File size: 10,035 Bytes
505fc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c95878c
 
d669912
da16aef
505fc99
 
 
d2037f7
 
7acdd63
d2037f7
505fc99
e346658
505fc99
d2037f7
8bd7feb
 
 
 
 
 
 
 
 
 
 
 
 
505fc99
 
 
 
8bd7feb
d2037f7
fd94b96
d2037f7
505fc99
 
8bd7feb
505fc99
8bd7feb
505fc99
d2037f7
505fc99
8bd7feb
 
 
685281d
72ce591
 
685281d
8bd7feb
 
505fc99
 
8bd7feb
505fc99
8bd7feb
505fc99
 
 
 
 
 
8bd7feb
 
 
 
 
505fc99
 
 
 
8bd7feb
 
 
 
 
505fc99
8bd7feb
505fc99
 
 
 
8bd7feb
76d360d
 
 
 
 
 
505fc99
76d360d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505fc99
d2037f7
 
505fc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2037f7
505fc99
 
 
 
 
 
 
 
 
64978ec
505fc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aefbb6d
505fc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2037f7
505fc99
 
 
 
 
 
 
 
 
d2037f7
505fc99
 
d2037f7
 
505fc99
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
"""
Plant Disease Detection Gradio App
Main UI application with advanced features
"""

import gradio as gr
import torch
import sys
from pathlib import Path
import json
from datetime import datetime
# Add current directory to path
sys.path.append(str(Path(__file__).parent))
sys.path.append(str(Path(__file__).parent.parent))

from model_loader import ModelLoader
import utils
from utils import *
import config
from config import *


class PlantDiseaseApp:
    def __init__(self):
        self.model_loader = ModelLoader()
        self.current_modelName = list(config.MODEL_CONFIGS.keys())[0]
        self.model = self.model_loader.loadModel(self.current_modelName)
        self.flagged_predictions = []
        self.class_names = utils.get_class_names()

    def predict(self, image, modelName, confidence_threshold):
        """
        Predict plant disease from a single image.

        Args:
            image: PIL Image or numpy array from Gradio upload
            modelName: Name of the model to use
            confidence_threshold: float (0-100), only show predictions above this confidence

        Returns:
            display_predictions: dict, class_name -> probability
            result_text: str, formatted top prediction info
            raw_predictions: str, JSON-formatted top predictions
        """
        if image is None:
            return None, "Please upload an image", ""

        try:
            # Load model if needed
            if modelName != self.current_modelName:
                self.model = self.model_loader.loadModel(modelName)
                self.current_modelName = modelName

            # Preprocess image
            tensor = preprocess_image(image).to(self.model_loader.device)

            # Model inference
            with torch.no_grad():
                logits = self.model(tensor)

            # Convert logits to probabilities
            probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0]


            predID = probs.argmax().item()
            print("predicted index: " + str(predID))

            # Map to class names
            predictions = {name: float(prob) for name, prob in zip(self.class_names, probs)}

            # Filter by confidence threshold
            filtered_predictions = {k: v for k, v in predictions.items() if v >= confidence_threshold / 100.0}

            # Top prediction info
            if filtered_predictions:
                top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0]
                top_prob = filtered_predictions[top_class]
                disease_info = get_disease_info(top_class)

                result_text = f"""
                    **Top Prediction:** {disease_info['formatted_name']}
                    **Confidence:** {top_prob*100:.2f}%
                    **Plant:** {disease_info['plant']}
                    **Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'}
                    """
            else:
                result_text = "No predictions above confidence threshold"

            # Format for Gradio Label component
            display_predictions = {format_class_name(k): v for k, v in filtered_predictions.items()}

            # Raw JSON output
            import json
            raw_predictions = json.dumps(filtered_predictions, indent=2)

            return display_predictions, result_text, raw_predictions

        except Exception as e:
            return None, f"Error during prediction: {str(e)}", ""


    def flag_prediction(self, image, result_info, feedback_text):
        if image is None:
            return "No image uploaded."

        if not feedback_text.strip():
            return "Please enter feedback before submitting."

        try:
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

            entry = {
                "timestamp": timestamp,
                "feedback": feedback_text,
                "model": self.current_modelName,
                "result_info": result_info
            }

            self.flagged_predictions.append(entry)

            return "Thanks! Your feedback has been recorded."

        except Exception as e:
            return f"Error saving feedback: {str(e)}"

def create_interface():
    app = PlantDiseaseApp()

    custom_css = """
    .main-header {
        text-align: center;
        background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
        padding: 2rem;
        border-radius: 10px;
        color: white;
        margin-bottom: 2rem;
    }
    .prediction-box {
        border: 2px solid #667eea;
        border-radius: 10px;
        padding: 1rem;
        background: #f8f9fa;
    }
    """

    with gr.Blocks(css=custom_css, title="Plant Disease Detection") as demo:

        # Header
        gr.Markdown(
            """
            <div class="main-header">
                <h1>Plant Disease Detection System</h1>
                <p>Upload a plant leaf image to detect diseases using AI</p>
            </div>
            """
        )

        # Model selection (available to all tabs)
        with gr.Row():
            model_selector = gr.Dropdown(
                choices=list(config.MODEL_CONFIGS.keys()),
                value="intermediate model",
                label="Select Model",
                info="Choose which model to use for predictions"
            )
            confidence_slider = gr.Slider(
                minimum=0,
                maximum=100,
                value=1,
                step=1,
                label="Confidence Threshold (%)",
                info="Only show predictions above this confidence"
            )

        # Tabs for different features
        with gr.Tabs():

            # Tab 1: Single Image Prediction
            with gr.Tab("Single Image"):
                with gr.Row():
                    with gr.Column(scale=1):
                        image_input = gr.Image(
                            label="Upload Plant Leaf Image",
                            type="pil"
                        )

                        predict_btn = gr.Button("Predict Disease", variant="primary", size="lg")

                        with gr.Accordion("Flag Incorrect Prediction", open=False):
                            feedback_text = gr.Textbox(
                                label="Your Feedback",
                                placeholder="What should the correct classification be?",
                                lines=2
                            )
                            flag_btn = gr.Button("Submit Flag")
                            flag_output = gr.Textbox(label="Status", interactive=False)

                    with gr.Column(scale=1):
                        prediction_output = gr.Label(
                            label="Top Predictions",
                            num_top_classes=10
                        )
                        result_info = gr.Markdown(label="Detailed Results")

                with gr.Accordion("Advanced: View Raw Predictions", open=False):
                    raw_predictions = gr.Textbox(
                        label="Raw JSON Output",
                        lines=10,
                        interactive=False
                    )

                # Connect buttons
                predict_btn.click(
                    fn=app.predict,
                    inputs=[image_input, model_selector, confidence_slider],
                    outputs=[prediction_output, result_info, raw_predictions]
                )

                flag_btn.click(
                    fn=app.flag_prediction,
                    inputs=[image_input, result_info, feedback_text],
                    outputs=flag_output
                )

            with gr.Tab("About"):
                gr.Markdown(
                    """
                    ## About This Application

                    This Plant Disease Detection system was developed as part of the
                    5CCSAGAP Artificial Intelligence Group Project at King's College London.

                    ### Features
                    - **Single Image Prediction**: Upload and classify individual plant images
                    - **Multiple Models**: Switch between different trained models
                    - **Batch Processing**: Classify multiple images at once
                    - **Example Gallery**: Try pre-loaded example images
                    - **Flagging System**: Report incorrect predictions to help improve the model
                    - **Confidence Threshold**: Filter predictions by confidence level

                    ### Dataset
                    The model is trained on the PlantVillage dataset, which contains 55,400 images
                    across 39 different plant disease categories.

                    ### Model Architecture
                    - **Basic CNN**: Custom convolutional neural network
                    - **Transfer Learning**: Fine-tuned ResNet18 (if available)

                    ### Technology Stack
                    - **PyTorch**: Model training and inference
                    - **Gradio**: User interface
                    - **ClearML**: Experiment tracking
                    - **Hugging Face Spaces**: Deployment platform

                    ### Team
                    [Add your team members' names here]

                    ### Links
                    - [GitHub Repository](https://github.kcl.ac.uk/K23064919/smallGroupProject)
                    - [ClearML Dashboard](https://5ccsagap.er.kcl.ac.uk/)
                    """
                )

        gr.Markdown(
            """
            ---
            **Note:** This is an AI-powered system and predictions should be verified by experts.
            Built with love by KCL AI Students
            """
        )

    return demo


if __name__ == "__main__":
    print("Starting Plant Disease Detection App...")

    demo = create_interface()

    demo.launch(
        share=False,
        server_name="0.0.0.0",
        server_port=7860
    )