""" PlantCity - Three Models in One Space (PyTorch Version): Tab 1 · Species Classifier → 12 plant species with OOD detection Tab 2 · Disease Classifier → 52 disease classes with OOD detection Tab 3 · Severity Estimator → 5 severity levels (None, Low, Medium, High, Critical) """ import gradio as gr import numpy as np from PIL import Image import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models import cv2 import traceback import os import requests from pathlib import Path import gdown # ═══════════════════════════════════════════════════════════════════════════════ # CONFIGURATION # ═══════════════════════════════════════════════════════════════════════════════ IMG_SIZE = (224, 224) # ── Model file names (PyTorch .pth format) ───────────────────────────────────── SPECIES_MODEL_FILE = "plant_species_final.pth" DISEASE_MODEL_FILE = "plant_disease_final.pth" SEVERITY_MODEL_FILE = "severity_final.pth" # ── Google Drive file IDs (from Hugging Face Secrets) ───────────────────────── SPECIES_MODEL_ID = os.environ.get('SPECIES_MODEL_ID', '') DISEASE_MODEL_ID = os.environ.get('DISEASE_MODEL_ID', '') SEVERITY_MODEL_ID = os.environ.get('SEVERITY_MODEL_ID', '') # ── Leaf Detection Thresholds ─────────────────────────────────────────────────── LEAF_GREEN_THRESHOLD = 0.08 LEAF_YELLOW_THRESHOLD = 0.08 # ── Model 1: Species (12 classes) ────────────────────────────────────────────── SPECIES_OOD_THRESHOLD = 0.65 SPECIES_CLASSES = [ "Apple", "Apricot", "Bean", "Cherry", "Corn", "Fig", "Grape", "Lokat", "Pear", "Walnut", "Persimmons", "Tomato" ] SPECIES_EMOJI = { "Apple": "🍎", "Apricot": "🍑", "Bean": "🌱", "Cherry": "🍒", "Corn": "🌽", "Fig": "🍈", "Grape": "🍇", "Lokat": "🍊", "Pear": "🍐", "Walnut": "🌰", "Persimmons": "🍅", "Tomato": "🍅" } NUM_SPECIES = len(SPECIES_CLASSES) # ── Model 2: Disease (52 classes) ────────────────────────────────────────────── DISEASE_OOD_THRESHOLD = 0.60 DISEASE_CONF_THRESHOLD = 0.50 DISEASE_CLASSES = [ "Apple Brown_spot", "Apple Normal", "Apple black_spot", "Apricot Normal", "Apricot blight leaf disease", "Apricot shot_hole", "Bean Fungal_leaf disease", "Bean Normal leaf", "Bean bean rust image", "Bean shot_hole", "Cherry Leaf Scorch", "Cherry Normal leaf", "Cherry brown_spot", "Cherry purple leaf spot", "Cherry_shot hole disease", "Corn Fungal leaf", "Corn Normal leaf", "Corn gray leaf spot", "Corn holcus_ leaf spot", "Fig Blight_leaf disease", "Fig Brown spot", "Fig normal leaf", "Fig_rust leaf", "Grape Anthracnose leaf", "Grape Brown spot leaf", "Grape Downy mildew leaf", "Grape Mites_leaf disease", "Grape Normal_leaf", "Grape Powdery_mildew leaf", "Grape shot hole leaf disease", "Lokat Normal leaf", "Pear Black spot _ leaf disease", "Pear Normal _leaf", "Pear fire blight", "Walnut Anthracnose_leaf disease", "Walnut Blotch_leaf disease", "Walnut Normal_leaf", "Walnut Shot_hole", "Walnut leaf gall mite", "lokat Leaf_spot", "persimmons Brown_spot", "tomato Fusarium Wilt", "tomato spider mites", "tomato verticillium wilt", "tomato_bacterial_spot", "tomato_early_blight", "tomato_healthy_leaf", "tomato_late_blight", "tomato_leaf_curl", "tomato_leaf_miner", "tomato_leaf_mold", "tomato_septoria_leaf" ] NUM_DISEASES = len(DISEASE_CLASSES) # ── Model 3: Severity (5 levels) ─────────────────────────────────────────────── SEVERITY_CLASSES = ["None", "Low", "Medium", "High", "Critical"] SEVERITY_LEVELS = { "None": "Healthy — 0% affected", "Low": "1–25% affected", "Medium": "26–50% affected", "High": "51–75% affected", "Critical": ">75% affected (severe damage)" } SEVERITY_ICONS = { "None": "✅", "Low": "🟡", "Medium": "🟠", "High": "🔴", "Critical": "🔴🔥" } NUM_SEVERITY = len(SEVERITY_CLASSES) # ═══════════════════════════════════════════════════════════════════════════════ # MODEL ARCHITECTURES # ═══════════════════════════════════════════════════════════════════════════════ class EfficientNetSpecies(nn.Module): """Species identification model with OOD head""" def __init__(self, num_classes=12): super(EfficientNetSpecies, self).__init__() # Use EfficientNet-B0 as backbone self.backbone = models.efficientnet_b0(weights=None) num_features = self.backbone.classifier[1].in_features # Species classifier self.backbone.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(num_features, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) # OOD head self.ood_head = nn.Sequential( nn.Linear(num_features, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 1) ) def forward(self, x): features = self.backbone.features(x) features = self.backbone.avgpool(features) features = torch.flatten(features, 1) species_logits = self.backbone.classifier(features) ood_logits = self.ood_head(features) return species_logits, ood_logits class EfficientNetDisease(nn.Module): """Disease classification model""" def __init__(self, num_classes=52): super(EfficientNetDisease, self).__init__() self.backbone = models.efficientnet_b0(weights=None) num_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(num_features, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, num_classes) ) def forward(self, x): return self.backbone(x) class EfficientNetSeverity(nn.Module): """Severity assessment model""" def __init__(self, num_classes=5): super(EfficientNetSeverity, self).__init__() self.backbone = models.efficientnet_b0(weights=None) num_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(num_features, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_classes) ) def forward(self, x): return self.backbone(x) # ═══════════════════════════════════════════════════════════════════════════════ # MODEL DOWNLOAD AND LOADING # ═══════════════════════════════════════════════════════════════════════════════ _models = {} _model_status = {} def download_file_from_google_drive(file_id, destination): """Download file from Google Drive""" try: url = f"https://drive.google.com/uc?id={file_id}" gdown.download(url, destination, quiet=False) return os.path.exists(destination) except Exception as e: print(f"Error downloading from Google Drive: {e}") return False def download_model(file_id, output_path): """Download model from Google Drive using file ID""" if not file_id or file_id == '': print(f"No file ID provided for {output_path}") return False try: os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) print(f"Downloading model to {output_path}...") if download_file_from_google_drive(file_id, output_path): print(f"Successfully downloaded {output_path}") return True else: print(f"Download failed for {output_path}") return False except Exception as e: print(f"Error downloading model: {e}") return False def load_model(model_type, model_file, file_id, model_class, num_classes): """Load PyTorch model with weights""" global _models, _model_status # Check if already loaded if model_type in _models: return _models[model_type] # Check local file first local_path = model_file if os.path.exists(local_path): model_path = local_path print(f"Found local model: {model_path}") else: model_path = f"/tmp/{model_file}" if not os.path.exists(model_path): print(f"Model file not found. Downloading from Google Drive...") if download_model(file_id, model_path): print(f"Model downloaded successfully to {model_path}") else: _model_status[model_type] = f"Failed to download model." return None # Load the model try: print(f"Loading model from {model_path}") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Create model instance model = model_class(num_classes=num_classes) # Load state dict state_dict = torch.load(model_path, map_location=device) # Handle different state dict formats if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] elif 'state_dict' in state_dict: state_dict = state_dict['state_dict'] # Remove 'module.' prefix if present (from DataParallel) new_state_dict = {} for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=False) model = model.to(device) model.eval() _models[model_type] = model _model_status[model_type] = "Loaded successfully" print(f"Model {model_type} loaded successfully on {device}") return model except Exception as e: error_msg = f"Error loading model: {str(e)}" print(error_msg) _model_status[model_type] = error_msg return None # ═══════════════════════════════════════════════════════════════════════════════ # IMAGE PREPROCESSING # ═══════════════════════════════════════════════════════════════════════════════ def get_transform(): """Get image transformation pipeline""" return transforms.Compose([ transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def detect_leaf(image: Image.Image) -> tuple: """Detect if image contains a leaf using green-yellow pixel ratio""" img_array = np.array(image.convert("RGB")) h, w = img_array.shape[:2] # Normalize to [0, 1] img_float = img_array.astype(np.float32) / 255.0 r, g, b = img_float[..., 0], img_float[..., 1], img_float[..., 2] # Green pixel detection green_mask = (g > r + 0.05) & (g > b + 0.05) & (g > 0.15) green_ratio = green_mask.mean() # Yellow pixel detection yellow_mask = (r > 0.3) & (g > 0.3) & (b < 0.35) & (np.abs(r - g) < 0.3) yellow_ratio = yellow_mask.mean() # Edge detection gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) blurred = cv2.GaussianBlur(gray, (5, 5), 0) edges = cv2.Canny(blurred, 50, 150) edge_complexity = np.sum(edges > 0) / (h * w) # Determine if leaf is_leaf = (green_ratio + yellow_ratio) / 2 > LEAF_GREEN_THRESHOLD or green_ratio > 0.05 details = { "green_ratio": f"{green_ratio:.2%}", "yellow_ratio": f"{yellow_ratio:.2%}", "edge_complexity": f"{edge_complexity:.2%}" } return is_leaf, details def preprocess_image(image: Image.Image): """Preprocess image for PyTorch model""" transform = get_transform() img_tensor = transform(image.convert("RGB")) return img_tensor.unsqueeze(0) # Add batch dimension def parse_disease(disease_name): """Parse disease name to extract plant and condition""" parts = disease_name.split('_') plant = parts[0] if len(parts) > 1: condition = ' '.join(parts[1:]) else: condition = "Unknown" healthy = "Normal" in disease_name or "healthy" in disease_name.lower() return plant, condition, healthy # ═══════════════════════════════════════════════════════════════════════════════ # PREDICTION FUNCTIONS # ═══════════════════════════════════════════════════════════════════════════════ def predict_species(image): """Predict plant species from leaf image""" if image is None: return ("⚠️ No Image", "Please upload a leaf image to get started.", {}) # Load model model = load_model('species', SPECIES_MODEL_FILE, SPECIES_MODEL_ID, EfficientNetSpecies, NUM_SPECIES) if model is None: return ("❌ Model Error", f"Species model not available.\n\n{_model_status.get('species', 'Unknown')}", {}) # Check if it's a leaf is_leaf, details = detect_leaf(image) if not is_leaf: leaf_details = f""" **Not a Leaf Image** This image does not appear to be a plant leaf. **Detection Metrics:** - Green ratio: {details['green_ratio']} - Yellow ratio: {details['yellow_ratio']} - Edge complexity: {details['edge_complexity']} Please upload a clear photo of a plant leaf. """ return ("🚫 Not a Leaf", leaf_details, {}) try: device = next(model.parameters()).device img_tensor = preprocess_image(image).to(device) with torch.no_grad(): species_logits, ood_logits = model(img_tensor) species_probs = torch.softmax(species_logits, dim=1) ood_score = torch.sigmoid(ood_logits).item() top_conf, top_idx = torch.max(species_probs, dim=1) top_conf = top_conf.item() top_idx = top_idx.item() top_species = SPECIES_CLASSES[top_idx] conf_dict = {c: float(p) for c, p in zip(SPECIES_CLASSES, species_probs[0].cpu().numpy())} if top_conf < SPECIES_OOD_THRESHOLD: return ( "🔍 Unknown Species", f"Confidence {top_conf:.1%} below threshold ({SPECIES_OOD_THRESHOLD:.0%}).", conf_dict ) emoji = SPECIES_EMOJI.get(top_species, "🌿") result_text = f"**Species:** {top_species}\n**Confidence:** {top_conf:.1%}\n**Leaf Detection:** ✓" return (f"{emoji} {top_species}", result_text, conf_dict) except Exception as e: return ("❌ Error", f"{type(e).__name__}: {str(e)}", {}) def predict_disease(image): """Predict disease from leaf image""" if image is None: return ("⚠️ No Image", "Please upload a leaf image.", {}) # Load model model = load_model('disease', DISEASE_MODEL_FILE, DISEASE_MODEL_ID, EfficientNetDisease, NUM_DISEASES) if model is None: return ("❌ Model Error", f"Disease model not available.\n\n{_model_status.get('disease', 'Unknown')}", {}) # Check if it's a leaf is_leaf, details = detect_leaf(image) if not is_leaf: return ("🚫 Not a Leaf", f"Green ratio: {details['green_ratio']}", {}) try: device = next(model.parameters()).device img_tensor = preprocess_image(image).to(device) with torch.no_grad(): disease_logits = model(img_tensor) disease_probs = torch.softmax(disease_logits, dim=1) top_conf, top_idx = torch.max(disease_probs, dim=1) top_conf = top_conf.item() top_idx = top_idx.item() top_disease = DISEASE_CLASSES[top_idx] conf_dict = {c: float(p) for c, p in zip(DISEASE_CLASSES, disease_probs[0].cpu().numpy())} if top_conf < DISEASE_OOD_THRESHOLD: return ("🔍 Unknown", f"Confidence {top_conf:.1%} below threshold.", conf_dict) plant, condition, healthy = parse_disease(top_disease) if healthy: return (f"✅ {plant} — Healthy", f"Healthy leaf. Confidence: {top_conf:.1%}", conf_dict) if top_conf < DISEASE_CONF_THRESHOLD: return (f"⚠️ {plant} — Low Confidence", f"Possible: {condition}\nConfidence: {top_conf:.1%}", conf_dict) return (f"🦠 {plant} — {condition}", f"Detected: {condition}\nConfidence: {top_conf:.1%}", conf_dict) except Exception as e: return ("❌ Error", f"{type(e).__name__}: {str(e)}", {}) def predict_severity(image): """Predict disease severity level""" if image is None: return ("⚠️ No Image", "Please upload a leaf image.", {}) # Load model model = load_model('severity', SEVERITY_MODEL_FILE, SEVERITY_MODEL_ID, EfficientNetSeverity, NUM_SEVERITY) if model is None: return ("❌ Model Error", f"Severity model not available.\n\n{_model_status.get('severity', 'Unknown')}", {}) # Check if it's a leaf is_leaf, details = detect_leaf(image) if not is_leaf: return ("🚫 Not a Leaf", f"Green ratio: {details['green_ratio']}", {}) try: device = next(model.parameters()).device img_tensor = preprocess_image(image).to(device) with torch.no_grad(): severity_logits = model(img_tensor) severity_probs = torch.softmax(severity_logits, dim=1) top_conf, top_idx = torch.max(severity_probs, dim=1) top_conf = top_conf.item() top_idx = top_idx.item() top_severity = SEVERITY_CLASSES[top_idx] conf_dict = {c: float(p) for c, p in zip(SEVERITY_CLASSES, severity_probs[0].cpu().numpy())} icon = SEVERITY_ICONS[top_severity] desc = SEVERITY_LEVELS[top_severity] result_text = f"**Severity:** {top_severity}\n**Description:** {desc}\n**Confidence:** {top_conf:.1%}" return (f"{icon} {top_severity}", result_text, conf_dict) except Exception as e: return ("❌ Error", f"{type(e).__name__}: {str(e)}", {}) # ═══════════════════════════════════════════════════════════════════════════════ # BUILD INTERFACE # ═══════════════════════════════════════════════════════════════════════════════ CSS = """ body, .gradio-container { background: #0f1a14 !important; } .header { background: linear-gradient(135deg, #0a1f14, #1a4028); padding: 2rem; border-radius: 18px; text-align: center; margin-bottom: 1rem; } .header h1 { color: #4ade80; font-size: 2.4rem; margin: 0; } .tab-nav button { background: #16261d !important; color: #86efac !important; } .tab-nav button.selected { background: #166534 !important; color: #4ade80 !important; } button.primary { background: linear-gradient(135deg, #166534, #15803d) !important; color: white !important; } """ print("=" * 50) print("Starting PlantCity AgriGuard (PyTorch Version)") print(f"Species Model ID: {'SET' if SPECIES_MODEL_ID else 'NOT SET'}") print(f"Disease Model ID: {'SET' if DISEASE_MODEL_ID else 'NOT SET'}") print(f"Severity Model ID: {'SET' if SEVERITY_MODEL_ID else 'NOT SET'}") print("=" * 50) with gr.Blocks(title="PlantCity - AgriGuard", theme=gr.themes.Soft(), css=CSS) as demo: gr.HTML("""

