File size: 5,335 Bytes
6313719 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import gradio as gr
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import os
# Import your model
from models import ResNet9
# Plant disease class names
CLASS_NAMES = [
'Apple___Apple_scab',
'Apple___Black_rot',
'Apple___Cedar_apple_rust',
'Apple___healthy',
'Blueberry___healthy',
'Cherry_(including_sour)__Powdery_mildew',
'Cherry(including_sour)__healthy',
'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot',
'Corn(maize)_Common_rust',
'Corn(maize)__Northern_Leaf_Blight',
'Corn(maize)healthy',
'Grape___Black_rot',
'Grape___Esca(Black_Measles)',
'Grape___Leaf_blight(Isariopsis_Leaf_Spot)',
'Grape___healthy',
'Orange___Haunglongbing(Citrus_greening)',
'Peach___Bacterial_spot',
'Peach___healthy',
'Pepper,_bell___Bacterial_spot',
'Pepper,_bell___healthy',
'Potato___Early_blight',
'Potato___Late_blight',
'Potato___healthy',
'Raspberry___healthy',
'Soybean___healthy',
'Squash___Powdery_mildew',
'Strawberry___Leaf_scorch',
'Strawberry___healthy',
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy'
]
# Load model
model = None
def load_model():
global model
try:
model = ResNet9(3, len(CLASS_NAMES))
state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
print("✅ Model loaded successfully")
return True
except Exception as e:
print(f"❌ Model load failed: {e}")
return False
def predict_disease(image):
"""Predict plant disease from image"""
if model is None:
if not load_model():
return {"Error": "Model not available"}
# Transform image
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
try:
# Convert and transform image
if image is None:
return {"Error": "No image provided"}
img_tensor = transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs[0], dim=0)
# Get top 5 predictions
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Format results for Gradio
results = {}
for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)):
class_name = CLASS_NAMES[idx.item()]
# Clean up class name for display
clean_name = class_name.replace('___', ' - ').replace('_', ' ')
results[clean_name] = float(prob)
return results
except Exception as e:
return {"Error": f"Prediction failed: {str(e)}"}
def format_class_info():
"""Format class information for display"""
plants = {}
for class_name in CLASS_NAMES:
if '___' in class_name:
plant, condition = class_name.split('___', 1)
if plant not in plants:
plants[plant] = []
plants[plant].append(condition.replace('_', ' '))
info = "## Supported Plants and Conditions:\n\n"
for plant, conditions in sorted(plants.items()):
info += f"**{plant.replace('_', ' ')}**: {', '.join(conditions)}\n\n"
return info
# Load model on startup
load_model()
# Create Gradio interface
with gr.Blocks(title="🌱 CropGuard - Plant Disease Detection", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌱 CropGuard - Plant Disease Detection
Upload an image of a plant leaf to detect diseases using our ResNet-9 model trained on the PlantVillage dataset.
**Supported formats**: JPG, PNG, JPEG
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Plant Image",
height=400
)
predict_btn = gr.Button("🔍 Analyze Disease", variant="primary", size="lg")
with gr.Column():
output = gr.Label(
label="Disease Prediction Results",
num_top_classes=5,
show_label=True
)
# Example images (you can add these later)
gr.Markdown("### 📋 Examples")
gr.Markdown("Try uploading images of plant leaves to see the disease detection in action!")
# Info section
with gr.Accordion("ℹ️ Supported Plants & Diseases", open=False):
gr.Markdown(format_class_info())
# Event handlers
predict_btn.click(
fn=predict_disease,
inputs=image_input,
outputs=output
)
# Also predict on image upload
image_input.change(
fn=predict_disease,
inputs=image_input,
outputs=output
)
# Launch the app
if __name__ == "__main__":
demo.launch() |