sidd-harth011 commited on
Commit
237c700
·
1 Parent(s): f25f10c
Files changed (3) hide show
  1. .gitignore +1 -1
  2. app.py +50 -8
  3. class_indices.json +1 -0
.gitignore CHANGED
@@ -1 +1 @@
1
- class_indices.json
 
1
+ plant_disease_model.tflite
app.py CHANGED
@@ -1,14 +1,56 @@
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def say_hello(name):
4
- return f"Hello {name}, the model is ready on Hugging Face Spaces!"
 
5
 
6
- demo = gr.Interface(
7
- fn=say_hello,
8
- inputs="text",
9
- outputs="text",
10
- title="Test Space"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
 
13
  if __name__ == "__main__":
14
- demo.launch()
 
1
+ import os
2
+ import json
3
+ from PIL import Image
4
+ import numpy as np
5
+ import tensorflow as tf
6
  import gradio as gr
7
 
8
+ # Paths
9
+ working_dir = os.path.dirname(os.path.abspath(__file__))
10
+ model_path = f"{working_dir}/plant_disease_model.tflite"
11
 
12
+ # Load TFLite model
13
+ interpreter = tf.lite.Interpreter(model_path=model_path)
14
+ interpreter.allocate_tensors()
15
+
16
+ # Get input and output details
17
+ input_details = interpreter.get_input_details()
18
+ output_details = interpreter.get_output_details()
19
+
20
+ # Load class indices
21
+ class_indices = json.load(open(f"{working_dir}/class_indices.json"))
22
+
23
+ # Function to preprocess the image
24
+ def load_and_preprocess_image(image, target_size=(224, 224)):
25
+ img = image.resize(target_size)
26
+ img_array = np.array(img, dtype=np.float32)
27
+ img_array = np.expand_dims(img_array, axis=0)
28
+ img_array = img_array / 255.0
29
+ return img_array
30
+
31
+ # Prediction function
32
+ def predict_image_class(image):
33
+ preprocessed_img = load_and_preprocess_image(image)
34
+
35
+ # Set input tensor
36
+ interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
37
+ interpreter.invoke()
38
+
39
+ # Get predictions
40
+ predictions = interpreter.get_tensor(output_details[0]['index'])
41
+ predicted_class_index = np.argmax(predictions, axis=1)[0]
42
+ predicted_class_name = class_indices[str(predicted_class_index)]
43
+
44
+ return f"Prediction: {predicted_class_name}"
45
+
46
+ # Gradio Interface
47
+ interface = gr.Interface(
48
+ fn=predict_image_class,
49
+ inputs=gr.Image(type="pil", label="Upload an Image"),
50
+ outputs=gr.Textbox(label="Prediction"),
51
+ title="🌱 Plant Disease Classifier (TFLite)",
52
+ description="Upload a plant leaf image to classify its disease using a compressed TFLite model."
53
  )
54
 
55
  if __name__ == "__main__":
56
+ interface.launch()
class_indices.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "Apple___Apple_scab", "1": "Apple___Black_rot", "2": "Apple___Cedar_apple_rust", "3": "Apple___healthy", "4": "Blueberry___healthy", "5": "Cherry_(including_sour)___Powdery_mildew", "6": "Cherry_(including_sour)___healthy", "7": "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot", "8": "Corn_(maize)___Common_rust_", "9": "Corn_(maize)___Northern_Leaf_Blight", "10": "Corn_(maize)___healthy", "11": "Grape___Black_rot", "12": "Grape___Esca_(Black_Measles)", "13": "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", "14": "Grape___healthy", "15": "Orange___Haunglongbing_(Citrus_greening)", "16": "Peach___Bacterial_spot", "17": "Peach___healthy", "18": "Pepper,_bell___Bacterial_spot", "19": "Pepper,_bell___healthy", "20": "Potato___Early_blight", "21": "Potato___Late_blight", "22": "Potato___healthy", "23": "Raspberry___healthy", "24": "Soybean___healthy", "25": "Squash___Powdery_mildew", "26": "Strawberry___Leaf_scorch", "27": "Strawberry___healthy", "28": "Tomato___Bacterial_spot", "29": "Tomato___Early_blight", "30": "Tomato___Late_blight", "31": "Tomato___Leaf_Mold", "32": "Tomato___Septoria_leaf_spot", "33": "Tomato___Spider_mites Two-spotted_spider_mite", "34": "Tomato___Target_Spot", "35": "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "36": "Tomato___Tomato_mosaic_virus", "37": "Tomato___healthy"}