| """ |
| PlantCity - Three Models in One Space (PyTorch Version): |
| Tab 1 Β· Species Classifier β 12 plant species with OOD detection |
| Tab 2 Β· Disease Classifier β 52 disease classes with OOD detection |
| Tab 3 Β· Severity Estimator β 5 severity levels (None, Low, Medium, High, Critical) |
| """ |
|
|
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| import torchvision.models as models |
| import cv2 |
| import traceback |
| import os |
| import requests |
| from pathlib import Path |
| import gdown |
|
|
| |
| |
| |
|
|
| IMG_SIZE = (224, 224) |
|
|
| |
| SPECIES_MODEL_FILE = "plant_species_final.pth" |
| DISEASE_MODEL_FILE = "plant_disease_final.pth" |
| SEVERITY_MODEL_FILE = "severity_final.pth" |
|
|
| |
| SPECIES_MODEL_ID = os.environ.get('SPECIES_MODEL_ID', '') |
| DISEASE_MODEL_ID = os.environ.get('DISEASE_MODEL_ID', '') |
| SEVERITY_MODEL_ID = os.environ.get('SEVERITY_MODEL_ID', '') |
|
|
| |
| LEAF_GREEN_THRESHOLD = 0.08 |
| LEAF_YELLOW_THRESHOLD = 0.08 |
|
|
| |
| SPECIES_OOD_THRESHOLD = 0.65 |
| SPECIES_CLASSES = [ |
| "Apple", "Apricot", "Bean", "Cherry", "Corn", "Fig", |
| "Grape", "Lokat", "Pear", "Walnut", "Persimmons", "Tomato" |
| ] |
| SPECIES_EMOJI = { |
| "Apple": "π", "Apricot": "π", "Bean": "π±", "Cherry": "π", |
| "Corn": "π½", "Fig": "π", "Grape": "π", "Lokat": "π", |
| "Pear": "π", "Walnut": "π°", "Persimmons": "π
", "Tomato": "π
" |
| } |
| NUM_SPECIES = len(SPECIES_CLASSES) |
|
|
| |
| DISEASE_OOD_THRESHOLD = 0.60 |
| DISEASE_CONF_THRESHOLD = 0.50 |
|
|
| DISEASE_CLASSES = [ |
| "Apple Brown_spot", "Apple Normal", "Apple black_spot", "Apricot Normal", |
| "Apricot blight leaf disease", "Apricot shot_hole", "Bean Fungal_leaf disease", |
| "Bean Normal leaf", "Bean bean rust image", "Bean shot_hole", "Cherry Leaf Scorch", |
| "Cherry Normal leaf", "Cherry brown_spot", "Cherry purple leaf spot", |
| "Cherry_shot hole disease", "Corn Fungal leaf", "Corn Normal leaf", |
| "Corn gray leaf spot", "Corn holcus_ leaf spot", "Fig Blight_leaf disease", |
| "Fig Brown spot", "Fig normal leaf", "Fig_rust leaf", "Grape Anthracnose leaf", |
| "Grape Brown spot leaf", "Grape Downy mildew leaf", "Grape Mites_leaf disease", |
| "Grape Normal_leaf", "Grape Powdery_mildew leaf", "Grape shot hole leaf disease", |
| "Lokat Normal leaf", "Pear Black spot _ leaf disease", "Pear Normal _leaf", |
| "Pear fire blight", "Walnut Anthracnose_leaf disease", "Walnut Blotch_leaf disease", |
| "Walnut Normal_leaf", "Walnut Shot_hole", "Walnut leaf gall mite", |
| "lokat Leaf_spot", "persimmons Brown_spot", "tomato Fusarium Wilt", |
| "tomato spider mites", "tomato verticillium wilt", "tomato_bacterial_spot", |
| "tomato_early_blight", "tomato_healthy_leaf", "tomato_late_blight", |
| "tomato_leaf_curl", "tomato_leaf_miner", "tomato_leaf_mold", "tomato_septoria_leaf" |
| ] |
| NUM_DISEASES = len(DISEASE_CLASSES) |
|
|
| |
| SEVERITY_CLASSES = ["None", "Low", "Medium", "High", "Critical"] |
| SEVERITY_LEVELS = { |
| "None": "Healthy β 0% affected", |
| "Low": "1β25% affected", |
| "Medium": "26β50% affected", |
| "High": "51β75% affected", |
| "Critical": ">75% affected (severe damage)" |
| } |
| SEVERITY_ICONS = { |
| "None": "β
", "Low": "π‘", "Medium": "π ", "High": "π΄", "Critical": "π΄π₯" |
| } |
| NUM_SEVERITY = len(SEVERITY_CLASSES) |
|
|
| |
| |
| |
|
|
| class EfficientNetSpecies(nn.Module): |
| """Species identification model with OOD head""" |
| def __init__(self, num_classes=12): |
| super(EfficientNetSpecies, self).__init__() |
| |
| self.backbone = models.efficientnet_b0(weights=None) |
| num_features = self.backbone.classifier[1].in_features |
| |
| |
| self.backbone.classifier = nn.Sequential( |
| nn.Dropout(0.3), |
| nn.Linear(num_features, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(256, num_classes) |
| ) |
| |
| |
| self.ood_head = nn.Sequential( |
| nn.Linear(num_features, 128), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(128, 1) |
| ) |
| |
| def forward(self, x): |
| features = self.backbone.features(x) |
| features = self.backbone.avgpool(features) |
| features = torch.flatten(features, 1) |
| species_logits = self.backbone.classifier(features) |
| ood_logits = self.ood_head(features) |
| return species_logits, ood_logits |
|
|
| class EfficientNetDisease(nn.Module): |
| """Disease classification model""" |
| def __init__(self, num_classes=52): |
| super(EfficientNetDisease, self).__init__() |
| self.backbone = models.efficientnet_b0(weights=None) |
| num_features = self.backbone.classifier[1].in_features |
| self.backbone.classifier = nn.Sequential( |
| nn.Dropout(0.3), |
| nn.Linear(num_features, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(256, num_classes) |
| ) |
| |
| def forward(self, x): |
| return self.backbone(x) |
|
|
| class EfficientNetSeverity(nn.Module): |
| """Severity assessment model""" |
| def __init__(self, num_classes=5): |
| super(EfficientNetSeverity, self).__init__() |
| self.backbone = models.efficientnet_b0(weights=None) |
| num_features = self.backbone.classifier[1].in_features |
| self.backbone.classifier = nn.Sequential( |
| nn.Dropout(0.3), |
| nn.Linear(num_features, 128), |
| nn.BatchNorm1d(128), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(128, num_classes) |
| ) |
| |
| def forward(self, x): |
| return self.backbone(x) |
|
|
| |
| |
| |
|
|
| _models = {} |
| _model_status = {} |
|
|
| def download_file_from_google_drive(file_id, destination): |
| """Download file from Google Drive""" |
| try: |
| url = f"https://drive.google.com/uc?id={file_id}" |
| gdown.download(url, destination, quiet=False) |
| return os.path.exists(destination) |
| except Exception as e: |
| print(f"Error downloading from Google Drive: {e}") |
| return False |
|
|
| def download_model(file_id, output_path): |
| """Download model from Google Drive using file ID""" |
| if not file_id or file_id == '': |
| print(f"No file ID provided for {output_path}") |
| return False |
| |
| try: |
| os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) |
| print(f"Downloading model to {output_path}...") |
| |
| if download_file_from_google_drive(file_id, output_path): |
| print(f"Successfully downloaded {output_path}") |
| return True |
| else: |
| print(f"Download failed for {output_path}") |
| return False |
| |
| except Exception as e: |
| print(f"Error downloading model: {e}") |
| return False |
|
|
| def load_model(model_type, model_file, file_id, model_class, num_classes): |
| """Load PyTorch model with weights""" |
| global _models, _model_status |
| |
| |
| if model_type in _models: |
| return _models[model_type] |
| |
| |
| local_path = model_file |
| if os.path.exists(local_path): |
| model_path = local_path |
| print(f"Found local model: {model_path}") |
| else: |
| model_path = f"/tmp/{model_file}" |
| if not os.path.exists(model_path): |
| print(f"Model file not found. Downloading from Google Drive...") |
| if download_model(file_id, model_path): |
| print(f"Model downloaded successfully to {model_path}") |
| else: |
| _model_status[model_type] = f"Failed to download model." |
| return None |
| |
| |
| try: |
| print(f"Loading model from {model_path}") |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| |
| model = model_class(num_classes=num_classes) |
| |
| |
| state_dict = torch.load(model_path, map_location=device) |
| |
| |
| if 'model_state_dict' in state_dict: |
| state_dict = state_dict['model_state_dict'] |
| elif 'state_dict' in state_dict: |
| state_dict = state_dict['state_dict'] |
| |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| name = k.replace('module.', '') |
| new_state_dict[name] = v |
| |
| model.load_state_dict(new_state_dict, strict=False) |
| model = model.to(device) |
| model.eval() |
| |
| _models[model_type] = model |
| _model_status[model_type] = "Loaded successfully" |
| print(f"Model {model_type} loaded successfully on {device}") |
| return model |
| |
| except Exception as e: |
| error_msg = f"Error loading model: {str(e)}" |
| print(error_msg) |
| _model_status[model_type] = error_msg |
| return None |
|
|
| |
| |
| |
|
|
| def get_transform(): |
| """Get image transformation pipeline""" |
| return transforms.Compose([ |
| transforms.Resize(IMG_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| def detect_leaf(image: Image.Image) -> tuple: |
| """Detect if image contains a leaf using green-yellow pixel ratio""" |
| img_array = np.array(image.convert("RGB")) |
| h, w = img_array.shape[:2] |
| |
| |
| img_float = img_array.astype(np.float32) / 255.0 |
| r, g, b = img_float[..., 0], img_float[..., 1], img_float[..., 2] |
| |
| |
| green_mask = (g > r + 0.05) & (g > b + 0.05) & (g > 0.15) |
| green_ratio = green_mask.mean() |
| |
| |
| yellow_mask = (r > 0.3) & (g > 0.3) & (b < 0.35) & (np.abs(r - g) < 0.3) |
| yellow_ratio = yellow_mask.mean() |
| |
| |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) |
| edges = cv2.Canny(blurred, 50, 150) |
| edge_complexity = np.sum(edges > 0) / (h * w) |
| |
| |
| is_leaf = (green_ratio + yellow_ratio) / 2 > LEAF_GREEN_THRESHOLD or green_ratio > 0.05 |
| |
| details = { |
| "green_ratio": f"{green_ratio:.2%}", |
| "yellow_ratio": f"{yellow_ratio:.2%}", |
| "edge_complexity": f"{edge_complexity:.2%}" |
| } |
| |
| return is_leaf, details |
|
|
| def preprocess_image(image: Image.Image): |
| """Preprocess image for PyTorch model""" |
| transform = get_transform() |
| img_tensor = transform(image.convert("RGB")) |
| return img_tensor.unsqueeze(0) |
|
|
| def parse_disease(disease_name): |
| """Parse disease name to extract plant and condition""" |
| parts = disease_name.split('_') |
| plant = parts[0] |
| if len(parts) > 1: |
| condition = ' '.join(parts[1:]) |
| else: |
| condition = "Unknown" |
| healthy = "Normal" in disease_name or "healthy" in disease_name.lower() |
| return plant, condition, healthy |
|
|
| |
| |
| |
|
|
| def predict_species(image): |
| """Predict plant species from leaf image""" |
| if image is None: |
| return ("β οΈ No Image", "Please upload a leaf image to get started.", {}) |
| |
| |
| model = load_model('species', SPECIES_MODEL_FILE, SPECIES_MODEL_ID, EfficientNetSpecies, NUM_SPECIES) |
| if model is None: |
| return ("β Model Error", f"Species model not available.\n\n{_model_status.get('species', 'Unknown')}", {}) |
| |
| |
| is_leaf, details = detect_leaf(image) |
| if not is_leaf: |
| leaf_details = f""" |
| **Not a Leaf Image** |
| |
| This image does not appear to be a plant leaf. |
| |
| **Detection Metrics:** |
| - Green ratio: {details['green_ratio']} |
| - Yellow ratio: {details['yellow_ratio']} |
| - Edge complexity: {details['edge_complexity']} |
| |
| Please upload a clear photo of a plant leaf. |
| """ |
| return ("π« Not a Leaf", leaf_details, {}) |
| |
| try: |
| device = next(model.parameters()).device |
| img_tensor = preprocess_image(image).to(device) |
| |
| with torch.no_grad(): |
| species_logits, ood_logits = model(img_tensor) |
| species_probs = torch.softmax(species_logits, dim=1) |
| ood_score = torch.sigmoid(ood_logits).item() |
| |
| top_conf, top_idx = torch.max(species_probs, dim=1) |
| top_conf = top_conf.item() |
| top_idx = top_idx.item() |
| top_species = SPECIES_CLASSES[top_idx] |
| |
| conf_dict = {c: float(p) for c, p in zip(SPECIES_CLASSES, species_probs[0].cpu().numpy())} |
| |
| if top_conf < SPECIES_OOD_THRESHOLD: |
| return ( |
| "π Unknown Species", |
| f"Confidence {top_conf:.1%} below threshold ({SPECIES_OOD_THRESHOLD:.0%}).", |
| conf_dict |
| ) |
| |
| emoji = SPECIES_EMOJI.get(top_species, "πΏ") |
| result_text = f"**Species:** {top_species}\n**Confidence:** {top_conf:.1%}\n**Leaf Detection:** β" |
| |
| return (f"{emoji} {top_species}", result_text, conf_dict) |
| |
| except Exception as e: |
| return ("β Error", f"{type(e).__name__}: {str(e)}", {}) |
|
|
| def predict_disease(image): |
| """Predict disease from leaf image""" |
| if image is None: |
| return ("β οΈ No Image", "Please upload a leaf image.", {}) |
| |
| |
| model = load_model('disease', DISEASE_MODEL_FILE, DISEASE_MODEL_ID, EfficientNetDisease, NUM_DISEASES) |
| if model is None: |
| return ("β Model Error", f"Disease model not available.\n\n{_model_status.get('disease', 'Unknown')}", {}) |
| |
| |
| is_leaf, details = detect_leaf(image) |
| if not is_leaf: |
| return ("π« Not a Leaf", f"Green ratio: {details['green_ratio']}", {}) |
| |
| try: |
| device = next(model.parameters()).device |
| img_tensor = preprocess_image(image).to(device) |
| |
| with torch.no_grad(): |
| disease_logits = model(img_tensor) |
| disease_probs = torch.softmax(disease_logits, dim=1) |
| top_conf, top_idx = torch.max(disease_probs, dim=1) |
| top_conf = top_conf.item() |
| top_idx = top_idx.item() |
| top_disease = DISEASE_CLASSES[top_idx] |
| |
| conf_dict = {c: float(p) for c, p in zip(DISEASE_CLASSES, disease_probs[0].cpu().numpy())} |
| |
| if top_conf < DISEASE_OOD_THRESHOLD: |
| return ("π Unknown", f"Confidence {top_conf:.1%} below threshold.", conf_dict) |
| |
| plant, condition, healthy = parse_disease(top_disease) |
| |
| if healthy: |
| return (f"β
{plant} β Healthy", f"Healthy leaf. Confidence: {top_conf:.1%}", conf_dict) |
| |
| if top_conf < DISEASE_CONF_THRESHOLD: |
| return (f"β οΈ {plant} β Low Confidence", f"Possible: {condition}\nConfidence: {top_conf:.1%}", conf_dict) |
| |
| return (f"π¦ {plant} β {condition}", f"Detected: {condition}\nConfidence: {top_conf:.1%}", conf_dict) |
| |
| except Exception as e: |
| return ("β Error", f"{type(e).__name__}: {str(e)}", {}) |
|
|
| def predict_severity(image): |
| """Predict disease severity level""" |
| if image is None: |
| return ("β οΈ No Image", "Please upload a leaf image.", {}) |
| |
| |
| model = load_model('severity', SEVERITY_MODEL_FILE, SEVERITY_MODEL_ID, EfficientNetSeverity, NUM_SEVERITY) |
| if model is None: |
| return ("β Model Error", f"Severity model not available.\n\n{_model_status.get('severity', 'Unknown')}", {}) |
| |
| |
| is_leaf, details = detect_leaf(image) |
| if not is_leaf: |
| return ("π« Not a Leaf", f"Green ratio: {details['green_ratio']}", {}) |
| |
| try: |
| device = next(model.parameters()).device |
| img_tensor = preprocess_image(image).to(device) |
| |
| with torch.no_grad(): |
| severity_logits = model(img_tensor) |
| severity_probs = torch.softmax(severity_logits, dim=1) |
| top_conf, top_idx = torch.max(severity_probs, dim=1) |
| top_conf = top_conf.item() |
| top_idx = top_idx.item() |
| top_severity = SEVERITY_CLASSES[top_idx] |
| |
| conf_dict = {c: float(p) for c, p in zip(SEVERITY_CLASSES, severity_probs[0].cpu().numpy())} |
| |
| icon = SEVERITY_ICONS[top_severity] |
| desc = SEVERITY_LEVELS[top_severity] |
| |
| result_text = f"**Severity:** {top_severity}\n**Description:** {desc}\n**Confidence:** {top_conf:.1%}" |
| |
| return (f"{icon} {top_severity}", result_text, conf_dict) |
| |
| except Exception as e: |
| return ("β Error", f"{type(e).__name__}: {str(e)}", {}) |
|
|
| |
| |
| |
|
|
| CSS = """ |
| body, .gradio-container { background: #0f1a14 !important; } |
| .header { background: linear-gradient(135deg, #0a1f14, #1a4028); padding: 2rem; border-radius: 18px; text-align: center; margin-bottom: 1rem; } |
| .header h1 { color: #4ade80; font-size: 2.4rem; margin: 0; } |
| .tab-nav button { background: #16261d !important; color: #86efac !important; } |
| .tab-nav button.selected { background: #166534 !important; color: #4ade80 !important; } |
| button.primary { background: linear-gradient(135deg, #166534, #15803d) !important; color: white !important; } |
| """ |
|
|
| print("=" * 50) |
| print("Starting PlantCity AgriGuard (PyTorch Version)") |
| print(f"Species Model ID: {'SET' if SPECIES_MODEL_ID else 'NOT SET'}") |
| print(f"Disease Model ID: {'SET' if DISEASE_MODEL_ID else 'NOT SET'}") |
| print(f"Severity Model ID: {'SET' if SEVERITY_MODEL_ID else 'NOT SET'}") |
| print("=" * 50) |
|
|
| with gr.Blocks(title="PlantCity - AgriGuard", theme=gr.themes.Soft(), css=CSS) as demo: |
| |
| gr.HTML(""" |
| <div class="header"> |
| <h1>πΏAgriGuard</h1> |
| <p>Species Identification Β· Disease Detection Β· Severity Assessment</p> |
| <p style="font-size: 0.85rem;">12 Species | 52 Diseases | 5 Severity Levels</p> |
| </div> |
| """) |
| |
| with gr.Tabs(): |
| |
| with gr.TabItem("πΏ Species Identification"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| species_img = gr.Image(type="pil", label="Upload Leaf Image", height=300) |
| species_btn = gr.Button("π Identify Species", variant="primary") |
| |
| with gr.Column(scale=1): |
| species_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1) |
| species_details = gr.Textbox(label="DETAILS", interactive=False, lines=6) |
| species_probs = gr.Label(label="CLASS PROBABILITIES", num_top_classes=8) |
| |
| species_btn.click(predict_species, species_img, [species_label, species_details, species_probs]) |
| species_img.change(predict_species, species_img, [species_label, species_details, species_probs]) |
| |
| with gr.TabItem("π¦ Disease Detection"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| disease_img = gr.Image(type="pil", label="Upload Leaf Image", height=300) |
| disease_btn = gr.Button("π¬ Detect Disease", variant="primary") |
| |
| with gr.Column(scale=1): |
| disease_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1) |
| disease_details = gr.Textbox(label="DETAILS", interactive=False, lines=6) |
| disease_probs = gr.Label(label="CLASS PROBABILITIES", num_top_classes=8) |
| |
| disease_btn.click(predict_disease, disease_img, [disease_label, disease_details, disease_probs]) |
| disease_img.change(predict_disease, disease_img, [disease_label, disease_details, disease_probs]) |
| |
| with gr.TabItem("π Severity Assessment"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| severity_img = gr.Image(type="pil", label="Upload Leaf Image", height=300) |
| severity_btn = gr.Button("π Assess Severity", variant="primary") |
| |
| with gr.Column(scale=1): |
| severity_label = gr.Textbox(label="PREDICTION", interactive=False, lines=1) |
| severity_details = gr.Textbox(label="DETAILS", interactive=False, lines=6) |
| severity_probs = gr.Label(label="SEVERITY PROBABILITIES", num_top_classes=5) |
| |
| severity_btn.click(predict_severity, severity_img, [severity_label, severity_details, severity_probs]) |
| severity_img.change(predict_severity, severity_img, [severity_label, severity_details, severity_probs]) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |