| import gradio as gr |
| import torch |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import os |
|
|
| |
| from models import ResNet9 |
|
|
| |
| CLASS_NAMES = [ |
| 'Apple___Apple_scab', |
| 'Apple___Black_rot', |
| 'Apple___Cedar_apple_rust', |
| 'Apple___healthy', |
| 'Blueberry___healthy', |
| 'Cherry_(including_sour)__Powdery_mildew', |
| 'Cherry(including_sour)__healthy', |
| 'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot', |
| 'Corn(maize)_Common_rust', |
| 'Corn(maize)__Northern_Leaf_Blight', |
| 'Corn(maize)healthy', |
| 'Grape___Black_rot', |
| 'Grape___Esca(Black_Measles)', |
| 'Grape___Leaf_blight(Isariopsis_Leaf_Spot)', |
| 'Grape___healthy', |
| 'Orange___Haunglongbing(Citrus_greening)', |
| 'Peach___Bacterial_spot', |
| 'Peach___healthy', |
| 'Pepper,_bell___Bacterial_spot', |
| 'Pepper,_bell___healthy', |
| 'Potato___Early_blight', |
| 'Potato___Late_blight', |
| 'Potato___healthy', |
| 'Raspberry___healthy', |
| 'Soybean___healthy', |
| 'Squash___Powdery_mildew', |
| 'Strawberry___Leaf_scorch', |
| 'Strawberry___healthy', |
| 'Tomato___Bacterial_spot', |
| 'Tomato___Early_blight', |
| 'Tomato___Late_blight', |
| 'Tomato___Leaf_Mold', |
| 'Tomato___Septoria_leaf_spot', |
| 'Tomato___Spider_mites Two-spotted_spider_mite', |
| 'Tomato___Target_Spot', |
| 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', |
| 'Tomato___Tomato_mosaic_virus', |
| 'Tomato___healthy' |
| ] |
|
|
| |
| model = None |
|
|
| def load_model(): |
| global model |
| try: |
| model = ResNet9(3, len(CLASS_NAMES)) |
| state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu") |
| model.load_state_dict(state_dict) |
| model.eval() |
| print("✅ Model loaded successfully") |
| return True |
| except Exception as e: |
| print(f"❌ Model load failed: {e}") |
| return False |
|
|
| def predict_disease(image): |
| """Predict plant disease from image""" |
| if model is None: |
| if not load_model(): |
| return {"Error": "Model not available"} |
| |
| |
| transform = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| transforms.ToTensor() |
| ]) |
| |
| try: |
| |
| if image is None: |
| return {"Error": "No image provided"} |
| |
| img_tensor = transform(image).unsqueeze(0) |
| |
| |
| with torch.no_grad(): |
| outputs = model(img_tensor) |
| probabilities = F.softmax(outputs[0], dim=0) |
| |
| |
| top5_prob, top5_indices = torch.topk(probabilities, 5) |
| |
| |
| results = {} |
| for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): |
| class_name = CLASS_NAMES[idx.item()] |
| |
| clean_name = class_name.replace('___', ' - ').replace('_', ' ') |
| results[clean_name] = float(prob) |
| |
| return results |
| |
| except Exception as e: |
| return {"Error": f"Prediction failed: {str(e)}"} |
|
|
| def format_class_info(): |
| """Format class information for display""" |
| plants = {} |
| for class_name in CLASS_NAMES: |
| if '___' in class_name: |
| plant, condition = class_name.split('___', 1) |
| if plant not in plants: |
| plants[plant] = [] |
| plants[plant].append(condition.replace('_', ' ')) |
| |
| info = "## Supported Plants and Conditions:\n\n" |
| for plant, conditions in sorted(plants.items()): |
| info += f"**{plant.replace('_', ' ')}**: {', '.join(conditions)}\n\n" |
| |
| return info |
|
|
| |
| load_model() |
|
|
| |
| with gr.Blocks(title="🌱 CropGuard - Plant Disease Detection", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # 🌱 CropGuard - Plant Disease Detection |
| |
| Upload an image of a plant leaf to detect diseases using our ResNet-9 model trained on the PlantVillage dataset. |
| |
| **Supported formats**: JPG, PNG, JPEG |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image( |
| type="pil", |
| label="Upload Plant Image", |
| height=400 |
| ) |
| predict_btn = gr.Button("🔍 Analyze Disease", variant="primary", size="lg") |
| |
| with gr.Column(): |
| output = gr.Label( |
| label="Disease Prediction Results", |
| num_top_classes=5, |
| show_label=True |
| ) |
| |
| |
| gr.Markdown("### 📋 Examples") |
| gr.Markdown("Try uploading images of plant leaves to see the disease detection in action!") |
| |
| |
| with gr.Accordion("ℹ️ Supported Plants & Diseases", open=False): |
| gr.Markdown(format_class_info()) |
| |
| |
| predict_btn.click( |
| fn=predict_disease, |
| inputs=image_input, |
| outputs=output |
| ) |
| |
| |
| image_input.change( |
| fn=predict_disease, |
| inputs=image_input, |
| outputs=output |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |