Plant Disease Detection System
Upload a plant leaf image to detect diseases using AI
""" Plant Disease Detection Gradio App Main UI application with advanced features """ import gradio as gr import torch import sys from pathlib import Path import json from datetime import datetime # Add current directory to path sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent.parent)) from model_loader import ModelLoader import utils from utils import * import config from config import * class PlantDiseaseApp: def __init__(self): self.model_loader = ModelLoader() self.current_modelName = list(config.MODEL_CONFIGS.keys())[0] self.model = self.model_loader.loadModel(self.current_modelName) self.flagged_predictions = [] self.class_names = utils.get_class_names() def predict(self, image, modelName, confidence_threshold): """ Predict plant disease from a single image. Args: image: PIL Image or numpy array from Gradio upload modelName: Name of the model to use confidence_threshold: float (0-100), only show predictions above this confidence Returns: display_predictions: dict, class_name -> probability result_text: str, formatted top prediction info raw_predictions: str, JSON-formatted top predictions """ if image is None: return None, "Please upload an image", "" try: # Load model if needed if modelName != self.current_modelName: self.model = self.model_loader.loadModel(modelName) self.current_modelName = modelName # Preprocess image tensor = preprocess_image(image).to(self.model_loader.device) # Model inference with torch.no_grad(): logits = self.model(tensor) # Convert logits to probabilities probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0] predID = probs.argmax().item() print("predicted index: " + str(predID)) # Map to class names predictions = {name: float(prob) for name, prob in zip(self.class_names, probs)} # Filter by confidence threshold filtered_predictions = {k: v for k, v in predictions.items() if v >= confidence_threshold / 100.0} # Top prediction info if filtered_predictions: top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0] top_prob = filtered_predictions[top_class] disease_info = get_disease_info(top_class) result_text = f""" **Top Prediction:** {disease_info['formatted_name']} **Confidence:** {top_prob*100:.2f}% **Plant:** {disease_info['plant']} **Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'} """ else: result_text = "No predictions above confidence threshold" # Format for Gradio Label component display_predictions = {format_class_name(k): v for k, v in filtered_predictions.items()} # Raw JSON output import json raw_predictions = json.dumps(filtered_predictions, indent=2) return display_predictions, result_text, raw_predictions except Exception as e: return None, f"Error during prediction: {str(e)}", "" def flag_prediction(self, image, result_info, feedback_text): if image is None: return "No image uploaded." if not feedback_text.strip(): return "Please enter feedback before submitting." try: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") entry = { "timestamp": timestamp, "feedback": feedback_text, "model": self.current_modelName, "result_info": result_info } self.flagged_predictions.append(entry) return "Thanks! Your feedback has been recorded." except Exception as e: return f"Error saving feedback: {str(e)}" def create_interface(): app = PlantDiseaseApp() custom_css = """ .main-header { text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); padding: 2rem; border-radius: 10px; color: white; margin-bottom: 2rem; } .prediction-box { border: 2px solid #667eea; border-radius: 10px; padding: 1rem; background: #f8f9fa; } """ with gr.Blocks(css=custom_css, title="Plant Disease Detection") as demo: # Header gr.Markdown( """
Upload a plant leaf image to detect diseases using AI