madan2248c commited on
Commit
1ab4ffa
Β·
verified Β·
1 Parent(s): 8d1f6f8

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +63 -6
  2. app.py +235 -0
  3. best_student_attention_kd.pth +3 -0
  4. requirements.txt +5 -0
README.md CHANGED
@@ -1,12 +1,69 @@
1
  ---
2
- title: Swapnith
3
- emoji: πŸ’»
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Plant Disease Detection
3
+ emoji: 🌱
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # 🌱 Plant Disease Detection AI
14
+
15
+ An AI-powered plant disease detection system that can identify diseases in various crops including Chilli, Pepper Bell, Potato, Tomato, and GroundNut plants.
16
+
17
+ ## Features
18
+
19
+ - **19 Disease Classes**: Comprehensive detection across multiple plant types
20
+ - **TinyViT Model**: Efficient Vision Transformer for fast inference
21
+ - **Treatment Recommendations**: Detailed information about causes, symptoms, and treatments
22
+ - **User-Friendly Interface**: Simple drag-and-drop image upload
23
+
24
+ ## Supported Plants & Diseases
25
+
26
+ ### 🌢️ Chilli
27
+ - Healthy
28
+ - Leaf Curl Virus
29
+
30
+ ### πŸ«‘ Pepper Bell
31
+ - Healthy
32
+ - Bacterial Spot
33
+
34
+ ### πŸ₯” Potato
35
+ - Healthy
36
+ - Early Blight
37
+ - Late Blight
38
+
39
+ ### πŸ… Tomato
40
+ - Healthy
41
+ - Bacterial Spot
42
+ - Early Blight
43
+ - Late Blight
44
+ - Leaf Mold
45
+ - Mosaic Virus
46
+ - Septoria Leaf Spot
47
+ - Target Spot
48
+ - Two Spotted Spider Mite
49
+ - Yellow Leaf Curl Virus
50
+
51
+ ### πŸ₯œ GroundNut
52
+ - Healthy
53
+ - Rust
54
+
55
+ ## How to Use
56
+
57
+ 1. Upload an image of a plant leaf
58
+ 2. Click "Analyze Plant Disease"
59
+ 3. Get instant results with:
60
+ - Disease identification
61
+ - Confidence score
62
+ - Treatment recommendations
63
+
64
+ ## Model Details
65
+
66
+ - **Architecture**: TinyViT (WinKawaks/vit-tiny-patch16-224)
67
+ - **Parameters**: ~5.5M parameters
68
+ - **Input Size**: 224x224 RGB images
69
+ - **Training**: Knowledge distillation from larger teacher model
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import os
7
+ from transformers import ViTForImageClassification
8
+
9
+ # Disease class mappings - 19 classes from your trained model
10
+ DISEASE_CLASSES = [
11
+ "Chilli - Healthy", # 00
12
+ "Chilli - Leaf Curl Virus", # 01
13
+ "Pepper Bell - Bacterial Spot", # 02
14
+ "Pepper Bell - Healthy", # 03
15
+ "Potato - Early Blight", # 04
16
+ "Potato - Healthy", # 05
17
+ "Potato - Late Blight", # 06
18
+ "Tomato - Bacterial Spot", # 07
19
+ "Tomato - Early Blight", # 08
20
+ "Tomato - Healthy", # 09
21
+ "Tomato - Late Blight", # 10
22
+ "Tomato - Leaf Mold", # 11
23
+ "Tomato - Mosaic Virus", # 12
24
+ "Tomato - Septoria Leaf Spot", # 13
25
+ "Tomato - Target Spot", # 14
26
+ "Tomato - Two Spotted Spider Mite", # 15
27
+ "Tomato - Yellow Leaf Curl Virus", # 16
28
+ "GroundNut - Healthy", # 17
29
+ "GroundNut - Rust" # 18
30
+ ]
31
+
32
+ # Disease information database (shortened for demo)
33
+ DISEASE_INFO = {
34
+ "Chilli - Healthy": {
35
+ "description": "The chilli plant appears healthy with no visible signs of disease.",
36
+ "treatment": "Continue good agricultural practices and regular monitoring."
37
+ },
38
+ "Chilli - Leaf Curl Virus": {
39
+ "description": "Leaf curl virus causes leaves to curl, wrinkle, and become distorted.",
40
+ "treatment": "Remove infected plants, control whiteflies with neem oil, use yellow sticky traps."
41
+ },
42
+ "Tomato - Early Blight": {
43
+ "description": "Early blight causes characteristic target-spot patterns on older leaves.",
44
+ "treatment": "Apply fungicides (chlorothalonil, mancozeb), remove infected leaves, improve air circulation."
45
+ },
46
+ "Potato - Late Blight": {
47
+ "description": "Late blight is a devastating disease that can destroy entire crops rapidly.",
48
+ "treatment": "Apply systemic fungicides immediately, destroy infected plants, improve air circulation."
49
+ },
50
+ # Add more as needed...
51
+ }
52
+
53
+ # Global variables
54
+ model = None
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+
57
+ # Image preprocessing
58
+ transform = transforms.Compose([
59
+ transforms.Resize((224, 224)),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
62
+ ])
63
+
64
+ def load_model():
65
+ """Load the TinyViT student model"""
66
+ global model
67
+
68
+ # Look for model file
69
+ model_paths = [
70
+ "best_student_attention_kd.pth",
71
+ "model/best_student_attention_kd.pth"
72
+ ]
73
+
74
+ model_path = None
75
+ for path in model_paths:
76
+ if os.path.exists(path):
77
+ model_path = path
78
+ break
79
+
80
+ if model_path is None:
81
+ raise FileNotFoundError("Model file not found")
82
+
83
+ try:
84
+ print("Loading TinyViT student model...")
85
+
86
+ # Initialize TinyViT model architecture
87
+ model = ViTForImageClassification.from_pretrained(
88
+ "WinKawaks/vit-tiny-patch16-224",
89
+ num_labels=len(DISEASE_CLASSES),
90
+ ignore_mismatched_sizes=True
91
+ )
92
+
93
+ # Load trained weights
94
+ checkpoint = torch.load(model_path, map_location=device)
95
+
96
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
97
+ model.load_state_dict(checkpoint['model_state_dict'])
98
+ else:
99
+ model.load_state_dict(checkpoint)
100
+
101
+ model.to(device)
102
+ model.eval()
103
+
104
+ print(f"βœ“ Model loaded successfully on {device}")
105
+ return True
106
+
107
+ except Exception as e:
108
+ print(f"Error loading model: {e}")
109
+ return False
110
+
111
+ def predict_disease(image):
112
+ """Predict plant disease from image"""
113
+ if model is None:
114
+ return "❌ Model not loaded", "", ""
115
+
116
+ try:
117
+ # Preprocess image
118
+ if image is None:
119
+ return "❌ No image provided", "", ""
120
+
121
+ # Convert to RGB if needed
122
+ if image.mode != 'RGB':
123
+ image = image.convert('RGB')
124
+
125
+ image_tensor = transform(image).unsqueeze(0).to(device)
126
+
127
+ # Make prediction
128
+ with torch.no_grad():
129
+ outputs = model(image_tensor)
130
+ logits = outputs.logits
131
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
132
+ confidence, predicted_idx = torch.max(probabilities, 1)
133
+
134
+ predicted_class = DISEASE_CLASSES[predicted_idx.item()]
135
+ confidence_score = confidence.item()
136
+
137
+ # Parse results
138
+ parts = predicted_class.split(" - ")
139
+ crop_type = parts[0] if len(parts) > 0 else "Unknown"
140
+ disease = parts[1] if len(parts) > 1 else predicted_class
141
+ status = "🟒 Healthy" if "healthy" in disease.lower() else "��� Diseased"
142
+
143
+ # Get disease info
144
+ disease_info = DISEASE_INFO.get(predicted_class, {
145
+ "description": f"Information about {disease}.",
146
+ "treatment": "Consult with a plant pathologist or agricultural extension service."
147
+ })
148
+
149
+ # Format results
150
+ result = f"""
151
+ ## 🌱 **Crop Type:** {crop_type}
152
+ ## 🦠 **Disease:** {disease}
153
+ ## πŸ“Š **Status:** {status}
154
+ ## 🎯 **Confidence:** {confidence_score:.2%}
155
+
156
+ ### πŸ“ **Description:**
157
+ {disease_info['description']}
158
+
159
+ ### πŸ’Š **Treatment:**
160
+ {disease_info['treatment']}
161
+ """
162
+
163
+ return result, predicted_class, f"Confidence: {confidence_score:.2%}"
164
+
165
+ except Exception as e:
166
+ return f"❌ Error processing image: {str(e)}", "", ""
167
+
168
+ # Load model on startup
169
+ print("Initializing Plant Disease Detection Model...")
170
+ model_loaded = load_model()
171
+
172
+ if not model_loaded:
173
+ print("⚠️ Model failed to load - running in demo mode")
174
+
175
+ # Create Gradio interface
176
+ with gr.Blocks(title="🌱 Plant Disease Detection", theme=gr.themes.Soft()) as demo:
177
+ gr.Markdown("""
178
+ # 🌱 Plant Disease Detection AI
179
+
180
+ Upload an image of a plant leaf to detect diseases and get treatment recommendations.
181
+
182
+ **Supported Plants:** Chilli, Pepper Bell, Potato, Tomato, GroundNut
183
+ **Supported Diseases:** 19 different disease classes including healthy plants
184
+ """)
185
+
186
+ with gr.Row():
187
+ with gr.Column():
188
+ image_input = gr.Image(
189
+ type="pil",
190
+ label="πŸ“Έ Upload Plant Image",
191
+ height=400
192
+ )
193
+
194
+ predict_btn = gr.Button(
195
+ "πŸ” Analyze Plant Disease",
196
+ variant="primary",
197
+ size="lg"
198
+ )
199
+
200
+ with gr.Column():
201
+ result_output = gr.Markdown(
202
+ label="πŸ“‹ Analysis Results",
203
+ value="Upload an image and click 'Analyze' to get started!"
204
+ )
205
+
206
+ with gr.Row():
207
+ with gr.Column():
208
+ predicted_class_output = gr.Textbox(
209
+ label="🏷️ Predicted Class",
210
+ interactive=False
211
+ )
212
+ with gr.Column():
213
+ confidence_output = gr.Textbox(
214
+ label="πŸ“Š Confidence Score",
215
+ interactive=False
216
+ )
217
+
218
+ # Example images (you can add these)
219
+ gr.Markdown("### πŸ“Έ Example Images")
220
+ gr.Markdown("Try uploading images of plant leaves with various diseases or healthy plants.")
221
+
222
+ # Connect the prediction function
223
+ predict_btn.click(
224
+ fn=predict_disease,
225
+ inputs=[image_input],
226
+ outputs=[result_output, predicted_class_output, confidence_output]
227
+ )
228
+
229
+ # Launch the app
230
+ if __name__ == "__main__":
231
+ demo.launch(
232
+ server_name="0.0.0.0",
233
+ server_port=7860,
234
+ share=False
235
+ )
best_student_attention_kd.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64435d4efb8325520d8a7702176008344447566e9380d76458e6e68d2651d86d
3
+ size 22196075
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch==2.1.0
3
+ torchvision==0.16.0
4
+ transformers==4.35.2
5
+ Pillow==10.1.0