mgumowsk commited on
Commit
53ae913
·
1 Parent(s): dd5778f
Files changed (4) hide show
  1. app.py +86 -59
  2. models/serialized.bin +3 -0
  3. models/serialized.xml +0 -0
  4. requirements.txt +2 -1
app.py CHANGED
@@ -8,13 +8,12 @@ import numpy as np
8
  from pathlib import Path
9
  from PIL import Image
10
  import time
11
- from typing import Tuple, Optional
12
- import glob
 
13
 
14
  from model_api.models import Model
15
  from model_api.visualizer import Visualizer
16
- import asyncio
17
- import warnings
18
 
19
  warnings.filterwarnings("ignore", message=".*Invalid file descriptor.*")
20
 
@@ -37,10 +36,7 @@ def get_available_models():
37
  Returns:
38
  list: List of model names (without .xml extension)
39
  """
40
- models_dir = Path("models")
41
- if not models_dir.exists():
42
- return []
43
-
44
  xml_files = list(models_dir.glob("*.xml"))
45
  model_names = [f.stem for f in xml_files]
46
  return sorted(model_names)
@@ -58,8 +54,6 @@ def load_model(model_name: str, device: str = "CPU"):
58
  Model instance from model_api
59
  """
60
  global current_model, current_model_name
61
-
62
- # Check if model is already loaded
63
  if current_model is not None and current_model_name == model_name:
64
  return current_model
65
 
@@ -70,14 +64,6 @@ def load_model(model_name: str, device: str = "CPU"):
70
 
71
  print(f"Loading model: {model_name}")
72
  model = Model.create_model(str(model_path), device=device)
73
-
74
- # Warm-up inference
75
- print("Warming up model...")
76
- dummy_image = np.ones((224, 224, 3), dtype=np.uint8)
77
- for _ in range(3):
78
- _ = model(dummy_image)
79
-
80
- # Reset metrics after warm-up
81
  model.get_performance_metrics().reset()
82
 
83
  current_model = model
@@ -87,13 +73,13 @@ def load_model(model_name: str, device: str = "CPU"):
87
  return model
88
 
89
 
90
- def classify_image(
91
  image: np.ndarray,
92
  model_name: str,
93
  confidence_threshold: float
94
  ) -> Tuple[Image.Image, str, str]:
95
  """
96
- Perform image classification and return visualized result with metrics.
97
 
98
  Args:
99
  image: Input image as numpy array
@@ -101,23 +87,20 @@ def classify_image(
101
  confidence_threshold: Confidence threshold for filtering predictions
102
 
103
  Returns:
104
- Tuple of (visualized_image, detections_text, metrics_text)
105
  """
 
 
 
 
 
 
 
106
  try:
107
- # Load model
108
  model = load_model(model_name)
109
 
110
- # Convert numpy array to PIL Image if needed
111
- if isinstance(image, np.ndarray):
112
- pil_image = Image.fromarray(image)
113
- else:
114
- pil_image = image
115
-
116
- # Convert PIL to numpy for model_api
117
- image_np = np.array(pil_image)
118
-
119
  # Run inference
120
- result = model(image_np)
121
 
122
  # Get performance metrics
123
  metrics = model.get_performance_metrics()
