Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- README.md +63 -6
- app.py +235 -0
- best_student_attention_kd.pth +3 -0
- requirements.txt +5 -0
README.md
CHANGED
|
@@ -1,12 +1,69 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Plant Disease Detection
|
| 3 |
+
emoji: π±
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# π± Plant Disease Detection AI
|
| 14 |
+
|
| 15 |
+
An AI-powered plant disease detection system that can identify diseases in various crops including Chilli, Pepper Bell, Potato, Tomato, and GroundNut plants.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- **19 Disease Classes**: Comprehensive detection across multiple plant types
|
| 20 |
+
- **TinyViT Model**: Efficient Vision Transformer for fast inference
|
| 21 |
+
- **Treatment Recommendations**: Detailed information about causes, symptoms, and treatments
|
| 22 |
+
- **User-Friendly Interface**: Simple drag-and-drop image upload
|
| 23 |
+
|
| 24 |
+
## Supported Plants & Diseases
|
| 25 |
+
|
| 26 |
+
### πΆοΈ Chilli
|
| 27 |
+
- Healthy
|
| 28 |
+
- Leaf Curl Virus
|
| 29 |
+
|
| 30 |
+
### π« Pepper Bell
|
| 31 |
+
- Healthy
|
| 32 |
+
- Bacterial Spot
|
| 33 |
+
|
| 34 |
+
### π₯ Potato
|
| 35 |
+
- Healthy
|
| 36 |
+
- Early Blight
|
| 37 |
+
- Late Blight
|
| 38 |
+
|
| 39 |
+
### π
Tomato
|
| 40 |
+
- Healthy
|
| 41 |
+
- Bacterial Spot
|
| 42 |
+
- Early Blight
|
| 43 |
+
- Late Blight
|
| 44 |
+
- Leaf Mold
|
| 45 |
+
- Mosaic Virus
|
| 46 |
+
- Septoria Leaf Spot
|
| 47 |
+
- Target Spot
|
| 48 |
+
- Two Spotted Spider Mite
|
| 49 |
+
- Yellow Leaf Curl Virus
|
| 50 |
+
|
| 51 |
+
### π₯ GroundNut
|
| 52 |
+
- Healthy
|
| 53 |
+
- Rust
|
| 54 |
+
|
| 55 |
+
## How to Use
|
| 56 |
+
|
| 57 |
+
1. Upload an image of a plant leaf
|
| 58 |
+
2. Click "Analyze Plant Disease"
|
| 59 |
+
3. Get instant results with:
|
| 60 |
+
- Disease identification
|
| 61 |
+
- Confidence score
|
| 62 |
+
- Treatment recommendations
|
| 63 |
+
|
| 64 |
+
## Model Details
|
| 65 |
+
|
| 66 |
+
- **Architecture**: TinyViT (WinKawaks/vit-tiny-patch16-224)
|
| 67 |
+
- **Parameters**: ~5.5M parameters
|
| 68 |
+
- **Input Size**: 224x224 RGB images
|
| 69 |
+
- **Training**: Knowledge distillation from larger teacher model
|
app.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import os
|
| 7 |
+
from transformers import ViTForImageClassification
|
| 8 |
+
|
| 9 |
+
# Disease class mappings - 19 classes from your trained model
|
| 10 |
+
DISEASE_CLASSES = [
|
| 11 |
+
"Chilli - Healthy", # 00
|
| 12 |
+
"Chilli - Leaf Curl Virus", # 01
|
| 13 |
+
"Pepper Bell - Bacterial Spot", # 02
|
| 14 |
+
"Pepper Bell - Healthy", # 03
|
| 15 |
+
"Potato - Early Blight", # 04
|
| 16 |
+
"Potato - Healthy", # 05
|
| 17 |
+
"Potato - Late Blight", # 06
|
| 18 |
+
"Tomato - Bacterial Spot", # 07
|
| 19 |
+
"Tomato - Early Blight", # 08
|
| 20 |
+
"Tomato - Healthy", # 09
|
| 21 |
+
"Tomato - Late Blight", # 10
|
| 22 |
+
"Tomato - Leaf Mold", # 11
|
| 23 |
+
"Tomato - Mosaic Virus", # 12
|
| 24 |
+
"Tomato - Septoria Leaf Spot", # 13
|
| 25 |
+
"Tomato - Target Spot", # 14
|
| 26 |
+
"Tomato - Two Spotted Spider Mite", # 15
|
| 27 |
+
"Tomato - Yellow Leaf Curl Virus", # 16
|
| 28 |
+
"GroundNut - Healthy", # 17
|
| 29 |
+
"GroundNut - Rust" # 18
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
# Disease information database (shortened for demo)
|
| 33 |
+
DISEASE_INFO = {
|
| 34 |
+
"Chilli - Healthy": {
|
| 35 |
+
"description": "The chilli plant appears healthy with no visible signs of disease.",
|
| 36 |
+
"treatment": "Continue good agricultural practices and regular monitoring."
|
| 37 |
+
},
|
| 38 |
+
"Chilli - Leaf Curl Virus": {
|
| 39 |
+
"description": "Leaf curl virus causes leaves to curl, wrinkle, and become distorted.",
|
| 40 |
+
"treatment": "Remove infected plants, control whiteflies with neem oil, use yellow sticky traps."
|
| 41 |
+
},
|
| 42 |
+
"Tomato - Early Blight": {
|
| 43 |
+
"description": "Early blight causes characteristic target-spot patterns on older leaves.",
|
| 44 |
+
"treatment": "Apply fungicides (chlorothalonil, mancozeb), remove infected leaves, improve air circulation."
|
| 45 |
+
},
|
| 46 |
+
"Potato - Late Blight": {
|
| 47 |
+
"description": "Late blight is a devastating disease that can destroy entire crops rapidly.",
|
| 48 |
+
"treatment": "Apply systemic fungicides immediately, destroy infected plants, improve air circulation."
|
| 49 |
+
},
|
| 50 |
+
# Add more as needed...
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# Global variables
|
| 54 |
+
model = None
|
| 55 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
+
|
| 57 |
+
# Image preprocessing
|
| 58 |
+
transform = transforms.Compose([
|
| 59 |
+
transforms.Resize((224, 224)),
|
| 60 |
+
transforms.ToTensor(),
|
| 61 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
def load_model():
|
| 65 |
+
"""Load the TinyViT student model"""
|
| 66 |
+
global model
|
| 67 |
+
|
| 68 |
+
# Look for model file
|
| 69 |
+
model_paths = [
|
| 70 |
+
"best_student_attention_kd.pth",
|
| 71 |
+
"model/best_student_attention_kd.pth"
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
model_path = None
|
| 75 |
+
for path in model_paths:
|
| 76 |
+
if os.path.exists(path):
|
| 77 |
+
model_path = path
|
| 78 |
+
break
|
| 79 |
+
|
| 80 |
+
if model_path is None:
|
| 81 |
+
raise FileNotFoundError("Model file not found")
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
print("Loading TinyViT student model...")
|
| 85 |
+
|
| 86 |
+
# Initialize TinyViT model architecture
|
| 87 |
+
model = ViTForImageClassification.from_pretrained(
|
| 88 |
+
"WinKawaks/vit-tiny-patch16-224",
|
| 89 |
+
num_labels=len(DISEASE_CLASSES),
|
| 90 |
+
ignore_mismatched_sizes=True
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Load trained weights
|
| 94 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 95 |
+
|
| 96 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 97 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 98 |
+
else:
|
| 99 |
+
model.load_state_dict(checkpoint)
|
| 100 |
+
|
| 101 |
+
model.to(device)
|
| 102 |
+
model.eval()
|
| 103 |
+
|
| 104 |
+
print(f"β Model loaded successfully on {device}")
|
| 105 |
+
return True
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"Error loading model: {e}")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def predict_disease(image):
|
| 112 |
+
"""Predict plant disease from image"""
|
| 113 |
+
if model is None:
|
| 114 |
+
return "β Model not loaded", "", ""
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Preprocess image
|
| 118 |
+
if image is None:
|
| 119 |
+
return "β No image provided", "", ""
|
| 120 |
+
|
| 121 |
+
# Convert to RGB if needed
|
| 122 |
+
if image.mode != 'RGB':
|
| 123 |
+
image = image.convert('RGB')
|
| 124 |
+
|
| 125 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 126 |
+
|
| 127 |
+
# Make prediction
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
outputs = model(image_tensor)
|
| 130 |
+
logits = outputs.logits
|
| 131 |
+
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
| 132 |
+
confidence, predicted_idx = torch.max(probabilities, 1)
|
| 133 |
+
|
| 134 |
+
predicted_class = DISEASE_CLASSES[predicted_idx.item()]
|
| 135 |
+
confidence_score = confidence.item()
|
| 136 |
+
|
| 137 |
+
# Parse results
|
| 138 |
+
parts = predicted_class.split(" - ")
|
| 139 |
+
crop_type = parts[0] if len(parts) > 0 else "Unknown"
|
| 140 |
+
disease = parts[1] if len(parts) > 1 else predicted_class
|
| 141 |
+
status = "π’ Healthy" if "healthy" in disease.lower() else "οΏ½οΏ½οΏ½ Diseased"
|
| 142 |
+
|
| 143 |
+
# Get disease info
|
| 144 |
+
disease_info = DISEASE_INFO.get(predicted_class, {
|
| 145 |
+
"description": f"Information about {disease}.",
|
| 146 |
+
"treatment": "Consult with a plant pathologist or agricultural extension service."
|
| 147 |
+
})
|
| 148 |
+
|
| 149 |
+
# Format results
|
| 150 |
+
result = f"""
|
| 151 |
+
## π± **Crop Type:** {crop_type}
|
| 152 |
+
## π¦ **Disease:** {disease}
|
| 153 |
+
## π **Status:** {status}
|
| 154 |
+
## π― **Confidence:** {confidence_score:.2%}
|
| 155 |
+
|
| 156 |
+
### π **Description:**
|
| 157 |
+
{disease_info['description']}
|
| 158 |
+
|
| 159 |
+
### π **Treatment:**
|
| 160 |
+
{disease_info['treatment']}
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
return result, predicted_class, f"Confidence: {confidence_score:.2%}"
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
return f"β Error processing image: {str(e)}", "", ""
|
| 167 |
+
|
| 168 |
+
# Load model on startup
|
| 169 |
+
print("Initializing Plant Disease Detection Model...")
|
| 170 |
+
model_loaded = load_model()
|
| 171 |
+
|
| 172 |
+
if not model_loaded:
|
| 173 |
+
print("β οΈ Model failed to load - running in demo mode")
|
| 174 |
+
|
| 175 |
+
# Create Gradio interface
|
| 176 |
+
with gr.Blocks(title="π± Plant Disease Detection", theme=gr.themes.Soft()) as demo:
|
| 177 |
+
gr.Markdown("""
|
| 178 |
+
# π± Plant Disease Detection AI
|
| 179 |
+
|
| 180 |
+
Upload an image of a plant leaf to detect diseases and get treatment recommendations.
|
| 181 |
+
|
| 182 |
+
**Supported Plants:** Chilli, Pepper Bell, Potato, Tomato, GroundNut
|
| 183 |
+
**Supported Diseases:** 19 different disease classes including healthy plants
|
| 184 |
+
""")
|
| 185 |
+
|
| 186 |
+
with gr.Row():
|
| 187 |
+
with gr.Column():
|
| 188 |
+
image_input = gr.Image(
|
| 189 |
+
type="pil",
|
| 190 |
+
label="πΈ Upload Plant Image",
|
| 191 |
+
height=400
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
predict_btn = gr.Button(
|
| 195 |
+
"π Analyze Plant Disease",
|
| 196 |
+
variant="primary",
|
| 197 |
+
size="lg"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
with gr.Column():
|
| 201 |
+
result_output = gr.Markdown(
|
| 202 |
+
label="π Analysis Results",
|
| 203 |
+
value="Upload an image and click 'Analyze' to get started!"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
with gr.Row():
|
| 207 |
+
with gr.Column():
|
| 208 |
+
predicted_class_output = gr.Textbox(
|
| 209 |
+
label="π·οΈ Predicted Class",
|
| 210 |
+
interactive=False
|
| 211 |
+
)
|
| 212 |
+
with gr.Column():
|
| 213 |
+
confidence_output = gr.Textbox(
|
| 214 |
+
label="π Confidence Score",
|
| 215 |
+
interactive=False
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Example images (you can add these)
|
| 219 |
+
gr.Markdown("### πΈ Example Images")
|
| 220 |
+
gr.Markdown("Try uploading images of plant leaves with various diseases or healthy plants.")
|
| 221 |
+
|
| 222 |
+
# Connect the prediction function
|
| 223 |
+
predict_btn.click(
|
| 224 |
+
fn=predict_disease,
|
| 225 |
+
inputs=[image_input],
|
| 226 |
+
outputs=[result_output, predicted_class_output, confidence_output]
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Launch the app
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
demo.launch(
|
| 232 |
+
server_name="0.0.0.0",
|
| 233 |
+
server_port=7860,
|
| 234 |
+
share=False
|
| 235 |
+
)
|
best_student_attention_kd.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64435d4efb8325520d8a7702176008344447566e9380d76458e6e68d2651d86d
|
| 3 |
+
size 22196075
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
torch==2.1.0
|
| 3 |
+
torchvision==0.16.0
|
| 4 |
+
transformers==4.35.2
|
| 5 |
+
Pillow==10.1.0
|