import torch import torch.nn as nn from torchvision import transforms from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download import timm # for EfficientNet-B3 # ------------------------------- # Download & Load Model # ------------------------------- model_path = hf_hub_download( repo_id="VisionaryQuant/5_Crop_Disease_Detection", filename="best_crop_disease_model.pt" ) # Recreate EfficientNet-B3 architecture model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=17) state_dict = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(state_dict, strict=False) model.eval() # ------------------------------- # Define Preprocessing # ------------------------------- transform = transforms.Compose([ transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # ------------------------------- # Class Index → Label Mapping # ------------------------------- idx2label = { 0: "Corn___Common_Rust", 1: "Corn___Gray_Leaf_Spot", 2: "Corn___Northern_Leaf_Blight", 3: "Corn___Healthy", 4: "Potato___Early_Blight", 5: "Potato___Late_Blight", 6: "Potato___Healthy", 7: "Rice___Brown_Spot", 8: "Rice___Leaf_Blast", 9: "Rice___Neck_Blast", 10: "Rice___Healthy", 11: "Wheat___Yellow_Rust", 12: "Wheat___Brown_Rust", 13: "Wheat___Healthy", 14: "Sugarcane___Red_Rot", 15: "Sugarcane___Bacterial_Blight", 16: "Sugarcane___Healthy" } # ------------------------------- # Inference Function # ------------------------------- def predict(image): if image is None: return "No image uploaded", {} image = Image.fromarray(image).convert("RGB") input_tensor = transform(image).unsqueeze(0) with torch.no_grad(): logits = model(input_tensor) probs = torch.nn.functional.softmax(logits, dim=1).squeeze() predicted_idx = torch.argmax(probs).item() predicted_label = idx2label[predicted_idx] probs_dict = {idx2label[i]: float(probs[i]) for i in range(len(idx2label))} return predicted_label, probs_dict # ------------------------------- # Gradio Interface # ------------------------------- title = "🌾 Crop Disease Detection" description = """ Upload an image of a crop leaf to detect diseases. Model: **EfficientNet-B3** trained on 13,000+ images across 5 crops (Corn, Potato, Rice, Wheat, Sugarcane). **Accuracy:** 94.8% | **Precision:** 95.4% | **Recall:** 94.5% """ demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", label="Upload Crop Leaf Image"), outputs=[ gr.Label(label="Predicted Disease"), gr.Label(label="Confidence Scores") ], title=title, description=description ) if __name__ == "__main__": demo.launch()