AgriGuard / app.py
nafees369's picture
Update app.py
68315bf verified
"""
PlantCity - Three Models in One Space (PyTorch Version):
Tab 1 Β· Species Classifier β†’ 12 plant species with OOD detection
Tab 2 Β· Disease Classifier β†’ 52 disease classes with OOD detection
Tab 3 Β· Severity Estimator β†’ 5 severity levels (None, Low, Medium, High, Critical)
"""
import gradio as gr
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import cv2
import traceback
import os
import requests
from pathlib import Path
import gdown
# ═══════════════════════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════
IMG_SIZE = (224, 224)
# ── Model file names (PyTorch .pth format) ─────────────────────────────────────
SPECIES_MODEL_FILE = "plant_species_final.pth"
DISEASE_MODEL_FILE = "plant_disease_final.pth"
SEVERITY_MODEL_FILE = "severity_final.pth"
# ── Google Drive file IDs (from Hugging Face Secrets) ─────────────────────────
SPECIES_MODEL_ID = os.environ.get('SPECIES_MODEL_ID', '')
DISEASE_MODEL_ID = os.environ.get('DISEASE_MODEL_ID', '')
SEVERITY_MODEL_ID = os.environ.get('SEVERITY_MODEL_ID', '')
# ── Leaf Detection Thresholds ───────────────────────────────────────────────────
LEAF_GREEN_THRESHOLD = 0.08
LEAF_YELLOW_THRESHOLD = 0.08
# ── Model 1: Species (12 classes) ──────────────────────────────────────────────
SPECIES_OOD_THRESHOLD = 0.65
SPECIES_CLASSES = [
"Apple", "Apricot", "Bean", "Cherry", "Corn", "Fig",
"Grape", "Lokat", "Pear", "Walnut", "Persimmons", "Tomato"
]
SPECIES_EMOJI = {
"Apple": "🍎", "Apricot": "πŸ‘", "Bean": "🌱", "Cherry": "πŸ’",
"Corn": "🌽", "Fig": "🍈", "Grape": "πŸ‡", "Lokat": "🍊",
"Pear": "🍐", "Walnut": "🌰", "Persimmons": "πŸ…", "Tomato": "πŸ…"
}
NUM_SPECIES = len(SPECIES_CLASSES)
# ── Model 2: Disease (52 classes) ──────────────────────────────────────────────
DISEASE_OOD_THRESHOLD = 0.60
DISEASE_CONF_THRESHOLD = 0.50
DISEASE_CLASSES = [
"Apple Brown_spot", "Apple Normal", "Apple black_spot", "Apricot Normal",
"Apricot blight leaf disease", "Apricot shot_hole", "Bean Fungal_leaf disease",
"Bean Normal leaf", "Bean bean rust image", "Bean shot_hole", "Cherry Leaf Scorch",
"Cherry Normal leaf", "Cherry brown_spot", "Cherry purple leaf spot",
"Cherry_shot hole disease", "Corn Fungal leaf", "Corn Normal leaf",
"Corn gray leaf spot", "Corn holcus_ leaf spot", "Fig Blight_leaf disease",
"Fig Brown spot", "Fig normal leaf", "Fig_rust leaf", "Grape Anthracnose leaf",
"Grape Brown spot leaf", "Grape Downy mildew leaf", "Grape Mites_leaf disease",
"Grape Normal_leaf", "Grape Powdery_mildew leaf", "Grape shot hole leaf disease",
"Lokat Normal leaf", "Pear Black spot _ leaf disease", "Pear Normal _leaf",
"Pear fire blight", "Walnut Anthracnose_leaf disease", "Walnut Blotch_leaf disease",
"Walnut Normal_leaf", "Walnut Shot_hole", "Walnut leaf gall mite",
"lokat Leaf_spot", "persimmons Brown_spot", "tomato Fusarium Wilt",
"tomato spider mites", "tomato verticillium wilt", "tomato_bacterial_spot",
"tomato_early_blight", "tomato_healthy_leaf", "tomato_late_blight",
"tomato_leaf_curl", "tomato_leaf_miner", "tomato_leaf_mold", "tomato_septoria_leaf"
]
NUM_DISEASES = len(DISEASE_CLASSES)
# ── Model 3: Severity (5 levels) ───────────────────────────────────────────────
SEVERITY_CLASSES = ["None", "Low", "Medium", "High", "Critical"]
SEVERITY_LEVELS = {
"None": "Healthy β€” 0% affected",
"Low": "1–25% affected",
"Medium": "26–50% affected",
"High": "51–75% affected",
"Critical": ">75% affected (severe damage)"
}
SEVERITY_ICONS = {
"None": "βœ…", "Low": "🟑", "Medium": "🟠", "High": "πŸ”΄", "Critical": "πŸ”΄πŸ”₯"
}
NUM_SEVERITY = len(SEVERITY_CLASSES)
# ═══════════════════════════════════════════════════════════════════════════════
# MODEL ARCHITECTURES
# ═══════════════════════════════════════════════════════════════════════════════
class EfficientNetSpecies(nn.Module):
"""Species identification model with OOD head"""
def __init__(self, num_classes=12):
super(EfficientNetSpecies, self).__init__()
# Use EfficientNet-B0 as backbone
self.backbone = models.efficientnet_b0(weights=None)
num_features = self.backbone.classifier[1].in_features
# Species classifier
self.backbone.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(num_features, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
# OOD head
self.ood_head = nn.Sequential(
nn.Linear(num_features, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 1)
)
def forward(self, x):
features = self.backbone.features(x)
features = self.backbone.avgpool(features)
features = torch.flatten(features, 1)
species_logits = self.backbone.classifier(features)
ood_logits = self.ood_head(features)
return species_logits, ood_logits
class EfficientNetDisease(nn.Module):
"""Disease classification model"""
def __init__(self, num_classes=52):
super(EfficientNetDisease, self).__init__()
self.backbone = models.efficientnet_b0(weights=None)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(num_features, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.backbone(x)
class EfficientNetSeverity(nn.Module):
"""Severity assessment model"""
def __init__(self, num_classes=5):
super(EfficientNetSeverity, self).__init__()
self.backbone = models.efficientnet_b0(weights=None)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(num_features, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
def forward(self, x):
return self.backbone(x)
# ═══════════════════════════════════════════════════════════════════════════════
# MODEL DOWNLOAD AND LOADING
# ═══════════════════════════════════════════════════════════════════════════════
_models = {}
_model_status = {}
def download_file_from_google_drive(file_id, destination):
"""Download file from Google Drive"""
try:
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, destination, quiet=False)
return os.path.exists(destination)
except Exception as e:
print(f"Error downloading from Google Drive: {e}")
return False
def download_model(file_id, output_path):
"""Download model from Google Drive using file ID"""
if not file_id or file_id == '':
print(f"No file ID provided for {output_path}")
return False
try:
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
print(f"Downloading model to {output_path}...")
if download_file_from_google_drive(file_id, output_path):
print(f"Successfully downloaded {output_path}")
return True
else:
print(f"Download failed for {output_path}")
return False
except Exception as e:
print(f"Error downloading model: {e}")
return False
def load_model(model_type, model_file, file_id, model_class, num_classes):
"""Load PyTorch model with weights"""
global _models, _model_status
# Check if already loaded
if model_type in _models:
return _models[model_type]
# Check local file first
local_path = model_file
if os.path.exists(local_path):
model_path = local_path
print(f"Found local model: {model_path}")
else:
model_path = f"/tmp/{model_file}"
if not os.path.exists(model_path):
print(f"Model file not found. Downloading from Google Drive...")
if download_model(file_id, model_path):
print(f"Model downloaded successfully to {model_path}")
else:
_model_status[model_type] = f"Failed to download model."
return None
# Load the model
try:
print(f"Loading model from {model_path}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model instance
model = model_class(num_classes=num_classes)
# Load state dict
state_dict = torch.load(model_path, map_location=device)
# Handle different state dict formats
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
elif 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
# Remove 'module.' prefix if present (from DataParallel)
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
model = model.to(device)
model.eval()
_models[model_type] = model
_model_status[model_type] = "Loaded successfully"
print(f"Model {model_type} loaded successfully on {device}")
return model
except Exception as e:
error_msg = f"Error loading model: {str(e)}"
print(error_msg)
_model_status[model_type] = error_msg
return None
# ═══════════════════════════════════════════════════════════════════════════════
# IMAGE PREPROCESSING
# ═══════════════════════════════════════════════════════════════════════════════
def get_transform():
"""Get image transformation pipeline"""
return transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def detect_leaf(image: Image.Image) -> tuple:
"""Detect if image contains a leaf using green-yellow pixel ratio"""
img_array = np.array(image.convert("RGB"))
h, w = img_array.shape[:2]
# Normalize to [0, 1]
img_float = img_array.astype(np.float32) / 255.0
r, g, b = img_float[..., 0], img_float[..., 1], img_float[..., 2]
# Green pixel detection
green_mask = (g > r + 0.05) & (g > b + 0.05) & (g > 0.15)
green_ratio = green_mask.mean()
# Yellow pixel detection
yellow_mask = (r > 0.3) & (g > 0.3) & (b < 0.35) & (np.abs(r - g) < 0.3)
yellow_ratio = yellow_mask.mean()
# Edge detection
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
edges = cv2.Canny(blurred, 50, 150)
edge_complexity = np.sum(edges > 0) / (h * w)
# Determine if leaf
is_leaf = (green_ratio + yellow_ratio) / 2 > LEAF_GREEN_THRESHOLD or green_ratio > 0.05
details = {
"green_ratio": f"{green_ratio:.2%}",
"yellow_ratio": f"{yellow_ratio:.2%}",
"edge_complexity": f"{edge_complexity:.2%}"
}
return is_leaf, details
def preprocess_image(image: Image.Image):
"""Preprocess image for PyTorch model"""
transform = get_transform()
img_tensor = transform(image.convert("RGB"))
return img_tensor.unsqueeze(0) # Add batch dimension
def parse_disease(disease_name):
"""Parse disease name to extract plant and condition"""
parts = disease_name.split('_')
plant = parts[0]
if len(parts) > 1:
condition = ' '.join(parts[1:])
else:
condition = "Unknown"
healthy = "Normal" in disease_name or "healthy" in disease_name.lower()
return plant, condition, healthy
# ═══════════════════════════════════════════════════════════════════════════════
# PREDICTION FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════════
def predict_species(image):
"""Predict plant species from leaf image"""
if image is None:
return ("⚠️ No Image", "Please upload a leaf image to get started.", {})
# Load model
model = load_model('species', SPECIES_MODEL_FILE, SPECIES_MODEL_ID, EfficientNetSpecies, NUM_SPECIES)
if model is None:
return ("❌ Model Error", f"Species model not available.\n\n{_model_status.get('species', 'Unknown')}", {})
# Check if it's a leaf
is_leaf, details = detect_leaf(image)
if not is_leaf:
leaf_details = f"""
**Not a Leaf Image**
This image does not appear to be a plant leaf.
**Detection Metrics:**
- Green ratio: {details['green_ratio']}
- Yellow ratio: {details['yellow_ratio']}
- Edge complexity: {details['edge_complexity']}
Please upload a clear photo of a plant leaf.
"""
return ("🚫 Not a Leaf", leaf_details, {})
try:
device = next(model.parameters()).device
img_tensor = preprocess_image(image).to(device)
with torch.no_grad():
species_logits, ood_logits = model(img_tensor)
species_probs = torch.softmax(species_logits, dim=1)
ood_score = torch.sigmoid(ood_logits).item()
top_conf, top_idx = torch.max(species_probs, dim=1)
top_conf = top_conf.item()
top_idx = top_idx.item()
top_species = SPECIES_CLASSES[top_idx]
conf_dict = {c: float(p) for c, p in zip(SPECIES_CLASSES, species_probs[0].cpu().numpy())}
if top_conf < SPECIES_OOD_THRESHOLD:
return (
"πŸ” Unknown Species",
f"Confidence {top_conf:.1%} below threshold ({SPECIES_OOD_THRESHOLD:.0%}).",
conf_dict
)
emoji = SPECIES_EMOJI.get(top_species, "🌿")
result_text = f"**Species:** {top_species}\n**Confidence:** {top_conf:.1%}\n**Leaf Detection:** βœ“"
return (f"{emoji} {top_species}", result_text, conf_dict)
except Exception as e:
return ("❌ Error", f"{type(e).__name__}: {str(e)}", {})
def predict_disease(image):
"""Predict disease from leaf image"""
if image is None:
return ("⚠️ No Image", "Please upload a leaf image.", {})
# Load model
model = load_model('disease', DISEASE_MODEL_FILE, DISEASE_MODEL_ID, EfficientNetDisease, NUM_DISEASES)
if model is None:
return ("❌ Model Error", f"Disease model not available.\n\n{_model_status.get('disease', 'Unknown')}", {})
# Check if it's a leaf
is_leaf, details = detect_leaf(image)
if not is_leaf:
return ("🚫 Not a Leaf", f"Green ratio: {details['green_ratio']}", {})
try:
device = next(model.parameters()).device
img_tensor = preprocess_image(image).to(device)
with torch.no_grad():
disease_logits = model(img_tensor)
disease_probs = torch.softmax(disease_logits, dim=1)
top_conf, top_idx = torch.max(disease_probs, dim=1)
top_conf = top_conf.item()
top_idx = top_idx.item()
top_disease = DISEASE_CLASSES[top_idx]
conf_dict = {c: float(p) for c, p in zip(DISEASE_CLASSES, disease_probs[0].cpu().numpy())}
if top_conf < DISEASE_OOD_THRESHOLD:
return ("πŸ” Unknown", f"Confidence {top_conf:.1%} below threshold.", conf_dict)
plant, condition, healthy = parse_disease(top_disease)
if healthy:
return (f"βœ… {plant} β€” Healthy", f"Healthy leaf. Confidence: {top_conf:.1%}", conf_dict)
if top_conf < DISEASE_CONF_THRESHOLD:
return (f"⚠️ {plant} β€” Low Confidence", f"Possible: {condition}\nConfidence: {top_conf:.1%}", conf_dict)
return (f"🦠 {plant} β€” {condition}", f"Detected: {condition}\nConfidence: {top_conf:.1%}", conf_dict)
except Exception as e:
return ("❌ Error", f"{type(e).__name__}: {str(e)}", {})
def predict_severity(image):
"""Predict disease severity level"""
if image is None:
return ("⚠️ No Image", "Please upload a leaf image.", {})
# Load model
model = load_model('severity', SEVERITY_MODEL_FILE, SEVERITY_MODEL_ID, EfficientNetSeverity, NUM_SEVERITY)
if model is None:
return ("❌ Model Error", f"Severity model not available.\n\n{_model_status.get('severity', 'Unknown')}", {})
# Check if it's a leaf
is_leaf, details = detect_leaf(image)
if not is_leaf:
return ("🚫 Not a Leaf", f"Green ratio: {details['green_ratio']}", {})
try:
device = next(model.parameters()).device
img_tensor = preprocess_image(image).to(device)
with torch.no_grad():
severity_logits = model(img_tensor)
severity_probs = torch.softmax(severity_logits, dim=1)
top_conf, top_idx = torch.max(severity_probs, dim=1)
top_conf = top_conf.item()
top_idx = top_idx.item()
top_severity = SEVERITY_CLASSES[top_idx]
conf_dict = {c: float(p) for c, p in zip(SEVERITY_CLASSES, severity_probs[0].cpu().numpy())}
icon = SEVERITY_ICONS[top_severity]
desc = SEVERITY_LEVELS[top_severity]
result_text = f"**Severity:** {top_severity}\n**Description:** {desc}\n**Confidence:** {top_conf:.1%}"
return (f"{icon} {top_severity}", result_text, conf_dict)
except Exception as e:
return ("❌ Error", f"{type(e).__name__}: {str(e)}", {})
# ═══════════════════════════════════════════════════════════════════════════════
# BUILD INTERFACE
# ═══════════════════════════════════════════════════════════════════════════════
CSS = """
body, .gradio-container { background: #0f1a14 !important; }
.header { background: linear-gradient(135deg, #0a1f14, #1a4028); padding: 2rem; border-radius: 18px; text-align: center; margin-bottom: 1rem; }
.header h1 { color: #4ade80; font-size: 2.4rem; margin: 0; }
.tab-nav button { background: #16261d !important; color: #86efac !important; }
.tab-nav button.selected { background: #166534 !important; color: #4ade80 !important; }
button.primary { background: linear-gradient(135deg, #166534, #15803d) !important; color: white !important; }
"""
print("=" * 50)
print("Starting PlantCity AgriGuard (PyTorch Version)")
print(f"Species Model ID: {'SET' if SPECIES_MODEL_ID else 'NOT SET'}")
print(f"Disease Model ID: {'SET' if DISEASE_MODEL_ID else 'NOT SET'}")
print(f"Severity Model ID: {'SET' if SEVERITY_MODEL_ID else 'NOT SET'}")
print("=" * 50)
with gr.Blocks(title="PlantCity - AgriGuard", theme=gr.themes.Soft(), css=CSS) as demo:
gr.HTML("""
<div class="header">
<h1>🌿AgriGuard</h1>
<p>Species Identification Β· Disease Detection Β· Severity Assessment</p>
<p style="font-size: 0.85rem;">12 Species | 52 Diseases | 5 Severity Levels</p>
</div>
""")
with gr.Tabs():
with gr.TabItem("🌿 Species Identification"):
with gr.Row():
with gr.Column(scale=1):
species_img = gr.Image(type="pil", label="Upload Leaf Image", height=300)
species_btn = gr.Button("πŸ” Identify Species", variant="primary")
with gr.Column(scale=1):
species_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1)
species_details = gr.Textbox(label="DETAILS", interactive=False, lines=6)
species_probs = gr.Label(label="CLASS PROBABILITIES", num_top_classes=8)
species_btn.click(predict_species, species_img, [species_label, species_details, species_probs])
species_img.change(predict_species, species_img, [species_label, species_details, species_probs])
with gr.TabItem("🦠 Disease Detection"):
with gr.Row():
with gr.Column(scale=1):
disease_img = gr.Image(type="pil", label="Upload Leaf Image", height=300)
disease_btn = gr.Button("πŸ”¬ Detect Disease", variant="primary")
with gr.Column(scale=1):
disease_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1)
disease_details = gr.Textbox(label="DETAILS", interactive=False, lines=6)
disease_probs = gr.Label(label="CLASS PROBABILITIES", num_top_classes=8)
disease_btn.click(predict_disease, disease_img, [disease_label, disease_details, disease_probs])
disease_img.change(predict_disease, disease_img, [disease_label, disease_details, disease_probs])
with gr.TabItem("πŸ“Š Severity Assessment"):
with gr.Row():
with gr.Column(scale=1):
severity_img = gr.Image(type="pil", label="Upload Leaf Image", height=300)
severity_btn = gr.Button("πŸ“Š Assess Severity", variant="primary")
with gr.Column(scale=1):
severity_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1)
severity_details = gr.Textbox(label="DETAILS", interactive=False, lines=6)
severity_probs = gr.Label(label="SEVERITY PROBABILITIES", num_top_classes=5)
severity_btn.click(predict_severity, severity_img, [severity_label, severity_details, severity_probs])
severity_img.change(predict_severity, severity_img, [severity_label, severity_details, severity_probs])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)