""" Plant Disease Detection Gradio App Main UI application with advanced features """ import gradio as gr import torch import sys from pathlib import Path import json from datetime import datetime # Add current directory to path sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent.parent)) from model_loader import ModelLoader import utils from utils import * import config from config import * class PlantDiseaseApp: def __init__(self): self.model_loader = ModelLoader() self.current_modelName = list(config.MODEL_CONFIGS.keys())[0] self.model = self.model_loader.loadModel(self.current_modelName) self.flagged_predictions = [] self.class_names = utils.get_class_names() def predict(self, image, modelName, confidence_threshold): """ Predict plant disease from a single image. Args: image: PIL Image or numpy array from Gradio upload modelName: Name of the model to use confidence_threshold: float (0-100), only show predictions above this confidence Returns: display_predictions: dict, class_name -> probability result_text: str, formatted top prediction info raw_predictions: str, JSON-formatted top predictions """ if image is None: return None, "Please upload an image", "" try: # Load model if needed if modelName != self.current_modelName: self.model = self.model_loader.loadModel(modelName) self.current_modelName = modelName # Preprocess image tensor = preprocess_image(image).to(self.model_loader.device) # Model inference with torch.no_grad(): logits = self.model(tensor) # Convert logits to probabilities probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0] predID = probs.argmax().item() print("predicted index: " + str(predID)) # Map to class names predictions = {name: float(prob) for name, prob in zip(self.class_names, probs)} # Filter by confidence threshold filtered_predictions = {k: v for k, v in predictions.items() if v >= confidence_threshold / 100.0} # Top prediction info if filtered_predictions: top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0] top_prob = filtered_predictions[top_class] disease_info = get_disease_info(top_class) result_text = f""" **Top Prediction:** {disease_info['formatted_name']} **Confidence:** {top_prob*100:.2f}% **Plant:** {disease_info['plant']} **Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'} """ else: result_text = "No predictions above confidence threshold" # Format for Gradio Label component display_predictions = {format_class_name(k): v for k, v in filtered_predictions.items()} # Raw JSON output import json raw_predictions = json.dumps(filtered_predictions, indent=2) return display_predictions, result_text, raw_predictions except Exception as e: return None, f"Error during prediction: {str(e)}", "" def flag_prediction(self, image, result_info, feedback_text): if image is None: return "No image uploaded." if not feedback_text.strip(): return "Please enter feedback before submitting." try: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") entry = { "timestamp": timestamp, "feedback": feedback_text, "model": self.current_modelName, "result_info": result_info } self.flagged_predictions.append(entry) return "Thanks! Your feedback has been recorded." except Exception as e: return f"Error saving feedback: {str(e)}" def create_interface(): app = PlantDiseaseApp() custom_css = """ .main-header { text-align: center; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); padding: 2rem; border-radius: 10px; color: white; margin-bottom: 2rem; } .prediction-box { border: 2px solid #667eea; border-radius: 10px; padding: 1rem; background: #f8f9fa; } """ with gr.Blocks(css=custom_css, title="Plant Disease Detection") as demo: # Header gr.Markdown( """

Plant Disease Detection System

Upload a plant leaf image to detect diseases using AI

""" ) # Model selection (available to all tabs) with gr.Row(): model_selector = gr.Dropdown( choices=list(config.MODEL_CONFIGS.keys()), value="intermediate model", label="Select Model", info="Choose which model to use for predictions" ) confidence_slider = gr.Slider( minimum=0, maximum=100, value=1, step=1, label="Confidence Threshold (%)", info="Only show predictions above this confidence" ) # Tabs for different features with gr.Tabs(): # Tab 1: Single Image Prediction with gr.Tab("Single Image"): with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( label="Upload Plant Leaf Image", type="pil" ) predict_btn = gr.Button("Predict Disease", variant="primary", size="lg") with gr.Accordion("Flag Incorrect Prediction", open=False): feedback_text = gr.Textbox( label="Your Feedback", placeholder="What should the correct classification be?", lines=2 ) flag_btn = gr.Button("Submit Flag") flag_output = gr.Textbox(label="Status", interactive=False) with gr.Column(scale=1): prediction_output = gr.Label( label="Top Predictions", num_top_classes=10 ) result_info = gr.Markdown(label="Detailed Results") with gr.Accordion("Advanced: View Raw Predictions", open=False): raw_predictions = gr.Textbox( label="Raw JSON Output", lines=10, interactive=False ) # Connect buttons predict_btn.click( fn=app.predict, inputs=[image_input, model_selector, confidence_slider], outputs=[prediction_output, result_info, raw_predictions] ) flag_btn.click( fn=app.flag_prediction, inputs=[image_input, result_info, feedback_text], outputs=flag_output ) with gr.Tab("About"): gr.Markdown( """ ## About This Application This Plant Disease Detection system was developed as part of the 5CCSAGAP Artificial Intelligence Group Project at King's College London. ### Features - **Single Image Prediction**: Upload and classify individual plant images - **Multiple Models**: Switch between different trained models - **Batch Processing**: Classify multiple images at once - **Example Gallery**: Try pre-loaded example images - **Flagging System**: Report incorrect predictions to help improve the model - **Confidence Threshold**: Filter predictions by confidence level ### Dataset The model is trained on the PlantVillage dataset, which contains 55,400 images across 39 different plant disease categories. ### Model Architecture - **Basic CNN**: Custom convolutional neural network - **Transfer Learning**: Fine-tuned ResNet18 (if available) ### Technology Stack - **PyTorch**: Model training and inference - **Gradio**: User interface - **ClearML**: Experiment tracking - **Hugging Face Spaces**: Deployment platform ### Team [Add your team members' names here] ### Links - [GitHub Repository](https://github.kcl.ac.uk/K23064919/smallGroupProject) - [ClearML Dashboard](https://5ccsagap.er.kcl.ac.uk/) """ ) gr.Markdown( """ --- **Note:** This is an AI-powered system and predictions should be verified by experts. Built with love by KCL AI Students """ ) return demo if __name__ == "__main__": print("Starting Plant Disease Detection App...") demo = create_interface() demo.launch( share=False, server_name="0.0.0.0", server_port=7860 )