import gradio as gr import torch import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import os # Import your model from models import ResNet9 # Plant disease class names CLASS_NAMES = [ 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)__Powdery_mildew', 'Cherry(including_sour)__healthy', 'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot', 'Corn(maize)_Common_rust', 'Corn(maize)__Northern_Leaf_Blight', 'Corn(maize)healthy', 'Grape___Black_rot', 'Grape___Esca(Black_Measles)', 'Grape___Leaf_blight(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy' ] # Load model model = None def load_model(): global model try: model = ResNet9(3, len(CLASS_NAMES)) state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu") model.load_state_dict(state_dict) model.eval() print("✅ Model loaded successfully") return True except Exception as e: print(f"❌ Model load failed: {e}") return False def predict_disease(image): """Predict plant disease from image""" if model is None: if not load_model(): return {"Error": "Model not available"} # Transform image transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor() ]) try: # Convert and transform image if image is None: return {"Error": "No image provided"} img_tensor = transform(image).unsqueeze(0) # Make prediction with torch.no_grad(): outputs = model(img_tensor) probabilities = F.softmax(outputs[0], dim=0) # Get top 5 predictions top5_prob, top5_indices = torch.topk(probabilities, 5) # Format results for Gradio results = {} for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): class_name = CLASS_NAMES[idx.item()] # Clean up class name for display clean_name = class_name.replace('___', ' - ').replace('_', ' ') results[clean_name] = float(prob) return results except Exception as e: return {"Error": f"Prediction failed: {str(e)}"} def format_class_info(): """Format class information for display""" plants = {} for class_name in CLASS_NAMES: if '___' in class_name: plant, condition = class_name.split('___', 1) if plant not in plants: plants[plant] = [] plants[plant].append(condition.replace('_', ' ')) info = "## Supported Plants and Conditions:\n\n" for plant, conditions in sorted(plants.items()): info += f"**{plant.replace('_', ' ')}**: {', '.join(conditions)}\n\n" return info # Load model on startup load_model() # Create Gradio interface with gr.Blocks(title="🌱 CropGuard - Plant Disease Detection", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🌱 CropGuard - Plant Disease Detection Upload an image of a plant leaf to detect diseases using our ResNet-9 model trained on the PlantVillage dataset. **Supported formats**: JPG, PNG, JPEG """) with gr.Row(): with gr.Column(): image_input = gr.Image( type="pil", label="Upload Plant Image", height=400 ) predict_btn = gr.Button("🔍 Analyze Disease", variant="primary", size="lg") with gr.Column(): output = gr.Label( label="Disease Prediction Results", num_top_classes=5, show_label=True ) # Example images (you can add these later) gr.Markdown("### 📋 Examples") gr.Markdown("Try uploading images of plant leaves to see the disease detection in action!") # Info section with gr.Accordion("â„šī¸ Supported Plants & Diseases", open=False): gr.Markdown(format_class_info()) # Event handlers predict_btn.click( fn=predict_disease, inputs=image_input, outputs=output ) # Also predict on image upload image_input.change( fn=predict_disease, inputs=image_input, outputs=output ) # Launch the app if __name__ == "__main__": demo.launch()