@@ -138,33 +121,76 @@ def classify_image(
138
  📈 Total Frames: {inference_time.count}
139
  """
140
 
141
- # Filter predictions by confidence threshold
142
- detections_text = "🔍 Detected Objects:\n"
143
- detections_text += "━" * 50 + "\n"
144
-
145
- if result.top_labels and len(result.top_labels) > 0:
146
- filtered_labels = [
147
- label for label in result.top_labels
148
- if label.confidence >= confidence_threshold
149
- ]
150
-
151
- if filtered_labels:
152
- for i, label in enumerate(filtered_labels, 1):
153
- detections_text += f"{i}. {label.name}: {label.confidence:.3f}\n"
154
- else:
155
- detections_text += f"No detections above confidence threshold {confidence_threshold:.2f}\n"
156
- else:
157
- detections_text += "No detections found\n"
158
 
159
  # Visualize results using model_api's visualizer
160
- visualized_image = visualizer.render(pil_image, result)
161
 
162
- return visualized_image, detections_text, metrics_text
163
 
164
  except Exception as e:
165
  error_msg = f"Error during inference: {str(e)}"
166
- print(error_msg)
167
- return image, error_msg, "Error: Could not compute metrics"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
 
170
  def create_gradio_interface():
@@ -214,6 +240,7 @@ def create_gradio_interface():
214
  output_image = gr.Image(
215
  label="Detection Result",
216
  type="pil",
 
217
  height=400
218
  )
219
 
@@ -233,20 +260,20 @@ def create_gradio_interface():
233
  gr.Markdown("## 📸 Examples")
234
  gr.Examples(
235
  examples=[
236
- ["examples/image1.jpg", available_models[0] if available_models else "resnet18", 0.3],
237
- ["examples/people-walking.png", available_models[0] if available_models else "resnet50", 0.4],
238
- ["examples/vehicles.png", available_models[0] if available_models else "resnet18", 0.5],
239
- ["examples/zidane.jpg", available_models[0] if available_models else "resnet18", 0.5],
240
  ],
241
  inputs=[input_image, model_dropdown, confidence_slider],
242
  outputs=[output_image, detections_output, metrics_output],
243
- fn=classify_image,
244
  cache_examples=False
245
  )
246
 
247
  # Connect the button to the inference function
248
  classify_btn.click(
249
- fn=classify_image,
250
  inputs=[input_image, model_dropdown, confidence_slider],
251
  outputs=[output_image, detections_output, metrics_output]
252
  )
 
8
  from pathlib import Path
9
  from PIL import Image
10
  import time
11
+ from typing import Tuple, List
12
+ import asyncio
13
+ import warnings
14
 
15
  from model_api.models import Model
16
  from model_api.visualizer import Visualizer
 
 
17
 
18
  warnings.filterwarnings("ignore", message=".*Invalid file descriptor.*")
19
 
 
36
  Returns:
37
  list: List of model names (without .xml extension)
38
  """
39
+ models_dir = Path("models")
 
 
 
40
  xml_files = list(models_dir.glob("*.xml"))
41
  model_names = [f.stem for f in xml_files]
42
  return sorted(model_names)
 
54
  Model instance from model_api
55
  """
56
  global current_model, current_model_name
 
 
57
  if current_model is not None and current_model_name == model_name:
58
  return current_model
59
 
 
64
 
65
  print(f"Loading model: {model_name}")
66
  model = Model.create_model(str(model_path), device=device)
 
 
 
 
 
 
 
 
67
  model.get_performance_metrics().reset()
68
 
69
  current_model = model
 
73
  return model
74
 
75
 
76
+ def run_inference(
77
  image: np.ndarray,
78
  model_name: str,
79
  confidence_threshold: float
80
  ) -> Tuple[Image.Image, str, str]:
81
  """
82
+ Perform inference and return visualized result with metrics.
83
 
84
  Args:
85
  image: Input image as numpy array
 
87
  confidence_threshold: Confidence threshold for filtering predictions
88
 
89
  Returns:
90
+ Tuple of (visualized_image, results_text, metrics_text)
91
  """
92
+ # Input validation
93
+ if image is None:
94
+ return None, "⚠️ Please upload an image first.", ""
95
+
96
+ if model_name is None or model_name == "No models available":
97
+ return None, "⚠️ No model selected or available.", ""
98
+
99
  try:
 
100
  model = load_model(model_name)
101
 
 
 
 
 
 
 
 
 
 
102
  # Run inference
103
+ result = model(image)
104
 
105
  # Get performance metrics
106
  metrics = model.get_performance_metrics()
 
121
  📈 Total Frames: {inference_time.count}
122
  """
123
 
124
+ # Format results based on model type
125
+ results_text = format_results(result, confidence_threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # Visualize results using model_api's visualizer
128
+ visualized_image = visualizer.render(image, result)
129
 
130
+ return visualized_image, metrics_text, metrics_text
131
 
132
  except Exception as e:
133
  error_msg = f"Error during inference: {str(e)}"
134
+ return None, error_msg, "Error: Could not compute metrics"
135
+
136
+
137
+ def format_results(result, confidence_threshold: float) -> str:
138
+ """
139
+ Format model results (classification or detection) as text.
140
+
141
+ Args:
142
+ result: Result object from model_api
143
+ confidence_threshold: Confidence threshold for filtering
144
+
145
+ Returns:
146
+ Formatted results text
147
+ """
148
+ # Check if it's a classification result
149
+ if hasattr(result, 'top_labels') and result.top_labels:
150
+ results_text = "🔍 Classification Results:\n"
151
+ results_text += "━" * 50 + "\n"
152
+
153
+ filtered_labels = [
154
+ label for label in result.top_labels
155
+ if label.confidence >= confidence_threshold
156
+ ]
157
+
158
+ if filtered_labels:
159
+ for i, label in enumerate(filtered_labels, 1):
160
+ results_text += f"{i}. {label.name}: {label.confidence:.3f}\n"
161
+ else:
162
+ results_text += f"No predictions above confidence threshold {confidence_threshold:.2f}\n"
163
+
164
+ # Check if it's a detection result
165
+ elif hasattr(result, 'segmentedObjects') and result.segmentedObjects:
166
+ results_text = "🔍 Detected Objects:\n"
167
+ results_text += "━" * 50 + "\n"
168
+
169
+ # Filter by confidence
170
+ filtered_objects = [
171
+ obj for obj in result.segmentedObjects
172
+ if obj.score >= confidence_threshold
173
+ ]
174
+
175
+ if filtered_objects:
176
+ from collections import Counter
177
+ label_counts = Counter(obj.str_label for obj in filtered_objects)
178
+
179
+ for i, obj in enumerate(filtered_objects, 1):
180
+ x1, y1 = int(obj.xmin), int(obj.ymin)
181
+ x2, y2 = int(obj.xmax), int(obj.ymax)
182
+ results_text += f"{i}. {obj.str_label}: {obj.score:.3f} @ [{x1}, {y1}, {x2}, {y2}]\n"
183
+
184
+ results_text += "\n📊 Summary:\n"
185
+ for label, count in label_counts.most_common():
186
+ results_text += f" • {label}: {count}\n"
187
+ else:
188
+ results_text += f"No detections above confidence threshold {confidence_threshold:.2f}\n"
189
+
190
+ else:
191
+ results_text = "No results available\n"
192
+
193
+ return results_text
194
 
195
 
196
  def create_gradio_interface():
 
240
  output_image = gr.Image(
241
  label="Detection Result",
242
  type="pil",
243
+ show_label=False,
244
  height=400
245
  )
246
 
 
260
  gr.Markdown("## 📸 Examples")
261
  gr.Examples(
262
  examples=[
263
+ ["examples/image1.jpg", "maskrcnn_resnet50_fpn_v2" if "maskrcnn_resnet50_fpn_v2" in available_models else available_models[0], 0.5],
264
+ ["examples/people-walking.png", "maskrcnn_resnet50_fpn_v2" if "maskrcnn_resnet50_fpn_v2" in available_models else available_models[0], 0.5],
265
+ ["examples/vehicles.png", "maskrcnn_resnet50_fpn_v2" if "maskrcnn_resnet50_fpn_v2" in available_models else available_models[0], 0.5],
266
+ ["examples/zidane.jpg", "maskrcnn_resnet50_fpn_v2" if "maskrcnn_resnet50_fpn_v2" in available_models else available_models[0], 0.5],
267
  ],
268
  inputs=[input_image, model_dropdown, confidence_slider],
269
  outputs=[output_image, detections_output, metrics_output],
270
+ fn=run_inference,
271
  cache_examples=False
272
  )
273
 
274
  # Connect the button to the inference function
275
  classify_btn.click(
276
+ fn=run_inference,
277
  inputs=[input_image, model_dropdown, confidence_slider],
278
  outputs=[output_image, detections_output, metrics_output]
279
  )
models/serialized.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91eab2c806de2b61101cdeabbe97d091b8dda14fbb7ee8b8db3b9e54e4b8b72e
3
+ size 10566189
models/serialized.xml ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio>=4.0.0
2
  numpy>=1.21.0
3
  pillow>=9.0.0
4
  openvino>=2024.0.0
5
- git+https://github.com/open-edge-platform/model_api.git
 
 
2
  numpy>=1.21.0
3
  pillow>=9.0.0
4
  openvino>=2024.0.0
5
+ opencv-python-headless>=4.5.0
6
+ git+https://github.com/open-edge-platform/model_api.git@mgumowsk/add-models-to-tool-converter