Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| from transformers import ViTForImageClassification | |
| # Disease class mappings - 19 classes from your trained model | |
| DISEASE_CLASSES = [ | |
| "Chilli - Healthy", # 00 | |
| "Chilli - Leaf Curl Virus", # 01 | |
| "Pepper Bell - Bacterial Spot", # 02 | |
| "Pepper Bell - Healthy", # 03 | |
| "Potato - Early Blight", # 04 | |
| "Potato - Healthy", # 05 | |
| "Potato - Late Blight", # 06 | |
| "Tomato - Bacterial Spot", # 07 | |
| "Tomato - Early Blight", # 08 | |
| "Tomato - Healthy", # 09 | |
| "Tomato - Late Blight", # 10 | |
| "Tomato - Leaf Mold", # 11 | |
| "Tomato - Mosaic Virus", # 12 | |
| "Tomato - Septoria Leaf Spot", # 13 | |
| "Tomato - Target Spot", # 14 | |
| "Tomato - Two Spotted Spider Mite", # 15 | |
| "Tomato - Yellow Leaf Curl Virus", # 16 | |
| "GroundNut - Healthy", # 17 | |
| "GroundNut - Rust" # 18 | |
| ] | |
| # Disease information database (simplified) | |
| DISEASE_INFO = { | |
| "Chilli - Healthy": { | |
| "description": "The chilli plant appears healthy with no visible signs of disease.", | |
| "treatment": "Continue good agricultural practices and regular monitoring." | |
| }, | |
| "Chilli - Leaf Curl Virus": { | |
| "description": "Leaf curl virus causes leaves to curl, wrinkle, and become distorted.", | |
| "treatment": "Remove infected plants, control whiteflies with neem oil, use yellow sticky traps." | |
| }, | |
| "Tomato - Early Blight": { | |
| "description": "Early blight causes characteristic target-spot patterns on older leaves.", | |
| "treatment": "Apply fungicides (chlorothalonil, mancozeb), remove infected leaves, improve air circulation." | |
| }, | |
| "Potato - Late Blight": { | |
| "description": "Late blight is a devastating disease that can destroy entire crops rapidly.", | |
| "treatment": "Apply systemic fungicides immediately, destroy infected plants, improve air circulation." | |
| } | |
| } | |
| # Global variables | |
| model = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def load_model(): | |
| """Load the TinyViT student model""" | |
| global model | |
| # Look for model file | |
| model_paths = [ | |
| "best_student_attention_kd.pth", | |
| "model/best_student_attention_kd.pth" | |
| ] | |
| model_path = None | |
| for path in model_paths: | |
| if os.path.exists(path): | |
| model_path = path | |
| break | |
| if model_path is None: | |
| raise FileNotFoundError("Model file not found") | |
| try: | |
| print("Loading TinyViT student model...") | |
| # Initialize TinyViT model architecture | |
| model = ViTForImageClassification.from_pretrained( | |
| "WinKawaks/vit-tiny-patch16-224", | |
| num_labels=len(DISEASE_CLASSES), | |
| ignore_mismatched_sizes=True | |
| ) | |
| # Load trained weights | |
| checkpoint = torch.load(model_path, map_location=device) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.to(device) | |
| model.eval() | |
| print(f"β Model loaded successfully on {device}") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| def predict_disease(image): | |
| """Predict plant disease from image""" | |
| if model is None: | |
| return "β Model not loaded", "", "" | |
| if image is None: | |
| return "β No image provided", "", "" | |
| try: | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=1) | |
| confidence, predicted_idx = torch.max(probabilities, 1) | |
| predicted_class = DISEASE_CLASSES[predicted_idx.item()] | |
| confidence_score = confidence.item() | |
| # Parse results | |
| parts = predicted_class.split(" - ") | |
| crop_type = parts[0] if len(parts) > 0 else "Unknown" | |
| disease = parts[1] if len(parts) > 1 else predicted_class | |
| status = "π’ Healthy" if "healthy" in disease.lower() else "π΄ Diseased" | |
| # Get disease info | |
| disease_info = DISEASE_INFO.get(predicted_class, { | |
| "description": f"Information about {disease}.", | |
| "treatment": "Consult with a plant pathologist or agricultural extension service." | |
| }) | |
| # Format results | |
| result = f""" | |
| ## π± **Crop Type:** {crop_type} | |
| ## π¦ **Disease:** {disease} | |
| ## π **Status:** {status} | |
| ## π― **Confidence:** {confidence_score:.2%} | |
| ### π **Description:** | |
| {disease_info['description']} | |
| ### π **Treatment:** | |
| {disease_info['treatment']} | |
| """ | |
| return result, predicted_class, f"Confidence: {confidence_score:.2%}" | |
| except Exception as e: | |
| return f"β Error processing image: {str(e)}", "", "" | |
| # Load model on startup | |
| print("Initializing Plant Disease Detection Model...") | |
| model_loaded = load_model() | |
| if not model_loaded: | |
| print("β οΈ Model failed to load - running in demo mode") | |
| # Create Gradio interface with simpler configuration | |
| demo = gr.Interface( | |
| fn=predict_disease, | |
| inputs=gr.Image(type="pil", label="πΈ Upload Plant Image"), | |
| outputs=[ | |
| gr.Markdown(label="π Analysis Results"), | |
| gr.Textbox(label="π·οΈ Predicted Class"), | |
| gr.Textbox(label="π Confidence Score") | |
| ], | |
| title="π± Plant Disease Detection AI", | |
| description=""" | |
| Upload an image of a plant leaf to detect diseases and get treatment recommendations. | |
| **Supported Plants:** Chilli, Pepper Bell, Potato, Tomato, GroundNut | |
| **Supported Diseases:** 19 different disease categories including healthy plants | |
| """, | |
| examples=None, # You can add example images here later | |
| cache_examples=False | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |