Spaces:
Runtime error
Runtime error
File size: 7,296 Bytes
07dd47d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import gradio as gr
import torch
from transformers import AutoImageProcessor, SiglipForImageClassification, pipeline
from torchvision import transforms
from PIL import Image
import numpy as np
import os
# -------------------------------
# Model paths (local folders)
# -------------------------------
hf_model_names = {
"Rice": "models/Rice-Leaf-Disease",
"Sugarcane": "models/sugarcane-plant-diseases-classification",
"Tomato": "models/tomato-leaf-disease-classification-resnet50",
"Corn/Wheat": "models/crop_leaf_diseases_vit"
}
# -------------------------------
# Utility: Load model offline or online
# -------------------------------
def load_model_or_fallback(model_name, model_path, use_pipeline=False, skip_processor=False):
if os.path.exists(model_path):
print(f"β
Loading local model: {model_path}")
if use_pipeline:
return pipeline("image-classification", model=model_path)
elif skip_processor:
model = SiglipForImageClassification.from_pretrained(model_path)
return None, model
else:
processor = AutoImageProcessor.from_pretrained(model_path)
model = SiglipForImageClassification.from_pretrained(model_path)
return processor, model
else:
print(f"π Model not found locally. Fetching from Hugging Face Hub: {model_name}")
if use_pipeline:
return pipeline("image-classification", model=model_name)
elif skip_processor:
model = SiglipForImageClassification.from_pretrained(model_name)
return None, model
else:
processor = AutoImageProcessor.from_pretrained(model_name)
model = SiglipForImageClassification.from_pretrained(model_name)
return processor, model
# -------------------------------
# Load models
# -------------------------------
hf_processors = {}
hf_models = {}
# Rice
hf_processors['Rice'], hf_models['Rice'] = load_model_or_fallback(
"prithivMLmods/Rice-Leaf-Disease", hf_model_names["Rice"]
)
# Sugarcane (skip processor)
_, hf_models['Sugarcane'] = load_model_or_fallback(
"dwililiya/sugarcane-plant-diseases-classification",
hf_model_names["Sugarcane"],
skip_processor=True
)
# Tomato (pipeline)
hf_models['Tomato'] = load_model_or_fallback(
"wellCh4n/tomato-leaf-disease-classification-resnet50",
hf_model_names["Tomato"], use_pipeline=True
)
# Corn/Wheat (pipeline)
hf_models['Corn/Wheat'] = load_model_or_fallback(
"wambugu71/crop_leaf_diseases_vit",
hf_model_names["Corn/Wheat"], use_pipeline=True
)
# -------------------------------
# Sugarcane manual preprocessing
# -------------------------------
sugarcane_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# -------------------------------
# Disease mapping
# -------------------------------
disease_dict = {
"Rice": ["Bacterial Blight", "Blast", "Brown Spot", "Healthy", "Tungro"],
"Sugarcane": ["Bacterial Blight", "Healthy", "Mosaic", "Red Rot", "Rust", "Yellow"],
"Tomato": ["Early Blight", "Late Blight", "Healthy"],
"Corn/Wheat": ["Healthy", "Rust", "Blight", "Leaf Spot"] # Adjust based on your model labels
}
# Remedies mapping
remedies = {
"Early Blight": "Remove infected leaves, apply fungicide.",
"Late Blight": "Use fungicides and remove infected plants.",
"Bacterial Blight": "Use resistant varieties and avoid overhead watering.",
"Blast": "Use balanced fertilizer, apply fungicide.",
"Brown Spot": "Ensure proper field drainage and avoid overcrowding.",
"Tungro": "Control green leafhoppers and remove infected plants.",
"Mosaic": "Remove infected plants, avoid spread.",
"Red Rot": "Remove infected plants, apply fungicide.",
"Rust": "Use fungicide and resistant varieties.",
"Yellow": "Monitor plant, apply preventive measures.",
"Leaf Spot": "Remove affected leaves and apply fungicide.",
"Blight": "Use disease-free seeds and apply fungicides.",
"Healthy": "No action needed."
}
# -------------------------------
# Prediction function
# -------------------------------
def predict_disease(crop, img):
if img is None:
return "No image uploaded", "Please upload a leaf image."
img_pil = Image.fromarray(img).convert("RGB")
if crop == "Rice":
inputs = hf_processors[crop](images=img_pil, return_tensors="pt")
with torch.no_grad():
outputs = hf_models[crop](**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
predicted_idx = int(np.argmax(probs))
disease = disease_dict[crop][predicted_idx]
advice = remedies.get(disease, "No advice available.")
return disease, advice
elif crop == "Sugarcane":
img_tensor = sugarcane_transform(img_pil).unsqueeze(0)
with torch.no_grad():
outputs = hf_models[crop](img_tensor)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
predicted_idx = int(np.argmax(probs))
disease = disease_dict[crop][predicted_idx]
advice = remedies.get(disease, "No advice available.")
return disease, advice
elif crop in ["Tomato", "Corn/Wheat"]:
result = hf_models[crop](img_pil)[0]
disease = result['label']
advice = remedies.get(disease, "No advice available.")
return disease, advice
else:
return "Error", f"Model for {crop} is not available."
# -------------------------------
# Gradio Interface
# -------------------------------
custom_css = """
body, .gradio-container {
background-image: url('https://media.istockphoto.com/id/1328004520/photo/healthy-young-soybean-crop-in-field-at-dawn.jpg?s=612x612&w=0&k=20&c=XRw20PArfhkh6LLgFrgvycPLm0Uy9y7lu9U7fLqabVY=');
background-size: cover;
background-repeat: no-repeat;
background-attachment: fixed;
min-height: 100vh !important;
}
.gradio-container > * {
background-color: rgba(255, 255, 255, 0.88) !important;
border-radius: 15px;
padding: 20px;
}
"""
with gr.Blocks(css=custom_css) as app:
gr.Markdown("## πΏ Crop Disease Detector")
gr.Markdown("Upload a leaf image of your crop and get AI-based disease prediction with remedies.")
with gr.Row():
with gr.Column():
crop_input = gr.Dropdown(list(hf_model_names.keys()), label="Select Crop")
img_input = gr.Image(type="numpy", label="Upload Leaf Image")
predict_btn = gr.Button("π Predict Disease")
with gr.Column():
disease_output = gr.Textbox(label="Predicted Disease")
advice_output = gr.Textbox(label="Recommended Action")
predict_btn.click(predict_disease, inputs=[crop_input, img_input], outputs=[disease_output, advice_output])
# Launch
app.launch(server_name="127.0.0.1", server_port=7860, share=True)
|