🌿AgriGuard

Species Identification · Disease Detection · Severity Assessment

12 Species | 52 Diseases | 5 Severity Levels

""") with gr.Tabs(): with gr.TabItem("🌿 Species Identification"): with gr.Row(): with gr.Column(scale=1): species_img = gr.Image(type="pil", label="Upload Leaf Image", height=300) species_btn = gr.Button("🔍 Identify Species", variant="primary") with gr.Column(scale=1): species_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1) species_details = gr.Textbox(label="DETAILS", interactive=False, lines=6) species_probs = gr.Label(label="CLASS PROBABILITIES", num_top_classes=8) species_btn.click(predict_species, species_img, [species_label, species_details, species_probs]) species_img.change(predict_species, species_img, [species_label, species_details, species_probs]) with gr.TabItem("🦠 Disease Detection"): with gr.Row(): with gr.Column(scale=1): disease_img = gr.Image(type="pil", label="Upload Leaf Image", height=300) disease_btn = gr.Button("🔬 Detect Disease", variant="primary") with gr.Column(scale=1): disease_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1) disease_details = gr.Textbox(label="DETAILS", interactive=False, lines=6) disease_probs = gr.Label(label="CLASS PROBABILITIES", num_top_classes=8) disease_btn.click(predict_disease, disease_img, [disease_label, disease_details, disease_probs]) disease_img.change(predict_disease, disease_img, [disease_label, disease_details, disease_probs]) with gr.TabItem("📊 Severity Assessment"): with gr.Row(): with gr.Column(scale=1): severity_img = gr.Image(type="pil", label="Upload Leaf Image", height=300) severity_btn = gr.Button("📊 Assess Severity", variant="primary") with gr.Column(scale=1): severity_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1) severity_details = gr.Textbox(label="DETAILS", interactive=False, lines=6) severity_probs = gr.Label(label="SEVERITY PROBABILITIES", num_top_classes=5) severity_btn.click(predict_severity, severity_img, [severity_label, severity_details, severity_probs]) severity_img.change(predict_severity, severity_img, [severity_label, severity_details, severity_probs]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)