Saketh12345 commited on
Commit
f742c10
·
1 Parent(s): 52aec6b

Initial commit for Hugging Face Space deployment with ResNet9 model

Browse files
Files changed (11) hide show
  1. .gitattributes +3 -0
  2. .gitignore +33 -3
  3. Dockerfile +2 -2
  4. LICENSE +21 -0
  5. README.md +4 -4
  6. app.py +180 -326
  7. app_resnet9.py +209 -0
  8. class_indices.json +1 -1
  9. plant_disease_model.pth +3 -0
  10. requirements.txt +41 -8
  11. space.yml +10 -0
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
2
+ plant_disease_model.pth filter=lfs diff=lfs merge=lfs -text
3
+ examples/*.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
@@ -44,9 +48,35 @@ coverage.xml
44
  .hypothesis/
45
  .pytest_cache/
46
 
47
- # Translations
48
- *.mo
49
- *.pot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Django stuff:
52
  *.log
 
1
+ # Dataset and large files
2
+ dataset/
3
+ kaggle_plant_disease/
4
+
5
  # Byte-compiled / optimized / DLL files
6
  __pycache__/
7
  *.py[cod]
 
48
  .hypothesis/
49
  .pytest_cache/
50
 
51
+ # Virtual environment
52
+ venv/
53
+
54
+ # IDE
55
+ .vscode/
56
+ .idea/
57
+ *.swp
58
+ *.swo
59
+
60
+ # OS generated files
61
+ .DS_Store
62
+ .DS_Store?
63
+ ._*
64
+ .Spotlight-V100
65
+ .Trashes
66
+ ehthumbs.db
67
+ Thumbs.db
68
+
69
+ # Local development
70
+ .env
71
+ *.pth
72
+ *.pth.*
73
+
74
+ # Training scripts and logs
75
+ resnet9_train.py
76
+ test_model.py
77
+
78
+ # Test images
79
+ test_image.JPG
80
 
81
  # Django stuff:
82
  *.log
Dockerfile CHANGED
@@ -23,5 +23,5 @@ ENV MPLCONFIGDIR=/tmp/matplotlib
23
  RUN mkdir -p /tmp/matplotlib && \
24
  chmod -R 777 /tmp/matplotlib
25
 
26
- # Run app.py when the container launches
27
- CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
 
23
  RUN mkdir -p /tmp/matplotlib && \
24
  chmod -R 777 /tmp/matplotlib
25
 
26
+ # Run app_resnet9.py when the container launches
27
+ CMD ["streamlit", "run", "app_resnet9.py", "--server.port=8501", "--server.address=0.0.0.0"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Saketh Jangala
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -11,11 +11,11 @@ pinned: false
11
 
12
  # 🌿 Plant Disease Detection
13
 
14
- A deep learning-based web application that identifies plant diseases from leaf images using EfficientNetB0.
15
 
16
  ## 🚀 Features
17
 
18
- - 🌱 Identify 38 different plant diseases
19
  - 📊 Interactive prediction visualization
20
  - 📱 Mobile-responsive design
21
  - ⚡ Fast inference with model caching
@@ -26,7 +26,7 @@ A deep learning-based web application that identifies plant diseases from leaf i
26
 
27
  1. Clone the repository:
28
  ```bash
29
- git clone https://huggingface.co/spaces/your-username/plant-disease-detection
30
  cd plant-disease-detection
31
  ```
32
 
@@ -42,7 +42,7 @@ A deep learning-based web application that identifies plant diseases from leaf i
42
 
43
  ## 🌐 Deployment
44
 
45
- This app is deployed on Hugging Face Spaces. You can access it [here](https://huggingface.co/spaces/your-username/plant-disease-detection).
46
 
47
  ## 📝 Note
48
 
 
11
 
12
  # 🌿 Plant Disease Detection
13
 
14
+ A deep learning-based web application that identifies plant diseases from leaf images using a ResNet9 model, providing fast and accurate predictions.
15
 
16
  ## 🚀 Features
17
 
18
+ - 🌱 Identify 38 different plant diseases using a ResNet9 model
19
  - 📊 Interactive prediction visualization
20
  - 📱 Mobile-responsive design
21
  - ⚡ Fast inference with model caching
 
26
 
27
  1. Clone the repository:
28
  ```bash
29
+ git clone https://huggingface.co/spaces/saketh-005/plant-disease-detection
30
  cd plant-disease-detection
31
  ```
32
 
 
42
 
43
  ## 🌐 Deployment
44
 
45
+ This app is deployed on Hugging Face Spaces. You can access it [here](https://huggingface.co/spaces/saketh-005/plant-disease-detection).
46
 
47
  ## 📝 Note
48
 
app.py CHANGED
@@ -1,345 +1,199 @@
1
- import os
2
- from io import BytesIO
3
- # Configure Matplotlib to use a non-interactive backend
4
- import matplotlib
5
- matplotlib.use('Agg') # Use the 'Agg' backend which doesn't require a display
6
- os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
7
-
8
- # Create the directory if it doesn't exist
9
- try:
10
- os.makedirs('/tmp/matplotlib', exist_ok=True)
11
- os.chmod('/tmp/matplotlib', 0o777)
12
- except Exception as e:
13
- pass # Directory creation is not critical
14
-
15
  import streamlit as st
16
- import numpy as np
17
- import pandas as pd
18
- from PIL import Image
19
- import matplotlib.pyplot as plt
20
- import plotly.express as px
21
  import torch
22
- from torchvision import transforms, models
23
  import torch.nn as nn
 
 
24
  import json
25
- from pathlib import Path
26
- import time
27
-
28
- # Set random seed for reproducibility
29
- SEED = 42
30
- np.random.seed(SEED)
31
- torch.manual_seed(SEED)
32
 
33
- # Set device
34
- device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Set page config
37
  st.set_page_config(
38
- page_title="Plant Disease Prediction",
39
- page_icon="🌱",
40
- layout="wide"
 
41
  )
42
 
43
- # Title and description
44
- st.title("🌿 Plant Disease Prediction")
45
- st.write("Upload an image of a plant leaf to detect potential diseases.")
46
-
47
- # Sidebar with information
48
  with st.sidebar:
49
- st.header("About")
50
- st.write("This app uses a deep learning model to detect plant diseases from leaf images.")
51
- st.write("### How to use:")
52
- st.write("1. Upload an image of a plant leaf")
53
- st.write("2. The app will process the image")
54
- st.write("3. View the prediction results")
55
- st.write("\nNote: This is a demo application. For production use, please train on a larger dataset.")
56
-
57
- # Constants
58
- MODEL_PATH = 'plant_disease_model.h5'
59
- IMG_SIZE = 224
60
-
61
- # Class mapping for PlantVillage dataset
62
- CLASS_NAMES = [
63
- 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
64
- 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
65
- 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_',
66
- 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot',
67
- 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy',
68
- 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy',
69
- 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight',
70
- 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy',
71
- 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy',
72
- 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
73
- 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite',
74
- 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus',
75
- 'Tomato___healthy'
76
- ]
77
-
78
- # Disease information
79
- DISEASE_INFO = {
80
- 'Apple___Apple_scab': {
81
- 'symptoms': 'Olive-green to black, circular spots on leaves that may become raised and velvety.',
82
- 'treatment': 'Apply fungicides in early spring and remove fallen leaves in autumn.'
83
- },
84
- 'Tomato___Early_blight': {
85
- 'symptoms': 'Dark, concentric spots on lower leaves that may develop a target-like appearance.',
86
- 'treatment': 'Use fungicides, remove infected leaves, and ensure good air circulation.'
87
- },
88
- 'default': {
89
- 'symptoms': 'Consult with a plant pathologist for accurate diagnosis.',
90
- 'treatment': 'Isolate the plant and consult with a local agricultural extension service.'
91
- }
92
- }
93
-
94
- # Default class indices (fallback if file not found)
95
- DEFAULT_CLASS_INDICES = {
96
- 'Pepper__bell___Bacterial_spot': 0,
97
- 'Pepper__bell___healthy': 1,
98
- 'Potato___Early_blight': 2,
99
- 'Potato___Late_blight': 3,
100
- 'Potato___healthy': 4,
101
- 'Tomato_Bacterial_spot': 5,
102
- 'Tomato_Early_blight': 6,
103
- 'Tomato_Late_blight': 7,
104
- 'Tomato_Leaf_Mold': 8,
105
- 'Tomato_Septoria_leaf_spot': 9,
106
- 'Tomato_Spider_mites_Two_spotted_spider_mite': 10,
107
- 'Tomato__Target_Spot': 11,
108
- 'Tomato__Tomato_YellowLeaf__Curl_Virus': 12,
109
- 'Tomato__Tomato_mosaic_virus': 13,
110
- 'Tomato_healthy': 14
111
- }
112
-
113
- # Load the model and class indices
114
- @st.cache_resource
115
- def load_model():
116
- try:
117
- # Try to load class indices from file, fall back to default if not found
118
  try:
119
  with open('class_indices.json', 'r') as f:
120
- class_indices = json.load(f)
121
- print("Loaded class indices from file")
122
- except FileNotFoundError:
123
- class_indices = DEFAULT_CLASS_INDICES
124
- print("Using default class indices")
125
-
126
- # Create model
127
- print("Creating model...")
128
- model = models.resnet18(weights=None) # We'll load our own weights
129
- num_ftrs = model.fc.in_features
130
- model.fc = nn.Linear(num_ftrs, len(class_indices))
131
-
132
- # Load weights if available
133
- try:
134
- checkpoint = torch.load('plant_disease_model.pth', map_location=device)
135
- model.load_state_dict(checkpoint['model_state_dict'])
136
- print("Loaded model weights")
137
- except FileNotFoundError:
138
- print("Warning: Model weights not found. Using random weights.")
139
-
140
- model = model.to(device)
141
- model.eval()
142
-
143
- # Reverse the dictionary to get index to class name mapping
144
- idx_to_class = {v: k for k, v in class_indices.items()}
145
- print(f"Model loaded with {len(idx_to_class)} classes")
146
- return model, idx_to_class
147
- except Exception as e:
148
- import traceback
149
- error_details = traceback.format_exc()
150
- print(f"Error loading model: {error_details}")
151
- st.error(f"Error loading model: {e}")
152
- return None, None
153
 
154
- model, idx_to_class = load_model()
 
 
155
 
156
- def preprocess_image(img):
157
- """Preprocess the image for prediction"""
158
- transform = transforms.Compose([
159
- transforms.Resize((224, 224)),
160
- transforms.ToTensor(),
161
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
162
- ])
163
- return transform(img).unsqueeze(0).to(device) # Add batch dimension and move to device
164
 
165
- def predict_disease(image):
166
- if model is None or idx_to_class is None:
167
- return "Model not loaded", 0.0
168
-
169
  try:
170
- # Preprocess the image
171
- input_tensor = preprocess_image(image)
 
172
 
173
  # Make prediction
174
- with torch.no_grad():
175
- output = model(input_tensor)
176
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
177
- confidence, predicted_idx = torch.max(probabilities, 0)
178
- predicted_class = idx_to_class.get(str(predicted_idx.item()), "Unknown")
179
-
180
- return predicted_class, confidence.item()
181
- except Exception as e:
182
- st.error(f"Error during prediction: {e}")
183
- return "Prediction error", 0.0
184
-
185
- def random_predictions():
186
- """Generate random predictions for demo purposes."""
187
- np.random.seed(42)
188
- indices = np.random.choice(len(CLASS_NAMES), 3, replace=False)
189
- scores = np.random.dirichlet(np.ones(3), size=1)[0]
190
- return list(zip([CLASS_NAMES[i] for i in indices], scores)), 0.1
191
-
192
- def display_disease_info(disease_name):
193
- """Display detailed information about the predicted disease."""
194
- # Get disease information or use default
195
- info = DISEASE_INFO.get(disease_name, DISEASE_INFO['default'])
196
-
197
- st.markdown("### Disease Information")
198
- st.markdown(f"**{disease_name.replace('_', ' ').title()}**")
199
-
200
- # Display symptoms and treatment
201
- with st.expander("ℹ️ Symptoms"):
202
- st.write(info['symptoms'])
203
-
204
- with st.expander("💊 Treatment"):
205
- st.write(info['treatment'])
206
-
207
- # Add prevention tips
208
- with st.expander("🛡️ Prevention Tips"):
209
- st.write("""
210
- - Ensure proper spacing between plants for good air circulation
211
- - Water at the base of plants to keep foliage dry
212
- - Rotate crops to prevent disease buildup in soil
213
- - Remove and destroy infected plant material
214
- - Use disease-resistant varieties when available
215
- """)
216
-
217
- def main():
218
- # Load the model (cached for performance)
219
- model, idx_to_class = load_model()
220
-
221
- # Check if model loaded successfully
222
- if model is None or idx_to_class is None:
223
- st.error("⚠️ Failed to load the model. Some features may not work correctly.")
224
- st.info("Please check if the model files are present in the correct location.")
225
-
226
- # Sidebar options
227
- with st.sidebar.expander("⚙️ Settings"):
228
- confidence_threshold = st.slider(
229
- "Minimum Confidence Threshold",
230
- min_value=0.1,
231
- max_value=0.9,
232
- value=0.5,
233
- step=0.1,
234
- help="Adjust the minimum confidence level for predictions"
235
- )
236
-
237
- # Main content
238
- st.markdown("## 🌱 Plant Disease Detection")
239
- st.markdown("Upload an image of a plant leaf to detect potential diseases.")
240
-
241
- # File uploader
242
- st.subheader("Upload an Image")
243
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
244
-
245
- # If no file is uploaded or selected, show a message
246
- if uploaded_file is None and 'uploaded_file' not in st.session_state:
247
- st.info("👆 Upload an image to get started")
248
 
249
- if uploaded_file is not None:
250
- try:
251
- # Reset the file pointer to the beginning
252
- if hasattr(uploaded_file, 'seek'):
253
- uploaded_file.seek(0)
254
-
255
- # Display the uploaded image
256
- image = Image.open(uploaded_file)
257
- st.image(image, caption="Uploaded Image", use_column_width=True)
258
-
259
- with st.spinner("🔍 Analyzing the image..."):
260
- # Preprocess the image
261
- processed_image = preprocess_image(image)
262
-
263
- # Get predictions
264
- predictions, inference_time = predict_with_model(model, processed_image, top_k=3)
265
-
266
- # Filter predictions by confidence threshold
267
- filtered_predictions = [(d, s) for d, s in predictions if s >= confidence_threshold]
268
-
269
- if not filtered_predictions:
270
- st.warning("No predictions met the confidence threshold. Try adjusting the threshold or upload a clearer image.")
271
- else:
272
- # Display results in two columns
273
- col1, col2 = st.columns([1, 1])
274
-
275
- with col1:
276
- st.markdown("### 📊 Prediction Results")
277
-
278
- # Create a DataFrame for the predictions
279
- df = pd.DataFrame({
280
- 'Disease': [p[0].replace('_', ' ').title() for p in filtered_predictions],
281
- 'Confidence': [p[1] for p in filtered_predictions]
282
- })
283
-
284
- # Display as a bar chart
285
- fig = px.bar(
286
- df,
287
- x='Confidence',
288
- y='Disease',
289
- orientation='h',
290
- title='Prediction Confidence',
291
- labels={'Confidence': 'Confidence Score', 'Disease': 'Disease'},
292
- color='Confidence',
293
- color_continuous_scale='Viridis'
294
- )
295
- fig.update_layout(showlegend=False)
296
- st.plotly_chart(fig, use_container_width=True)
297
-
298
- # Display inference time
299
- st.caption(f"⚡ Inference time: {inference_time*1000:.1f}ms")
300
-
301
- with col2:
302
- # Display detailed information about the top prediction
303
- top_disease = filtered_predictions[0][0]
304
- display_disease_info(top_disease)
305
-
306
- # Add a section for user feedback
307
- st.markdown("---")
308
- with st.expander("📝 Provide Feedback"):
309
- st.write("Help us improve our model!")
310
- feedback = st.radio(
311
- "Was this prediction accurate?",
312
- ("Yes", "Partially", "No")
313
- )
314
- if st.button("Submit Feedback"):
315
- # In a real app, you would save this feedback
316
- st.success("Thank you for your feedback!")
317
-
318
- # Add some space at the bottom
319
- st.markdown("---")
320
- st.markdown("""
321
- ### ℹ️ About This Tool
322
- This application uses deep learning to identify plant diseases from leaf images.
323
- The model has been trained on the [PlantVillage dataset](https://plantvillage.psu.edu/)
324
- and can detect various plant diseases.
325
-
326
- **Note:** This is a demonstration application. For real-world use, consult with
327
- agricultural experts and use laboratory testing for accurate disease diagnosis.
328
- """)
329
-
330
- except Exception as e:
331
- st.error(f"An error occurred: {str(e)}")
332
- st.error("Please try another image or check the console for details.")
333
-
334
- # Add a footer
335
- st.markdown("---")
336
- st.markdown("""
337
- <div style="text-align: center; color: gray;">
338
- <p>🌿 Plant Disease Detection App • Built with Streamlit •
339
- <a href="#" style="color: gray;">Terms of Use</a> •
340
- <a href="#" style="color: gray;">Privacy Policy</a></p>
341
- </div>
342
- """, unsafe_allow_html=True)
343
-
344
- if __name__ == "__main__":
345
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
2
  import torch
 
3
  import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
  import json
7
+ import numpy as np
 
 
 
 
 
 
8
 
9
+ # Define the model architecture (same as used in training)
10
+ class SimpleCNN(nn.Module):
11
+ def __init__(self, num_classes=38):
12
+ super().__init__()
13
+ self.features = nn.Sequential(
14
+ nn.Conv2d(3, 16, kernel_size=3, padding=1),
15
+ nn.BatchNorm2d(16),
16
+ nn.ReLU(inplace=True),
17
+ nn.MaxPool2d(2),
18
+
19
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
20
+ nn.BatchNorm2d(32),
21
+ nn.ReLU(inplace=True),
22
+ nn.MaxPool2d(2),
23
+
24
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
25
+ nn.BatchNorm2d(64),
26
+ nn.ReLU(inplace=True),
27
+ nn.MaxPool2d(2),
28
+
29
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
30
+ nn.BatchNorm2d(128),
31
+ nn.ReLU(inplace=True),
32
+ nn.MaxPool2d(2),
33
+ )
34
+
35
+ self.flattened_size = 128 * 16 * 16
36
+ self.classifier = nn.Sequential(
37
+ nn.Dropout(0.3),
38
+ nn.Linear(self.flattened_size, 512),
39
+ nn.ReLU(inplace=True),
40
+ nn.Dropout(0.3),
41
+ nn.Linear(512, num_classes)
42
+ )
43
+
44
+ def forward(self, x):
45
+ x = self.features(x)
46
+ x = x.view(x.size(0), -1)
47
+ x = self.classifier(x)
48
+ return x
49
+
50
+ # Load class indices
51
+ with open('class_indices.json', 'r') as f:
52
+ class_indices = json.load(f)
53
+ # Convert string keys to integers
54
+ class_indices = {int(k): v for k, v in class_indices.items()}
55
+ idx_to_class = {v: k for k, v in class_indices.items()}
56
+
57
+ # Initialize model
58
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59
+ model = SimpleCNN(num_classes=len(class_indices)).to(device)
60
+ model.load_state_dict(torch.load('plant_disease_model.pth', map_location=device))
61
+ model.eval()
62
+
63
+ # Image transformations
64
+ image_transforms = transforms.Compose([
65
+ transforms.Resize((256, 256)),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
68
+ ])
69
+
70
+ def predict(image):
71
+ """Predict the class of an image"""
72
+ # Preprocess
73
+ image = image_transforms(image).unsqueeze(0).to(device)
74
+
75
+ # Predict
76
+ with torch.no_grad():
77
+ outputs = model(image)
78
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
79
+
80
+ # Get top 3 predictions
81
+ top3_prob, top3_catid = torch.topk(probabilities, 3)
82
+ predictions = []
83
+ for i in range(top3_prob.size(0)):
84
+ class_idx = top3_catid[i].item()
85
+ class_name = class_indices.get(class_idx, f"Class {class_idx}")
86
+ # Convert to human-readable format
87
+ class_name = class_name.replace('___', ' ').replace('_', ' ').title()
88
+ predictions.append({
89
+ 'class': class_name,
90
+ 'probability': f"{top3_prob[i].item() * 100:.2f}%"
91
+ })
92
+
93
+ return predictions
94
 
95
+ # Page configuration (must be the first Streamlit command)
96
  st.set_page_config(
97
+ page_title="Plant Disease Classifier",
98
+ page_icon="🌿",
99
+ layout="wide",
100
+ initial_sidebar_state="expanded"
101
  )
102
 
103
+ # Sidebar with additional information
 
 
 
 
104
  with st.sidebar:
105
+ st.title("🌿 Plant Disease Classifier")
106
+ st.write("---")
107
+
108
+ # About section
109
+ st.subheader("About")
110
+ st.markdown("""
111
+ This app helps identify plant diseases using deep learning.
112
+ Upload an image of a plant leaf, and the model will predict
113
+ the most likely disease affecting it.
114
+ """)
115
+
116
+ # Supported diseases section
117
+ st.subheader("Supported Diseases")
118
+ if st.checkbox("Show supported diseases"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  try:
120
  with open('class_indices.json', 'r') as f:
121
+ diseases = json.load(f).values()
122
+ # Format disease names for better readability
123
+ formatted_diseases = [d.replace('___', ' ').replace('_', ' ').title()
124
+ for d in diseases]
125
+ formatted_diseases.sort() # Sort alphabetically
126
+
127
+ # Display in two columns for better layout
128
+ col1, col2 = st.columns(2)
129
+ half = (len(formatted_diseases) + 1) // 2
130
+
131
+ with col1:
132
+ for disease in formatted_diseases[:half]:
133
+ st.markdown(f"- {disease}")
134
+ with col2:
135
+ for disease in formatted_diseases[half:]:
136
+ st.markdown(f"- {disease}")
137
+ except Exception as e:
138
+ st.error("Could not load disease list.")
139
+
140
+ st.write("---")
141
+ st.markdown("*Upload an image to get started!*")
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # Main content
144
+ st.title("🌱 Plant Disease Classifier")
145
+ st.write("Upload an image of a plant leaf to detect potential diseases")
146
 
147
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
148
 
149
+ if uploaded_file is not None:
 
 
 
150
  try:
151
+ # Display the uploaded image
152
+ image = Image.open(uploaded_file).convert('RGB')
153
+ st.image(image, caption='Uploaded Image', use_column_width=True)
154
 
155
  # Make prediction
156
+ with st.spinner('Analyzing...'):
157
+ predictions = predict(image)
158
+
159
+ # Display results
160
+ st.subheader("Top Predictions:")
161
+ for i, pred in enumerate(predictions, 1):
162
+ st.write(f"{i}. {pred['class']} - {pred['probability']}")
163
+
164
+ # Show confidence level
165
+ confidence = float(predictions[0]['probability'].strip('%'))
166
+ if confidence > 80:
167
+ st.success("✅ High confidence in prediction!")
168
+ elif confidence > 50:
169
+ st.warning("⚠️ Moderate confidence in prediction.")
170
+ else:
171
+ st.info("ℹ️ Low confidence in prediction. Please ensure the image is clear and shows a plant leaf.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ except Exception as e:
174
+ st.error(f"An error occurred: {str(e)}")
175
+ st.info("Please try with a different image or make sure the image is clear and shows a plant leaf.")
176
+
177
+ # Add some styling
178
+ st.markdown("""
179
+ <style>
180
+ .stApp {
181
+ max-width: 1400px;
182
+ margin: 0 auto;
183
+ padding: 1rem;
184
+ }
185
+ .stButton>button {
186
+ background-color: #4CAF50;
187
+ color: white;
188
+ }
189
+ .sidebar .sidebar-content {
190
+ background-color: #f8f9fa;
191
+ }
192
+ .stMarkdown h2 {
193
+ color: #2e7d32;
194
+ }
195
+ .stMarkdown h3 {
196
+ color: #388e3c;
197
+ }
198
+ </style>
199
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_resnet9.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import json
7
+ import numpy as np
8
+ from pathlib import Path
9
+
10
+ # Page config
11
+ st.set_page_config(
12
+ page_title="Plant Disease Classifier",
13
+ page_icon="🌱",
14
+ layout="wide"
15
+ )
16
+
17
+ # Custom CSS
18
+ st.markdown("""
19
+ <style>
20
+ .main {
21
+ max-width: 1000px;
22
+ padding: 2rem;
23
+ }
24
+ .title {
25
+ text-align: center;
26
+ color: #2e8b57;
27
+ }
28
+ .prediction {
29
+ font-size: 1.2rem;
30
+ padding: 1rem;
31
+ border-radius: 0.5rem;
32
+ margin-top: 1rem;
33
+ }
34
+ .healthy {
35
+ background-color: #d4edda;
36
+ color: #155724;
37
+ }
38
+ .diseased {
39
+ background-color: #f8d7da;
40
+ color: #721c24;
41
+ }
42
+ </style>
43
+ """, unsafe_allow_html=True)
44
+
45
+ # Model class (same as in training)
46
+ class ConvBlock(nn.Module):
47
+ def __init__(self, in_channels, out_channels, pool=False):
48
+ super().__init__()
49
+ layers = [
50
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
51
+ nn.BatchNorm2d(out_channels),
52
+ nn.ReLU(inplace=True)
53
+ ]
54
+ if pool:
55
+ layers.append(nn.MaxPool2d(2))
56
+ self.conv = nn.Sequential(*layers)
57
+
58
+ def forward(self, x):
59
+ return self.conv(x)
60
+
61
+ class ResNet9(nn.Module):
62
+ def __init__(self, in_channels, num_classes):
63
+ super().__init__()
64
+
65
+ self.conv1 = ConvBlock(in_channels, 64)
66
+ self.conv2 = ConvBlock(64, 128, pool=True)
67
+ self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
68
+
69
+ self.conv3 = ConvBlock(128, 256, pool=True)
70
+ self.conv4 = ConvBlock(256, 512, pool=True)
71
+ self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
72
+
73
+ self.classifier = nn.Sequential(
74
+ nn.AdaptiveAvgPool2d(1),
75
+ nn.Flatten(),
76
+ nn.Dropout(0.2),
77
+ nn.Linear(512, num_classes)
78
+ )
79
+
80
+ def forward(self, xb):
81
+ out = self.conv1(xb)
82
+ out = self.conv2(out)
83
+ out = self.res1(out) + out
84
+ out = self.conv3(out)
85
+ out = self.conv4(out)
86
+ out = self.res2(out) + out
87
+ out = self.classifier(out)
88
+ return out
89
+
90
+ # Load class indices
91
+ @st.cache_data
92
+ def load_class_indices():
93
+ with open('class_indices.json', 'r') as f:
94
+ return json.load(f)
95
+
96
+ # Load model
97
+ @st.cache_resource
98
+ def load_model():
99
+ class_indices = load_class_indices()
100
+ model = ResNet9(3, len(class_indices))
101
+ model.load_state_dict(torch.load('plant_disease_resnet9.pth', map_location=torch.device('cpu')))
102
+ model.eval()
103
+ return model
104
+
105
+ # Preprocess image
106
+ def preprocess_image(image):
107
+ transform = transforms.Compose([
108
+ transforms.Resize((256, 256)),
109
+ transforms.ToTensor(),
110
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
111
+ ])
112
+ return transform(image).unsqueeze(0)
113
+
114
+ # Predict function
115
+ def predict(image, model, class_indices):
116
+ idx_to_class = {int(k): v for k, v in class_indices.items()}
117
+
118
+ # Preprocess
119
+ input_tensor = preprocess_image(image)
120
+
121
+ # Predict
122
+ with torch.no_grad():
123
+ output = model(input_tensor)
124
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
125
+ confidence, predicted_idx = torch.max(probabilities, 0)
126
+ predicted_class = idx_to_class[predicted_idx.item()]
127
+
128
+ return predicted_class, confidence.item()
129
+
130
+ # Main app
131
+ def main():
132
+ st.title("🌱 Plant Disease Classifier")
133
+ st.markdown("---")
134
+
135
+ # Load model and class indices
136
+ try:
137
+ model = load_model()
138
+ class_indices = load_class_indices()
139
+ idx_to_class = {int(k): v for k, v in class_indices.items()}
140
+ except Exception as e:
141
+ st.error(f"Error loading model: {str(e)}")
142
+ st.info("Please make sure you have trained the model first by running 'python resnet9_train.py'")
143
+ return
144
+
145
+ # File uploader
146
+ uploaded_file = st.file_uploader("Upload an image of a plant leaf", type=["jpg", "jpeg", "png"])
147
+
148
+ if uploaded_file is not None:
149
+ # Display image
150
+ image = Image.open(uploaded_file).convert('RGB')
151
+ st.image(image, caption='Uploaded Image', use_column_width=True)
152
+
153
+ # Make prediction
154
+ with st.spinner('Analyzing...'):
155
+ predicted_class, confidence = predict(image, model, class_indices)
156
+
157
+ # Display result
158
+ plant, status = predicted_class.split('___')
159
+ is_healthy = status == 'healthy'
160
+
161
+ st.markdown("### Prediction Result")
162
+ col1, col2 = st.columns(2)
163
+
164
+ with col1:
165
+ st.metric("Plant", plant.replace('_', ' ').title())
166
+ with col2:
167
+ status_display = "Healthy 🟢" if is_healthy else "Diseased 🔴"
168
+ st.metric("Status", status_display)
169
+
170
+ if not is_healthy:
171
+ st.metric("Disease", status.replace('_', ' ').title())
172
+
173
+ st.metric("Confidence", f"{confidence*100:.2f}%")
174
+
175
+ # Show info based on prediction
176
+ if is_healthy:
177
+ st.success(f"This {plant.replace('_', ' ').lower()} leaf appears to be healthy!")
178
+ else:
179
+ st.warning(f"This {plant.replace('_', ' ').lower()} leaf shows signs of {status.replace('_', ' ').lower()}.")
180
+
181
+ # Add some general advice (you can expand this)
182
+ st.info("""
183
+ **Recommendations:**
184
+ - Isolate the affected plant to prevent spread
185
+ - Remove severely infected leaves
186
+ - Consider using appropriate fungicides/pesticides
187
+ - Ensure proper spacing and air circulation
188
+ - Maintain optimal watering practices
189
+ """)
190
+ else:
191
+ st.info("Please upload an image of a plant leaf to check for diseases.")
192
+
193
+ # Add some information about the model
194
+ st.markdown("---")
195
+ st.markdown("""
196
+ ### About this App
197
+ This app uses a ResNet9 deep learning model to identify plant diseases from leaf images.
198
+ It can detect 38 different classes of plant diseases across 14 plant species.
199
+
200
+ **How to use:**
201
+ 1. Upload an image of a plant leaf
202
+ 2. The model will analyze the image
203
+ 3. View the prediction and recommendations
204
+
205
+ **Note:** For best results, use clear, well-lit photos of individual leaves.
206
+ """)
207
+
208
+ if __name__ == "__main__":
209
+ main()
class_indices.json CHANGED
@@ -1 +1 @@
1
- {"Pepper__bell___Bacterial_spot": 0, "Pepper__bell___healthy": 1, "Potato___Early_blight": 2, "Potato___Late_blight": 3, "Potato___healthy": 4, "Tomato_Bacterial_spot": 5, "Tomato_Early_blight": 6, "Tomato_Late_blight": 7, "Tomato_Leaf_Mold": 8, "Tomato_Septoria_leaf_spot": 9, "Tomato_Spider_mites_Two_spotted_spider_mite": 10, "Tomato__Target_Spot": 11, "Tomato__Tomato_YellowLeaf__Curl_Virus": 12, "Tomato__Tomato_mosaic_virus": 13, "Tomato_healthy": 14}
 
1
+ {"0": "Apple___Apple_scab", "1": "Apple___Black_rot", "2": "Apple___Cedar_apple_rust", "3": "Apple___healthy", "4": "Blueberry___healthy", "5": "Cherry_including_sour___Powdery_mildew", "6": "Cherry_including_sour___healthy", "7": "Corn_maize___Cercospora_leaf_spot_Gray_leaf_spot", "8": "Corn_maize___Common_rust_", "9": "Corn_maize___Northern_Leaf_Blight", "10": "Corn_maize___healthy", "11": "Grape___Black_rot", "12": "Grape___Esca_Black_Measles", "13": "Grape___Leaf_blight_Isariopsis_Leaf_Spot", "14": "Grape___healthy", "15": "Orange___Haunglongbing_Citrus_greening", "16": "Peach___Bacterial_spot", "17": "Peach___healthy", "18": "Pepper_bell___Bacterial_spot", "19": "Pepper_bell___healthy", "20": "Potato___Early_blight", "21": "Potato___Late_blight", "22": "Potato___healthy", "23": "Raspberry___healthy", "24": "Soybean___healthy", "25": "Squash___Powdery_mildew", "26": "Strawberry___Leaf_scorch", "27": "Strawberry___healthy", "28": "Tomato___Bacterial_spot", "29": "Tomato___Early_blight", "30": "Tomato___Late_blight", "31": "Tomato___Leaf_Mold", "32": "Tomato___Septoria_leaf_spot", "33": "Tomato___Spider_mites_Two-spotted_spider_mite", "34": "Tomato___Target_Spot", "35": "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "36": "Tomato___Tomato_mosaic_virus", "37": "Tomato___healthy"}
plant_disease_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e06c9f952fa40bd3e7470ca13a3ca780b38694cddcad1ff18599f7f90c1596b8
3
+ size 67593702
requirements.txt CHANGED
@@ -1,9 +1,42 @@
 
1
  streamlit>=1.29.0
2
- torch>=2.0.0
3
- torchvision>=0.15.0
4
- numpy>=1.19.0
5
- Pillow>=8.0.0
6
- matplotlib>=3.3.0
7
- plotly>=5.0.0
8
- pandas>=1.0.0
9
- tqdm>=4.65.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
  streamlit>=1.29.0
3
+
4
+ # PyTorch and vision
5
+ torch>=2.2.2
6
+ torchvision>=0.17.2
7
+ torchaudio>=2.2.2
8
+
9
+ # Data processing
10
+ numpy>=1.26.0
11
+ Pillow>=10.0.0
12
+
13
+ # Image processing
14
+ opencv-python>=4.8.0
15
+
16
+ # Utilities
17
+ tqdm>=4.66.0
18
+ matplotlib>=3.7.0
19
+ pandas>=2.2.0
20
+ Pillow>=10.0.0
21
+
22
+ # Visualization
23
+ matplotlib>=3.8.0
24
+ plotly>=5.18.0
25
+ seaborn>=0.13.0
26
+
27
+ # Model management
28
+ huggingface_hub>=0.20.0
29
+
30
+ # Utilities
31
+ tqdm>=4.66.0
32
+ requests>=2.31.0
33
+ python-dotenv>=1.0.0
34
+
35
+ # Image processing
36
+ opencv-python-headless>=4.8.0
37
+ albumentations>=1.3.1
38
+
39
+ # Development
40
+ scikit-learn>=1.4.0
41
+ ipykernel>=6.29.0
42
+ jupyter>=1.0.0
space.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ title: Plant Disease Detection
2
+ emoji: 🌿
3
+ colorFrom: green
4
+ colorTo: blue
5
+ sdk: docker
6
+ app_file: app_resnet9.py
7
+ app_port: 8501
8
+ pinned: false
9
+ duplicate: false
10
+ license: mit