croppred2 / app.py
onlyshrey98's picture
Update app.py
6204ab9 verified
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import timm # for EfficientNet-B3
# -------------------------------
# Download & Load Model
# -------------------------------
model_path = hf_hub_download(
repo_id="VisionaryQuant/5_Crop_Disease_Detection",
filename="best_crop_disease_model.pt"
)
# Recreate EfficientNet-B3 architecture
model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=17)
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict, strict=False)
model.eval()
# -------------------------------
# Define Preprocessing
# -------------------------------
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# -------------------------------
# Class Index → Label Mapping
# -------------------------------
idx2label = {
0: "Corn___Common_Rust",
1: "Corn___Gray_Leaf_Spot",
2: "Corn___Northern_Leaf_Blight",
3: "Corn___Healthy",
4: "Potato___Early_Blight",
5: "Potato___Late_Blight",
6: "Potato___Healthy",
7: "Rice___Brown_Spot",
8: "Rice___Leaf_Blast",
9: "Rice___Neck_Blast",
10: "Rice___Healthy",
11: "Wheat___Yellow_Rust",
12: "Wheat___Brown_Rust",
13: "Wheat___Healthy",
14: "Sugarcane___Red_Rot",
15: "Sugarcane___Bacterial_Blight",
16: "Sugarcane___Healthy"
}
# -------------------------------
# Inference Function
# -------------------------------
def predict(image):
if image is None:
return "No image uploaded", {}
image = Image.fromarray(image).convert("RGB")
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.nn.functional.softmax(logits, dim=1).squeeze()
predicted_idx = torch.argmax(probs).item()
predicted_label = idx2label[predicted_idx]
probs_dict = {idx2label[i]: float(probs[i]) for i in range(len(idx2label))}
return predicted_label, probs_dict
# -------------------------------
# Gradio Interface
# -------------------------------
title = "🌾 Crop Disease Detection"
description = """
Upload an image of a crop leaf to detect diseases.
Model: **EfficientNet-B3** trained on 13,000+ images across 5 crops (Corn, Potato, Rice, Wheat, Sugarcane).
**Accuracy:** 94.8% | **Precision:** 95.4% | **Recall:** 94.5%
"""
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Crop Leaf Image"),
outputs=[
gr.Label(label="Predicted Disease"),
gr.Label(label="Confidence Scores")
],
title=title,
description=description
)
if __name__ == "__main__":
demo.launch()