CropGuard / app.py
Jude Joseph Agustino
Initial commit: CropGuard disease detection app
6313719
import gradio as gr
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import os
# Import your model
from models import ResNet9
# Plant disease class names
CLASS_NAMES = [
'Apple___Apple_scab',
'Apple___Black_rot',
'Apple___Cedar_apple_rust',
'Apple___healthy',
'Blueberry___healthy',
'Cherry_(including_sour)__Powdery_mildew',
'Cherry(including_sour)__healthy',
'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot',
'Corn(maize)_Common_rust',
'Corn(maize)__Northern_Leaf_Blight',
'Corn(maize)healthy',
'Grape___Black_rot',
'Grape___Esca(Black_Measles)',
'Grape___Leaf_blight(Isariopsis_Leaf_Spot)',
'Grape___healthy',
'Orange___Haunglongbing(Citrus_greening)',
'Peach___Bacterial_spot',
'Peach___healthy',
'Pepper,_bell___Bacterial_spot',
'Pepper,_bell___healthy',
'Potato___Early_blight',
'Potato___Late_blight',
'Potato___healthy',
'Raspberry___healthy',
'Soybean___healthy',
'Squash___Powdery_mildew',
'Strawberry___Leaf_scorch',
'Strawberry___healthy',
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy'
]
# Load model
model = None
def load_model():
global model
try:
model = ResNet9(3, len(CLASS_NAMES))
state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
print("✅ Model loaded successfully")
return True
except Exception as e:
print(f"❌ Model load failed: {e}")
return False
def predict_disease(image):
"""Predict plant disease from image"""
if model is None:
if not load_model():
return {"Error": "Model not available"}
# Transform image
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
try:
# Convert and transform image
if image is None:
return {"Error": "No image provided"}
img_tensor = transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs[0], dim=0)
# Get top 5 predictions
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Format results for Gradio
results = {}
for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)):
class_name = CLASS_NAMES[idx.item()]
# Clean up class name for display
clean_name = class_name.replace('___', ' - ').replace('_', ' ')
results[clean_name] = float(prob)
return results
except Exception as e:
return {"Error": f"Prediction failed: {str(e)}"}
def format_class_info():
"""Format class information for display"""
plants = {}
for class_name in CLASS_NAMES:
if '___' in class_name:
plant, condition = class_name.split('___', 1)
if plant not in plants:
plants[plant] = []
plants[plant].append(condition.replace('_', ' '))
info = "## Supported Plants and Conditions:\n\n"
for plant, conditions in sorted(plants.items()):
info += f"**{plant.replace('_', ' ')}**: {', '.join(conditions)}\n\n"
return info
# Load model on startup
load_model()
# Create Gradio interface
with gr.Blocks(title="🌱 CropGuard - Plant Disease Detection", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌱 CropGuard - Plant Disease Detection
Upload an image of a plant leaf to detect diseases using our ResNet-9 model trained on the PlantVillage dataset.
**Supported formats**: JPG, PNG, JPEG
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Plant Image",
height=400
)
predict_btn = gr.Button("🔍 Analyze Disease", variant="primary", size="lg")
with gr.Column():
output = gr.Label(
label="Disease Prediction Results",
num_top_classes=5,
show_label=True
)
# Example images (you can add these later)
gr.Markdown("### 📋 Examples")
gr.Markdown("Try uploading images of plant leaves to see the disease detection in action!")
# Info section
with gr.Accordion("ℹ️ Supported Plants & Diseases", open=False):
gr.Markdown(format_class_info())
# Event handlers
predict_btn.click(
fn=predict_disease,
inputs=image_input,
outputs=output
)
# Also predict on image upload
image_input.change(
fn=predict_disease,
inputs=image_input,
outputs=output
)
# Launch the app
if __name__ == "__main__":
demo.launch()