Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import timm # for EfficientNet-B3 | |
| # ------------------------------- | |
| # Download & Load Model | |
| # ------------------------------- | |
| model_path = hf_hub_download( | |
| repo_id="VisionaryQuant/5_Crop_Disease_Detection", | |
| filename="best_crop_disease_model.pt" | |
| ) | |
| # Recreate EfficientNet-B3 architecture | |
| model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=17) | |
| state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| # ------------------------------- | |
| # Define Preprocessing | |
| # ------------------------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((300, 300)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # ------------------------------- | |
| # Class Index → Label Mapping | |
| # ------------------------------- | |
| idx2label = { | |
| 0: "Corn___Common_Rust", | |
| 1: "Corn___Gray_Leaf_Spot", | |
| 2: "Corn___Northern_Leaf_Blight", | |
| 3: "Corn___Healthy", | |
| 4: "Potato___Early_Blight", | |
| 5: "Potato___Late_Blight", | |
| 6: "Potato___Healthy", | |
| 7: "Rice___Brown_Spot", | |
| 8: "Rice___Leaf_Blast", | |
| 9: "Rice___Neck_Blast", | |
| 10: "Rice___Healthy", | |
| 11: "Wheat___Yellow_Rust", | |
| 12: "Wheat___Brown_Rust", | |
| 13: "Wheat___Healthy", | |
| 14: "Sugarcane___Red_Rot", | |
| 15: "Sugarcane___Bacterial_Blight", | |
| 16: "Sugarcane___Healthy" | |
| } | |
| # ------------------------------- | |
| # Inference Function | |
| # ------------------------------- | |
| def predict(image): | |
| if image is None: | |
| return "No image uploaded", {} | |
| image = Image.fromarray(image).convert("RGB") | |
| input_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(input_tensor) | |
| probs = torch.nn.functional.softmax(logits, dim=1).squeeze() | |
| predicted_idx = torch.argmax(probs).item() | |
| predicted_label = idx2label[predicted_idx] | |
| probs_dict = {idx2label[i]: float(probs[i]) for i in range(len(idx2label))} | |
| return predicted_label, probs_dict | |
| # ------------------------------- | |
| # Gradio Interface | |
| # ------------------------------- | |
| title = "🌾 Crop Disease Detection" | |
| description = """ | |
| Upload an image of a crop leaf to detect diseases. | |
| Model: **EfficientNet-B3** trained on 13,000+ images across 5 crops (Corn, Potato, Rice, Wheat, Sugarcane). | |
| **Accuracy:** 94.8% | **Precision:** 95.4% | **Recall:** 94.5% | |
| """ | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy", label="Upload Crop Leaf Image"), | |
| outputs=[ | |
| gr.Label(label="Predicted Disease"), | |
| gr.Label(label="Confidence Scores") | |
| ], | |
| title=title, | |
| description=description | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |