Spaces:
Sleeping
Sleeping
File size: 10,035 Bytes
505fc99 c95878c d669912 da16aef 505fc99 d2037f7 7acdd63 d2037f7 505fc99 e346658 505fc99 d2037f7 8bd7feb 505fc99 8bd7feb d2037f7 fd94b96 d2037f7 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 d2037f7 505fc99 8bd7feb 685281d 72ce591 685281d 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 76d360d 505fc99 76d360d 505fc99 d2037f7 505fc99 d2037f7 505fc99 64978ec 505fc99 aefbb6d 505fc99 d2037f7 505fc99 d2037f7 505fc99 d2037f7 505fc99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
"""
Plant Disease Detection Gradio App
Main UI application with advanced features
"""
import gradio as gr
import torch
import sys
from pathlib import Path
import json
from datetime import datetime
# Add current directory to path
sys.path.append(str(Path(__file__).parent))
sys.path.append(str(Path(__file__).parent.parent))
from model_loader import ModelLoader
import utils
from utils import *
import config
from config import *
class PlantDiseaseApp:
def __init__(self):
self.model_loader = ModelLoader()
self.current_modelName = list(config.MODEL_CONFIGS.keys())[0]
self.model = self.model_loader.loadModel(self.current_modelName)
self.flagged_predictions = []
self.class_names = utils.get_class_names()
def predict(self, image, modelName, confidence_threshold):
"""
Predict plant disease from a single image.
Args:
image: PIL Image or numpy array from Gradio upload
modelName: Name of the model to use
confidence_threshold: float (0-100), only show predictions above this confidence
Returns:
display_predictions: dict, class_name -> probability
result_text: str, formatted top prediction info
raw_predictions: str, JSON-formatted top predictions
"""
if image is None:
return None, "Please upload an image", ""
try:
# Load model if needed
if modelName != self.current_modelName:
self.model = self.model_loader.loadModel(modelName)
self.current_modelName = modelName
# Preprocess image
tensor = preprocess_image(image).to(self.model_loader.device)
# Model inference
with torch.no_grad():
logits = self.model(tensor)
# Convert logits to probabilities
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0]
predID = probs.argmax().item()
print("predicted index: " + str(predID))
# Map to class names
predictions = {name: float(prob) for name, prob in zip(self.class_names, probs)}
# Filter by confidence threshold
filtered_predictions = {k: v for k, v in predictions.items() if v >= confidence_threshold / 100.0}
# Top prediction info
if filtered_predictions:
top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0]
top_prob = filtered_predictions[top_class]
disease_info = get_disease_info(top_class)
result_text = f"""
**Top Prediction:** {disease_info['formatted_name']}
**Confidence:** {top_prob*100:.2f}%
**Plant:** {disease_info['plant']}
**Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'}
"""
else:
result_text = "No predictions above confidence threshold"
# Format for Gradio Label component
display_predictions = {format_class_name(k): v for k, v in filtered_predictions.items()}
# Raw JSON output
import json
raw_predictions = json.dumps(filtered_predictions, indent=2)
return display_predictions, result_text, raw_predictions
except Exception as e:
return None, f"Error during prediction: {str(e)}", ""
def flag_prediction(self, image, result_info, feedback_text):
if image is None:
return "No image uploaded."
if not feedback_text.strip():
return "Please enter feedback before submitting."
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
entry = {
"timestamp": timestamp,
"feedback": feedback_text,
"model": self.current_modelName,
"result_info": result_info
}
self.flagged_predictions.append(entry)
return "Thanks! Your feedback has been recorded."
except Exception as e:
return f"Error saving feedback: {str(e)}"
def create_interface():
app = PlantDiseaseApp()
custom_css = """
.main-header {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
border-radius: 10px;
color: white;
margin-bottom: 2rem;
}
.prediction-box {
border: 2px solid #667eea;
border-radius: 10px;
padding: 1rem;
background: #f8f9fa;
}
"""
with gr.Blocks(css=custom_css, title="Plant Disease Detection") as demo:
# Header
gr.Markdown(
"""
<div class="main-header">
<h1>Plant Disease Detection System</h1>
<p>Upload a plant leaf image to detect diseases using AI</p>
</div>
"""
)
# Model selection (available to all tabs)
with gr.Row():
model_selector = gr.Dropdown(
choices=list(config.MODEL_CONFIGS.keys()),
value="intermediate model",
label="Select Model",
info="Choose which model to use for predictions"
)
confidence_slider = gr.Slider(
minimum=0,
maximum=100,
value=1,
step=1,
label="Confidence Threshold (%)",
info="Only show predictions above this confidence"
)
# Tabs for different features
with gr.Tabs():
# Tab 1: Single Image Prediction
with gr.Tab("Single Image"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Plant Leaf Image",
type="pil"
)
predict_btn = gr.Button("Predict Disease", variant="primary", size="lg")
with gr.Accordion("Flag Incorrect Prediction", open=False):
feedback_text = gr.Textbox(
label="Your Feedback",
placeholder="What should the correct classification be?",
lines=2
)
flag_btn = gr.Button("Submit Flag")
flag_output = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=1):
prediction_output = gr.Label(
label="Top Predictions",
num_top_classes=10
)
result_info = gr.Markdown(label="Detailed Results")
with gr.Accordion("Advanced: View Raw Predictions", open=False):
raw_predictions = gr.Textbox(
label="Raw JSON Output",
lines=10,
interactive=False
)
# Connect buttons
predict_btn.click(
fn=app.predict,
inputs=[image_input, model_selector, confidence_slider],
outputs=[prediction_output, result_info, raw_predictions]
)
flag_btn.click(
fn=app.flag_prediction,
inputs=[image_input, result_info, feedback_text],
outputs=flag_output
)
with gr.Tab("About"):
gr.Markdown(
"""
## About This Application
This Plant Disease Detection system was developed as part of the
5CCSAGAP Artificial Intelligence Group Project at King's College London.
### Features
- **Single Image Prediction**: Upload and classify individual plant images
- **Multiple Models**: Switch between different trained models
- **Batch Processing**: Classify multiple images at once
- **Example Gallery**: Try pre-loaded example images
- **Flagging System**: Report incorrect predictions to help improve the model
- **Confidence Threshold**: Filter predictions by confidence level
### Dataset
The model is trained on the PlantVillage dataset, which contains 55,400 images
across 39 different plant disease categories.
### Model Architecture
- **Basic CNN**: Custom convolutional neural network
- **Transfer Learning**: Fine-tuned ResNet18 (if available)
### Technology Stack
- **PyTorch**: Model training and inference
- **Gradio**: User interface
- **ClearML**: Experiment tracking
- **Hugging Face Spaces**: Deployment platform
### Team
[Add your team members' names here]
### Links
- [GitHub Repository](https://github.kcl.ac.uk/K23064919/smallGroupProject)
- [ClearML Dashboard](https://5ccsagap.er.kcl.ac.uk/)
"""
)
gr.Markdown(
"""
---
**Note:** This is an AI-powered system and predictions should be verified by experts.
Built with love by KCL AI Students
"""
)
return demo
if __name__ == "__main__":
print("Starting Plant Disease Detection App...")
demo = create_interface()
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)
|