Gautamgiri commited on
Commit
cfae6d1
·
verified ·
1 Parent(s): c7eaf3b

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +208 -0
  3. requirements.txt +9 -0
  4. trained_model.keras +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ trained_model.keras filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ from collections import deque
6
+
7
+ # -----------------------------
8
+ # Load Model (Global)
9
+ # -----------------------------
10
+ print("Loading model...")
11
+ model = tf.keras.models.load_model("trained_model.keras")
12
+ print("Model loaded.")
13
+
14
+ # -----------------------------
15
+ # Class Labels
16
+ # -----------------------------
17
+ class_names = [
18
+ 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
19
+ 'Blueberry___healthy',
20
+ 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
21
+ 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_',
22
+ 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy',
23
+ 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy',
24
+ 'Orange___Haunglongbing_(Citrus_greening)',
25
+ 'Peach___Bacterial_spot', 'Peach___healthy',
26
+ 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy',
27
+ 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy',
28
+ 'Raspberry___healthy',
29
+ 'Soybean___healthy',
30
+ 'Squash___Powdery_mildew',
31
+ 'Strawberry___Leaf_scorch', 'Strawberry___healthy',
32
+ 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
33
+ 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot',
34
+ 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot',
35
+ 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'
36
+ ]
37
+
38
+ # -----------------------------
39
+ # Global state for stabilization (for streaming only)
40
+ # -----------------------------
41
+ history = deque(maxlen=5)
42
+
43
+ # -----------------------------
44
+ # Preprocessing Function (CRITICAL FIX)
45
+ # -----------------------------
46
+ def preprocess_frame(frame):
47
+ """
48
+ Handles any input frame (RGB, RGBA, Grayscale) from Gradio and
49
+ converts it to the exact (64, 64) BGR format your model was trained on.
50
+ """
51
+
52
+ # --- 1. Robustly convert to 3-channel RGB ---
53
+ if len(frame.shape) == 2:
54
+ # It's Grayscale (H, W)
55
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
56
+ elif frame.shape[2] == 1:
57
+ # It's Grayscale (H, W, 1)
58
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
59
+ elif frame.shape[2] == 4:
60
+ # It's RGBA (H, W, 4)
61
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
62
+ else:
63
+ # It's already 3-channel RGB (H, W, 3)
64
+ frame_rgb = frame
65
+
66
+ # --- 2. Convert from RGB to BGR (as model expects) ---
67
+ # Your original script used cv2.VideoCapture, which provides BGR frames.
68
+ frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
69
+
70
+ # --- 3. Resize to model's input size (64, 64) ---
71
+ # Your original script used (64, 64).
72
+ img_resized = cv2.resize(frame_bgr, (64, 64))
73
+
74
+ # --- 4. Normalize and add batch dimension ---
75
+ img_normalized = img_resized / 255.0
76
+ img_normalized = img_normalized.astype(np.float32)
77
+ img_batch = np.expand_dims(img_normalized, axis=0)
78
+
79
+ return img_batch
80
+
81
+ # -----------------------------
82
+ # Prediction Function for STREAMING (Webcam)
83
+ # -----------------------------
84
+ def predict_stream(frame):
85
+ """
86
+ Takes a single RGB frame, predicts, and returns an annotated RGB frame.
87
+ Uses 'history' deque for stabilization.
88
+ """
89
+ if frame is None:
90
+ return None
91
+
92
+ # 1. Preprocess and predict
93
+ # (preprocess_frame now handles RGB -> BGR conversion)
94
+ preprocessed_img = preprocess_frame(frame)
95
+ prediction = model.predict(preprocessed_img, verbose=0)
96
+ predicted_class = np.argmax(prediction)
97
+ history.append(predicted_class)
98
+
99
+ # 2. Stabilize prediction
100
+ if len(history) > 0:
101
+ label_index = max(set(history), key=history.count)
102
+ label = class_names[label_index]
103
+ else:
104
+ label = "Initializing..."
105
+
106
+ # 3. Return the original frame and the label text
107
+ # (All cv2.putText and color conversion logic removed)
108
+ return frame, label
109
+
110
+ # -----------------------------
111
+ # Prediction Function for UPLOAD (Single Image)
112
+ # -----------------------------
113
+ def predict_upload(frame):
114
+ """
115
+ Takes a single RGB frame, predicts, and returns an annotated RGB frame.
116
+ Does NOT use 'history' deque.
117
+ """
118
+ if frame is None:
119
+ return None
120
+
121
+ # 1. Preprocess and predict
122
+ # (preprocess_frame now handles RGB -> BGR conversion)
123
+ preprocessed_img = preprocess_frame(frame)
124
+ prediction = model.predict(preprocessed_img, verbose=0)
125
+ predicted_class = np.argmax(prediction)
126
+
127
+ # 2. Get label (no stabilization needed)
128
+ label = class_names[predicted_class]
129
+
130
+ # 3. Robustly convert original frame to RGB for display
131
+ # (This logic is still needed so the output image displays correctly)
132
+ if len(frame.shape) == 2:
133
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
134
+ elif frame.shape[2] == 1:
135
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
136
+ elif frame.shape[2] == 4:
137
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
138
+ else:
139
+ frame_rgb = frame.copy()
140
+
141
+ # 4. Return the RGB frame and the label text
142
+ # (All cv2.putText and color conversion logic removed)
143
+ return frame_rgb, label
144
+
145
+ # -----------------------------
146
+ # Gradio Interface (with Tabs)
147
+ # -----------------------------
148
+
149
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
150
+ gr.Markdown(
151
+ """
152
+ # 🌱 Real-Time Plant Disease Detection
153
+ This app uses a trained CNN to detect plant diseases.
154
+ Use the tabs below to either start a live webcam feed or upload an image.
155
+ """
156
+ )
157
+
158
+ with gr.Tabs():
159
+ # --- Tab 1: Live Detection ---
160
+ with gr.TabItem("Live Detection"):
161
+ with gr.Row():
162
+ webcam_input = gr.Image(
163
+ sources=["webcam"],
164
+ streaming=True,
165
+ label="Webcam Feed"
166
+ )
167
+ webcam_output = gr.Image(label="Prediction")
168
+
169
+ # --- NEW: Add a Label output for the prediction ---
170
+ webcam_label = gr.Label(label="Result")
171
+
172
+ webcam_input.stream(
173
+ predict_stream,
174
+ webcam_input,
175
+ [webcam_output, webcam_label] # --- UPDATED: Output to both components ---
176
+ )
177
+
178
+ # --- Tab 2: Upload Image ---
179
+ with gr.TabItem("Upload Image"):
180
+ with gr.Row():
181
+ upload_input = gr.Image(
182
+ sources=["upload"],
183
+ label="Upload a plant image",
184
+ type="numpy"
185
+ )
186
+ upload_output = gr.Image(label="Prediction")
187
+
188
+ # --- NEW: Add a Label output for the prediction ---
189
+ upload_label = gr.Label(label="Result")
190
+
191
+ upload_input.upload(
192
+ predict_upload,
193
+ upload_input,
194
+ [upload_output, upload_label] # --- UPDATED: Output to both components ---
195
+ )
196
+
197
+ # --- Accordions for extra info ---
198
+ with gr.Accordion("About this App"):
199
+ gr.Markdown("This project uses a TensorFlow/Keras CNN model to classify 38 different plant disease categories in real-time.")
200
+
201
+ with gr.Accordion("Show all 38 classes"):
202
+ gr.JSON(class_names)
203
+
204
+ # -----------------------------
205
+ # Launch the App
206
+ # -----------------------------
207
+ if __name__ == "__main__":
208
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ scikit-learn
3
+ matplotlib
4
+ seaborn
5
+ pandas
6
+ librosa
7
+ tensorflow
8
+ opencv-python-headless
9
+ numpy
trained_model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b2762962ee4a591eef3013445a26c7d04d27a28620ef05908f1122d80bdc13c
3
+ size 10201218