Spaces:
Sleeping
Sleeping
| """ | |
| Plant Disease Detection Gradio App | |
| Main UI application with advanced features | |
| """ | |
| import gradio as gr | |
| import torch | |
| import sys | |
| from pathlib import Path | |
| import json | |
| from datetime import datetime | |
| # Add current directory to path | |
| sys.path.append(str(Path(__file__).parent)) | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from model_loader import ModelLoader | |
| import utils | |
| from utils import * | |
| import config | |
| from config import * | |
| class PlantDiseaseApp: | |
| def __init__(self): | |
| self.model_loader = ModelLoader() | |
| self.current_modelName = list(config.MODEL_CONFIGS.keys())[0] | |
| self.model = self.model_loader.loadModel(self.current_modelName) | |
| self.flagged_predictions = [] | |
| self.class_names = utils.get_class_names() | |
| def predict(self, image, modelName, confidence_threshold): | |
| """ | |
| Predict plant disease from a single image. | |
| Args: | |
| image: PIL Image or numpy array from Gradio upload | |
| modelName: Name of the model to use | |
| confidence_threshold: float (0-100), only show predictions above this confidence | |
| Returns: | |
| display_predictions: dict, class_name -> probability | |
| result_text: str, formatted top prediction info | |
| raw_predictions: str, JSON-formatted top predictions | |
| """ | |
| if image is None: | |
| return None, "Please upload an image", "" | |
| try: | |
| # Load model if needed | |
| if modelName != self.current_modelName: | |
| self.model = self.model_loader.loadModel(modelName) | |
| self.current_modelName = modelName | |
| # Preprocess image | |
| tensor = preprocess_image(image).to(self.model_loader.device) | |
| # Model inference | |
| with torch.no_grad(): | |
| logits = self.model(tensor) | |
| # Convert logits to probabilities | |
| probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0] | |
| predID = probs.argmax().item() | |
| print("predicted index: " + str(predID)) | |
| # Map to class names | |
| predictions = {name: float(prob) for name, prob in zip(self.class_names, probs)} | |
| # Filter by confidence threshold | |
| filtered_predictions = {k: v for k, v in predictions.items() if v >= confidence_threshold / 100.0} | |
| # Top prediction info | |
| if filtered_predictions: | |
| top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0] | |
| top_prob = filtered_predictions[top_class] | |
| disease_info = get_disease_info(top_class) | |
| result_text = f""" | |
| **Top Prediction:** {disease_info['formatted_name']} | |
| **Confidence:** {top_prob*100:.2f}% | |
| **Plant:** {disease_info['plant']} | |
| **Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'} | |
| """ | |
| else: | |
| result_text = "No predictions above confidence threshold" | |
| # Format for Gradio Label component | |
| display_predictions = {format_class_name(k): v for k, v in filtered_predictions.items()} | |
| # Raw JSON output | |
| import json | |
| raw_predictions = json.dumps(filtered_predictions, indent=2) | |
| return display_predictions, result_text, raw_predictions | |
| except Exception as e: | |
| return None, f"Error during prediction: {str(e)}", "" | |
| def flag_prediction(self, image, result_info, feedback_text): | |
| if image is None: | |
| return "No image uploaded." | |
| if not feedback_text.strip(): | |
| return "Please enter feedback before submitting." | |
| try: | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| entry = { | |
| "timestamp": timestamp, | |
| "feedback": feedback_text, | |
| "model": self.current_modelName, | |
| "result_info": result_info | |
| } | |
| self.flagged_predictions.append(entry) | |
| return "Thanks! Your feedback has been recorded." | |
| except Exception as e: | |
| return f"Error saving feedback: {str(e)}" | |
| def create_interface(): | |
| app = PlantDiseaseApp() | |
| custom_css = """ | |
| .main-header { | |
| text-align: center; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| padding: 2rem; | |
| border-radius: 10px; | |
| color: white; | |
| margin-bottom: 2rem; | |
| } | |
| .prediction-box { | |
| border: 2px solid #667eea; | |
| border-radius: 10px; | |
| padding: 1rem; | |
| background: #f8f9fa; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, title="Plant Disease Detection") as demo: | |
| # Header | |
| gr.Markdown( | |
| """ | |
| <div class="main-header"> | |
| <h1>Plant Disease Detection System</h1> | |
| <p>Upload a plant leaf image to detect diseases using AI</p> | |
| </div> | |
| """ | |
| ) | |
| # Model selection (available to all tabs) | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=list(config.MODEL_CONFIGS.keys()), | |
| value="intermediate model", | |
| label="Select Model", | |
| info="Choose which model to use for predictions" | |
| ) | |
| confidence_slider = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=1, | |
| step=1, | |
| label="Confidence Threshold (%)", | |
| info="Only show predictions above this confidence" | |
| ) | |
| # Tabs for different features | |
| with gr.Tabs(): | |
| # Tab 1: Single Image Prediction | |
| with gr.Tab("Single Image"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="Upload Plant Leaf Image", | |
| type="pil" | |
| ) | |
| predict_btn = gr.Button("Predict Disease", variant="primary", size="lg") | |
| with gr.Accordion("Flag Incorrect Prediction", open=False): | |
| feedback_text = gr.Textbox( | |
| label="Your Feedback", | |
| placeholder="What should the correct classification be?", | |
| lines=2 | |
| ) | |
| flag_btn = gr.Button("Submit Flag") | |
| flag_output = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=1): | |
| prediction_output = gr.Label( | |
| label="Top Predictions", | |
| num_top_classes=10 | |
| ) | |
| result_info = gr.Markdown(label="Detailed Results") | |
| with gr.Accordion("Advanced: View Raw Predictions", open=False): | |
| raw_predictions = gr.Textbox( | |
| label="Raw JSON Output", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # Connect buttons | |
| predict_btn.click( | |
| fn=app.predict, | |
| inputs=[image_input, model_selector, confidence_slider], | |
| outputs=[prediction_output, result_info, raw_predictions] | |
| ) | |
| flag_btn.click( | |
| fn=app.flag_prediction, | |
| inputs=[image_input, result_info, feedback_text], | |
| outputs=flag_output | |
| ) | |
| with gr.Tab("About"): | |
| gr.Markdown( | |
| """ | |
| ## About This Application | |
| This Plant Disease Detection system was developed as part of the | |
| 5CCSAGAP Artificial Intelligence Group Project at King's College London. | |
| ### Features | |
| - **Single Image Prediction**: Upload and classify individual plant images | |
| - **Multiple Models**: Switch between different trained models | |
| - **Batch Processing**: Classify multiple images at once | |
| - **Example Gallery**: Try pre-loaded example images | |
| - **Flagging System**: Report incorrect predictions to help improve the model | |
| - **Confidence Threshold**: Filter predictions by confidence level | |
| ### Dataset | |
| The model is trained on the PlantVillage dataset, which contains 55,400 images | |
| across 39 different plant disease categories. | |
| ### Model Architecture | |
| - **Basic CNN**: Custom convolutional neural network | |
| - **Transfer Learning**: Fine-tuned ResNet18 (if available) | |
| ### Technology Stack | |
| - **PyTorch**: Model training and inference | |
| - **Gradio**: User interface | |
| - **ClearML**: Experiment tracking | |
| - **Hugging Face Spaces**: Deployment platform | |
| ### Team | |
| [Add your team members' names here] | |
| ### Links | |
| - [GitHub Repository](https://github.kcl.ac.uk/K23064919/smallGroupProject) | |
| - [ClearML Dashboard](https://5ccsagap.er.kcl.ac.uk/) | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Note:** This is an AI-powered system and predictions should be verified by experts. | |
| Built with love by KCL AI Students | |
| """ | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| print("Starting Plant Disease Detection App...") | |
| demo = create_interface() | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |