Conor Brennan (k23064919) commited on
Commit
d2037f7
·
unverified ·
1 Parent(s): 994727d

Update app.py

Browse files

remove mock implementation, local and hf model loading

Files changed (1) hide show
  1. ui/app.py +16 -152
ui/app.py CHANGED
@@ -9,61 +9,28 @@ import sys
9
  from pathlib import Path
10
  import json
11
  from datetime import datetime
12
-
13
  # Add current directory to path
14
  sys.path.append(str(Path(__file__).parent))
15
  sys.path.append(str(Path(__file__).parent.parent))
16
 
17
- import config
18
  from model_loader import ModelLoader
19
- from utils import (
20
- preprocess_image,
21
- postprocess_predictions,
22
- format_class_name,
23
- get_disease_info,
24
- batch_preprocess_images
25
- )
26
- from models.mock_model import create_mock_predictions
27
 
28
 
29
  class PlantDiseaseApp:
30
- """
31
- Main application class for Plant Disease Detection
32
- """
33
-
34
- def __init__(self, use_mock=True):
35
- """
36
- Initialize the application
37
-
38
- Args:
39
- use_mock: Whether to use mock model for development
40
- """
41
- self.use_mock = use_mock
42
- self.model_loader = ModelLoader(use_mock=use_mock)
43
- self.current_model_name = "CNN from Scratch"
44
- self.model = self.model_loader.load_model(self.current_model_name)
45
  self.flagged_predictions = []
46
 
47
- def predict(self, image, model_name, confidence_threshold):
48
- """
49
- Make prediction on a single image
50
-
51
- Args:
52
- image: Input image
53
- model_name: Name of model to use
54
- confidence_threshold: Minimum confidence to display
55
-
56
- Returns:
57
- Predictions, formatted info, and detailed results
58
- """
59
  if image is None:
60
  return None, "Please upload an image", ""
61
 
62
  try:
63
- # Switch model if needed
64
- if model_name != self.current_model_name:
65
- self.model = self.model_loader.load_model(model_name)
66
- self.current_model_name = model_name
67
 
68
  # Preprocess image
69
  tensor = preprocess_image(image)
@@ -71,12 +38,7 @@ class PlantDiseaseApp:
71
 
72
  # Get prediction
73
  with torch.no_grad():
74
- if self.use_mock:
75
- # Use mock predictions for development
76
- predictions = create_mock_predictions(config.CLASS_NAMES)
77
- logits = torch.tensor([list(predictions.values())])
78
- else:
79
- logits = self.model(tensor)
80
 
81
  # Postprocess
82
  top_predictions, all_predictions = postprocess_predictions(
@@ -113,101 +75,11 @@ class PlantDiseaseApp:
113
  except Exception as e:
114
  return None, f"Error during prediction: {str(e)}", ""
115
 
116
- def predict_batch(self, files, model_name, confidence_threshold):
117
- """
118
- Make predictions on multiple images
119
-
120
- Args:
121
- files: List of uploaded files
122
- model_name: Name of model to use
123
- confidence_threshold: Minimum confidence to display
124
-
125
- Returns:
126
- Results for each image
127
- """
128
- if not files:
129
- return "Please upload at least one image"
130
-
131
- results = []
132
- for i, file in enumerate(files):
133
- try:
134
- # Get predictions for this image
135
- preds, info, _ = self.predict(file, model_name, confidence_threshold)
136
-
137
- if preds:
138
- top_class = max(preds.items(), key=lambda x: x[1])[0]
139
- top_prob = preds[top_class]
140
- results.append(f"**Image {i+1}:** {top_class} ({top_prob*100:.2f}%)")
141
- else:
142
- results.append(f"**Image {i+1}:** No prediction")
143
-
144
- except Exception as e:
145
- results.append(f"**Image {i+1}:** Error - {str(e)}")
146
-
147
- return "\n\n".join(results)
148
-
149
- def flag_prediction(self, image, prediction, user_feedback):
150
- """
151
- Flag a prediction as incorrect
152
-
153
- Args:
154
- image: The input image
155
- prediction: The model's prediction
156
- user_feedback: User's feedback text
157
-
158
- Returns:
159
- Confirmation message
160
- """
161
- if image is None:
162
- return "No image to flag"
163
-
164
- flag_entry = {
165
- "timestamp": datetime.now().isoformat(),
166
- "prediction": prediction,
167
- "feedback": user_feedback
168
- }
169
-
170
- self.flagged_predictions.append(flag_entry)
171
-
172
- # In a real deployment, you would save this to a file or database
173
- # For now, we'll just keep it in memory
174
- return f"Thank you! Flagged prediction #{len(self.flagged_predictions)}"
175
-
176
- def get_example_images(self):
177
- """
178
- Get list of example images from examples directory
179
-
180
- Returns:
181
- List of example image paths
182
- """
183
- examples_dir = Path(__file__).parent / "examples"
184
-
185
- if not examples_dir.exists():
186
- return []
187
-
188
- # Get all image files
189
- image_extensions = ['.jpg', '.jpeg', '.png']
190
- examples = []
191
 
192
- for ext in image_extensions:
193
- examples.extend(list(examples_dir.glob(f"*{ext}")))
194
 
195
- return [str(path) for path in examples[:10]] # Return max 10 examples
 
196
 
197
-
198
- def create_interface(use_mock=True):
199
- """
200
- Create the Gradio interface
201
-
202
- Args:
203
- use_mock: Whether to use mock model
204
-
205
- Returns:
206
- Gradio Blocks interface
207
- """
208
- app = PlantDiseaseApp(use_mock=use_mock)
209
-
210
- # Custom CSS for better styling
211
  custom_css = """
212
  .main-header {
213
  text-align: center;
@@ -231,7 +103,7 @@ def create_interface(use_mock=True):
231
  gr.Markdown(
232
  """
233
  <div class="main-header">
234
- <h1>🌱 Plant Disease Detection System</h1>
235
  <p>Upload a plant leaf image to detect diseases using AI</p>
236
  </div>
237
  """
@@ -304,7 +176,6 @@ def create_interface(use_mock=True):
304
  outputs=flag_output
305
  )
306
 
307
- # Tab 2: Example Gallery
308
  with gr.Tab("Example Images"):
309
  gr.Markdown("### Try these example plant images")
310
  gr.Markdown("Click on an example below to load it into the predictor")
@@ -329,7 +200,6 @@ def create_interface(use_mock=True):
329
  """
330
  )
331
 
332
- # Tab 3: Batch Processing
333
  with gr.Tab("Batch Processing"):
334
  gr.Markdown("### Upload multiple images for batch processing")
335
 
@@ -348,8 +218,6 @@ def create_interface(use_mock=True):
348
  inputs=[batch_input, model_selector, confidence_slider],
349
  outputs=batch_output
350
  )
351
-
352
- # Tab 4: About
353
  with gr.Tab("About"):
354
  gr.Markdown(
355
  """
@@ -389,12 +257,11 @@ def create_interface(use_mock=True):
389
  """
390
  )
391
 
392
- # Footer
393
  gr.Markdown(
394
  """
395
  ---
396
  **Note:** This is an AI-powered system and predictions should be verified by experts.
397
- Built with ❤️ by KCL AI Students
398
  """
399
  )
400
 
@@ -402,15 +269,12 @@ def create_interface(use_mock=True):
402
 
403
 
404
  if __name__ == "__main__":
405
- # Create and launch the app
406
  print("Starting Plant Disease Detection App...")
407
 
408
- # Use mock=True for development, mock=False when you have real models
409
- demo = create_interface(use_mock=True)
410
 
411
- # Launch the app
412
  demo.launch(
413
- share=False, # Set to True to create a public link
414
- server_name="0.0.0.0", # Makes it accessible on your network
415
  server_port=7860
416
  )
 
9
  from pathlib import Path
10
  import json
11
  from datetime import datetime
 
12
  # Add current directory to path
13
  sys.path.append(str(Path(__file__).parent))
14
  sys.path.append(str(Path(__file__).parent.parent))
15
 
 
16
  from model_loader import ModelLoader
 
 
 
 
 
 
 
 
17
 
18
 
19
  class PlantDiseaseApp:
20
+ def __init__(self):
21
+ self.model_loader = ModelLoader()
22
+ self.current_modelName = "CNN from Scratch"
23
+ self.model = self.model_loader.loadModel(self.current_modelName)
 
 
 
 
 
 
 
 
 
 
 
24
  self.flagged_predictions = []
25
 
26
+ def predict(self, image, modelName, confidence_threshold):
 
 
 
 
 
 
 
 
 
 
 
27
  if image is None:
28
  return None, "Please upload an image", ""
29
 
30
  try:
31
+ if modelName != self.current_modelName:
32
+ self.model = self.model_loader.loadModel(modelName)
33
+ self.current_modelName = modelName
 
34
 
35
  # Preprocess image
36
  tensor = preprocess_image(image)
 
38
 
39
  # Get prediction
40
  with torch.no_grad():
41
+ logits = self.model(tensor)
 
 
 
 
 
42
 
43
  # Postprocess
44
  top_predictions, all_predictions = postprocess_predictions(
 
75
  except Exception as e:
76
  return None, f"Error during prediction: {str(e)}", ""
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
 
 
79
 
80
+ def create_interface():
81
+ app = PlantDiseaseApp()
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  custom_css = """
84
  .main-header {
85
  text-align: center;
 
103
  gr.Markdown(
104
  """
105
  <div class="main-header">
106
+ <h1>Plant Disease Detection System</h1>
107
  <p>Upload a plant leaf image to detect diseases using AI</p>
108
  </div>
109
  """
 
176
  outputs=flag_output
177
  )
178
 
 
179
  with gr.Tab("Example Images"):
180
  gr.Markdown("### Try these example plant images")
181
  gr.Markdown("Click on an example below to load it into the predictor")
 
200
  """
201
  )
202
 
 
203
  with gr.Tab("Batch Processing"):
204
  gr.Markdown("### Upload multiple images for batch processing")
205
 
 
218
  inputs=[batch_input, model_selector, confidence_slider],
219
  outputs=batch_output
220
  )
 
 
221
  with gr.Tab("About"):
222
  gr.Markdown(
223
  """
 
257
  """
258
  )
259
 
 
260
  gr.Markdown(
261
  """
262
  ---
263
  **Note:** This is an AI-powered system and predictions should be verified by experts.
264
+ Built with love by KCL AI Students
265
  """
266
  )
267
 
 
269
 
270
 
271
  if __name__ == "__main__":
 
272
  print("Starting Plant Disease Detection App...")
273
 
274
+ demo = create_interface()
 
275
 
 
276
  demo.launch(
277
+ share=False,
278
+ server_name="0.0.0.0",
279
  server_port=7860
280
  )