Mert Yerlikaya
Add feature-rich Gradio UI with mock model
505fc99
raw
history blame
14.2 kB
"""
Plant Disease Detection Gradio App
Main UI application with advanced features
"""
import gradio as gr
import torch
import sys
from pathlib import Path
import json
from datetime import datetime
# Add current directory to path
sys.path.append(str(Path(__file__).parent))
sys.path.append(str(Path(__file__).parent.parent))
import config
from model_loader import ModelLoader
from utils import (
preprocess_image,
postprocess_predictions,
format_class_name,
get_disease_info,
batch_preprocess_images
)
from models.mock_model import create_mock_predictions
class PlantDiseaseApp:
"""
Main application class for Plant Disease Detection
"""
def __init__(self, use_mock=True):
"""
Initialize the application
Args:
use_mock: Whether to use mock model for development
"""
self.use_mock = use_mock
self.model_loader = ModelLoader(use_mock=use_mock)
self.current_model_name = "CNN from Scratch"
self.model = self.model_loader.load_model(self.current_model_name)
self.flagged_predictions = []
def predict(self, image, model_name, confidence_threshold):
"""
Make prediction on a single image
Args:
image: Input image
model_name: Name of model to use
confidence_threshold: Minimum confidence to display
Returns:
Predictions, formatted info, and detailed results
"""
if image is None:
return None, "Please upload an image", ""
try:
# Switch model if needed
if model_name != self.current_model_name:
self.model = self.model_loader.load_model(model_name)
self.current_model_name = model_name
# Preprocess image
tensor = preprocess_image(image)
tensor = tensor.to(self.model_loader.device)
# Get prediction
with torch.no_grad():
if self.use_mock:
# Use mock predictions for development
predictions = create_mock_predictions(config.CLASS_NAMES)
logits = torch.tensor([list(predictions.values())])
else:
logits = self.model(tensor)
# Postprocess
top_predictions, all_predictions = postprocess_predictions(
logits, config.CLASS_NAMES, config.TOP_K_PREDICTIONS
)
# Filter by confidence threshold
filtered_predictions = {
k: v for k, v in top_predictions.items() if v >= confidence_threshold / 100
}
# Get top prediction info
if filtered_predictions:
top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0]
top_prob = filtered_predictions[top_class]
disease_info = get_disease_info(top_class)
result_text = f"""
**Top Prediction:** {disease_info['formatted_name']}
**Confidence:** {top_prob*100:.2f}%
**Plant:** {disease_info['plant']}
**Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'}
"""
else:
result_text = "No predictions above confidence threshold"
# Format for Gradio Label component
display_predictions = {
format_class_name(k): v for k, v in filtered_predictions.items()
}
return display_predictions, result_text, json.dumps(filtered_predictions, indent=2)
except Exception as e:
return None, f"Error during prediction: {str(e)}", ""
def predict_batch(self, files, model_name, confidence_threshold):
"""
Make predictions on multiple images
Args:
files: List of uploaded files
model_name: Name of model to use
confidence_threshold: Minimum confidence to display
Returns:
Results for each image
"""
if not files:
return "Please upload at least one image"
results = []
for i, file in enumerate(files):
try:
# Get predictions for this image
preds, info, _ = self.predict(file, model_name, confidence_threshold)
if preds:
top_class = max(preds.items(), key=lambda x: x[1])[0]
top_prob = preds[top_class]
results.append(f"**Image {i+1}:** {top_class} ({top_prob*100:.2f}%)")
else:
results.append(f"**Image {i+1}:** No prediction")
except Exception as e:
results.append(f"**Image {i+1}:** Error - {str(e)}")
return "\n\n".join(results)
def flag_prediction(self, image, prediction, user_feedback):
"""
Flag a prediction as incorrect
Args:
image: The input image
prediction: The model's prediction
user_feedback: User's feedback text
Returns:
Confirmation message
"""
if image is None:
return "No image to flag"
flag_entry = {
"timestamp": datetime.now().isoformat(),
"prediction": prediction,
"feedback": user_feedback
}
self.flagged_predictions.append(flag_entry)
# In a real deployment, you would save this to a file or database
# For now, we'll just keep it in memory
return f"Thank you! Flagged prediction #{len(self.flagged_predictions)}"
def get_example_images(self):
"""
Get list of example images from examples directory
Returns:
List of example image paths
"""
examples_dir = Path(__file__).parent / "examples"
if not examples_dir.exists():
return []
# Get all image files
image_extensions = ['.jpg', '.jpeg', '.png']
examples = []
for ext in image_extensions:
examples.extend(list(examples_dir.glob(f"*{ext}")))
return [str(path) for path in examples[:10]] # Return max 10 examples
def create_interface(use_mock=True):
"""
Create the Gradio interface
Args:
use_mock: Whether to use mock model
Returns:
Gradio Blocks interface
"""
app = PlantDiseaseApp(use_mock=use_mock)
# Custom CSS for better styling
custom_css = """
.main-header {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
border-radius: 10px;
color: white;
margin-bottom: 2rem;
}
.prediction-box {
border: 2px solid #667eea;
border-radius: 10px;
padding: 1rem;
background: #f8f9fa;
}
"""
with gr.Blocks(css=custom_css, title="Plant Disease Detection") as demo:
# Header
gr.Markdown(
"""
<div class="main-header">
<h1>🌱 Plant Disease Detection System</h1>
<p>Upload a plant leaf image to detect diseases using AI</p>
</div>
"""
)
# Model selection (available to all tabs)
with gr.Row():
model_selector = gr.Dropdown(
choices=list(config.MODEL_CONFIGS.keys()),
value="CNN from Scratch",
label="Select Model",
info="Choose which model to use for predictions"
)
confidence_slider = gr.Slider(
minimum=0,
maximum=100,
value=1,
step=1,
label="Confidence Threshold (%)",
info="Only show predictions above this confidence"
)
# Tabs for different features
with gr.Tabs():
# Tab 1: Single Image Prediction
with gr.Tab("Single Image"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Plant Leaf Image",
type="pil"
)
predict_btn = gr.Button("Predict Disease", variant="primary", size="lg")
with gr.Accordion("Flag Incorrect Prediction", open=False):
feedback_text = gr.Textbox(
label="Your Feedback",
placeholder="What should the correct classification be?",
lines=2
)
flag_btn = gr.Button("Submit Flag")
flag_output = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=1):
prediction_output = gr.Label(
label="Top Predictions",
num_top_classes=10
)
result_info = gr.Markdown(label="Detailed Results")
with gr.Accordion("Advanced: View Raw Predictions", open=False):
raw_predictions = gr.Textbox(
label="Raw JSON Output",
lines=10,
interactive=False
)
# Connect buttons
predict_btn.click(
fn=app.predict,
inputs=[image_input, model_selector, confidence_slider],
outputs=[prediction_output, result_info, raw_predictions]
)
flag_btn.click(
fn=app.flag_prediction,
inputs=[image_input, result_info, feedback_text],
outputs=flag_output
)
# Tab 2: Example Gallery
with gr.Tab("Example Images"):
gr.Markdown("### Try these example plant images")
gr.Markdown("Click on an example below to load it into the predictor")
example_images = app.get_example_images()
if example_images:
examples = gr.Examples(
examples=example_images,
inputs=image_input,
label="Example Plant Disease Images"
)
else:
gr.Markdown(
"""
**No example images found.**
To add example images:
1. Create a folder: `ui/examples/`
2. Add plant leaf images (.jpg, .png) to this folder
3. Restart the app
"""
)
# Tab 3: Batch Processing
with gr.Tab("Batch Processing"):
gr.Markdown("### Upload multiple images for batch processing")
batch_input = gr.File(
label="Upload Multiple Images",
file_count="multiple",
type="filepath"
)
batch_predict_btn = gr.Button("Predict All", variant="primary")
batch_output = gr.Markdown(label="Batch Results")
batch_predict_btn.click(
fn=app.predict_batch,
inputs=[batch_input, model_selector, confidence_slider],
outputs=batch_output
)
# Tab 4: About
with gr.Tab("About"):
gr.Markdown(
"""
## About This Application
This Plant Disease Detection system was developed as part of the
5CCSAGAP Artificial Intelligence Group Project at King's College London.
### Features
- **Single Image Prediction**: Upload and classify individual plant images
- **Multiple Models**: Switch between different trained models
- **Batch Processing**: Classify multiple images at once
- **Example Gallery**: Try pre-loaded example images
- **Flagging System**: Report incorrect predictions to help improve the model
- **Confidence Threshold**: Filter predictions by confidence level
### Dataset
The model is trained on the PlantVillage dataset, which contains 55,400 images
across 39 different plant disease categories.
### Model Architecture
- **CNN from Scratch**: Custom convolutional neural network
- **Transfer Learning**: Fine-tuned ResNet18 (if available)
### Technology Stack
- **PyTorch**: Model training and inference
- **Gradio**: User interface
- **ClearML**: Experiment tracking
- **Hugging Face Spaces**: Deployment platform
### Team
[Add your team members' names here]
### Links
- [GitHub Repository](https://github.kcl.ac.uk/K23064919/smallGroupProject)
- [ClearML Dashboard](https://5ccsagap.er.kcl.ac.uk/)
"""
)
# Footer
gr.Markdown(
"""
---
**Note:** This is an AI-powered system and predictions should be verified by experts.
Built with ❤️ by KCL AI Students
"""
)
return demo
if __name__ == "__main__":
# Create and launch the app
print("Starting Plant Disease Detection App...")
# Use mock=True for development, mock=False when you have real models
demo = create_interface(use_mock=True)
# Launch the app
demo.launch(
share=False, # Set to True to create a public link
server_name="0.0.0.0", # Makes it accessible on your network
server_port=7860
)