k23064919's picture
quickfix
64978ec
"""
Plant Disease Detection Gradio App
Main UI application with advanced features
"""
import gradio as gr
import torch
import sys
from pathlib import Path
import json
from datetime import datetime
# Add current directory to path
sys.path.append(str(Path(__file__).parent))
sys.path.append(str(Path(__file__).parent.parent))
from model_loader import ModelLoader
import utils
from utils import *
import config
from config import *
class PlantDiseaseApp:
def __init__(self):
self.model_loader = ModelLoader()
self.current_modelName = list(config.MODEL_CONFIGS.keys())[0]
self.model = self.model_loader.loadModel(self.current_modelName)
self.flagged_predictions = []
self.class_names = utils.get_class_names()
def predict(self, image, modelName, confidence_threshold):
"""
Predict plant disease from a single image.
Args:
image: PIL Image or numpy array from Gradio upload
modelName: Name of the model to use
confidence_threshold: float (0-100), only show predictions above this confidence
Returns:
display_predictions: dict, class_name -> probability
result_text: str, formatted top prediction info
raw_predictions: str, JSON-formatted top predictions
"""
if image is None:
return None, "Please upload an image", ""
try:
# Load model if needed
if modelName != self.current_modelName:
self.model = self.model_loader.loadModel(modelName)
self.current_modelName = modelName
# Preprocess image
tensor = preprocess_image(image).to(self.model_loader.device)
# Model inference
with torch.no_grad():
logits = self.model(tensor)
# Convert logits to probabilities
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0]
predID = probs.argmax().item()
print("predicted index: " + str(predID))
# Map to class names
predictions = {name: float(prob) for name, prob in zip(self.class_names, probs)}
# Filter by confidence threshold
filtered_predictions = {k: v for k, v in predictions.items() if v >= confidence_threshold / 100.0}
# Top prediction info
if filtered_predictions:
top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0]
top_prob = filtered_predictions[top_class]
disease_info = get_disease_info(top_class)
result_text = f"""
**Top Prediction:** {disease_info['formatted_name']}
**Confidence:** {top_prob*100:.2f}%
**Plant:** {disease_info['plant']}
**Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'}
"""
else:
result_text = "No predictions above confidence threshold"
# Format for Gradio Label component
display_predictions = {format_class_name(k): v for k, v in filtered_predictions.items()}
# Raw JSON output
import json
raw_predictions = json.dumps(filtered_predictions, indent=2)
return display_predictions, result_text, raw_predictions
except Exception as e:
return None, f"Error during prediction: {str(e)}", ""
def flag_prediction(self, image, result_info, feedback_text):
if image is None:
return "No image uploaded."
if not feedback_text.strip():
return "Please enter feedback before submitting."
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
entry = {
"timestamp": timestamp,
"feedback": feedback_text,
"model": self.current_modelName,
"result_info": result_info
}
self.flagged_predictions.append(entry)
return "Thanks! Your feedback has been recorded."
except Exception as e:
return f"Error saving feedback: {str(e)}"
def create_interface():
app = PlantDiseaseApp()
custom_css = """
.main-header {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
border-radius: 10px;
color: white;
margin-bottom: 2rem;
}
.prediction-box {
border: 2px solid #667eea;
border-radius: 10px;
padding: 1rem;
background: #f8f9fa;
}
"""
with gr.Blocks(css=custom_css, title="Plant Disease Detection") as demo:
# Header
gr.Markdown(
"""
<div class="main-header">
<h1>Plant Disease Detection System</h1>
<p>Upload a plant leaf image to detect diseases using AI</p>
</div>
"""
)
# Model selection (available to all tabs)
with gr.Row():
model_selector = gr.Dropdown(
choices=list(config.MODEL_CONFIGS.keys()),
value="intermediate model",
label="Select Model",
info="Choose which model to use for predictions"
)
confidence_slider = gr.Slider(
minimum=0,
maximum=100,
value=1,
step=1,
label="Confidence Threshold (%)",
info="Only show predictions above this confidence"
)
# Tabs for different features
with gr.Tabs():
# Tab 1: Single Image Prediction
with gr.Tab("Single Image"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Plant Leaf Image",
type="pil"
)
predict_btn = gr.Button("Predict Disease", variant="primary", size="lg")
with gr.Accordion("Flag Incorrect Prediction", open=False):
feedback_text = gr.Textbox(
label="Your Feedback",
placeholder="What should the correct classification be?",
lines=2
)
flag_btn = gr.Button("Submit Flag")
flag_output = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=1):
prediction_output = gr.Label(
label="Top Predictions",
num_top_classes=10
)
result_info = gr.Markdown(label="Detailed Results")
with gr.Accordion("Advanced: View Raw Predictions", open=False):
raw_predictions = gr.Textbox(
label="Raw JSON Output",
lines=10,
interactive=False
)
# Connect buttons
predict_btn.click(
fn=app.predict,
inputs=[image_input, model_selector, confidence_slider],
outputs=[prediction_output, result_info, raw_predictions]
)
flag_btn.click(
fn=app.flag_prediction,
inputs=[image_input, result_info, feedback_text],
outputs=flag_output
)
with gr.Tab("About"):
gr.Markdown(
"""
## About This Application
This Plant Disease Detection system was developed as part of the
5CCSAGAP Artificial Intelligence Group Project at King's College London.
### Features
- **Single Image Prediction**: Upload and classify individual plant images
- **Multiple Models**: Switch between different trained models
- **Batch Processing**: Classify multiple images at once
- **Example Gallery**: Try pre-loaded example images
- **Flagging System**: Report incorrect predictions to help improve the model
- **Confidence Threshold**: Filter predictions by confidence level
### Dataset
The model is trained on the PlantVillage dataset, which contains 55,400 images
across 39 different plant disease categories.
### Model Architecture
- **Basic CNN**: Custom convolutional neural network
- **Transfer Learning**: Fine-tuned ResNet18 (if available)
### Technology Stack
- **PyTorch**: Model training and inference
- **Gradio**: User interface
- **ClearML**: Experiment tracking
- **Hugging Face Spaces**: Deployment platform
### Team
[Add your team members' names here]
### Links
- [GitHub Repository](https://github.kcl.ac.uk/K23064919/smallGroupProject)
- [ClearML Dashboard](https://5ccsagap.er.kcl.ac.uk/)
"""
)
gr.Markdown(
"""
---
**Note:** This is an AI-powered system and predictions should be verified by experts.
Built with love by KCL AI Students
"""
)
return demo
if __name__ == "__main__":
print("Starting Plant Disease Detection App...")
demo = create_interface()
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)