Mirko Trasciatti commited on
Commit
aa18f74
·
1 Parent(s): 255f277

Filter detections to sports ball by default

Browse files
Files changed (1) hide show
  1. app.py +42 -8
app.py CHANGED
@@ -23,8 +23,36 @@ def download_model(model_filename):
23
  """
24
  return hf_hub_download(repo_id="atalaydenknalbant/Yolov13", filename=model_filename)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  @spaces.GPU
27
- def yolo_inference(input_type, image, video, model_id, conf_threshold, iou_threshold, max_detection):
28
  """
29
  Performs object detection inference using a YOLOv13 model on either an image or a video.
30
 
@@ -71,12 +99,14 @@ def yolo_inference(input_type, image, video, model_id, conf_threshold, iou_thres
71
  return blank_image, None
72
 
73
  model = YOLO(model_path)
 
74
  results = model.predict(
75
  source=image,
76
  conf=conf_threshold,
77
  iou=iou_threshold,
78
  imgsz=640,
79
  max_det=max_detection,
 
80
  show_labels=True,
81
  show_conf=True,
82
  )
@@ -107,6 +137,7 @@ def yolo_inference(input_type, image, video, model_id, conf_threshold, iou_thres
107
  return None, temp_video_file
108
 
109
  model = YOLO(model_path)
 
110
  cap = cv2.VideoCapture(video)
111
  fps = cap.get(cv2.CAP_PROP_FPS) if cap.get(cv2.CAP_PROP_FPS) > 0 else 25
112
  frames = []
@@ -121,6 +152,7 @@ def yolo_inference(input_type, image, video, model_id, conf_threshold, iou_thres
121
  iou=iou_threshold,
122
  imgsz=640,
123
  max_det=max_detection,
 
124
  show_labels=True,
125
  show_conf=True,
126
  )
@@ -163,7 +195,7 @@ def update_visibility(input_type):
163
  else:
164
  return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
165
 
166
- def yolo_inference_for_examples(image, model_id, conf_threshold, iou_threshold, max_detection):
167
  """
168
  Wrapper function for `yolo_inference` specifically for Gradio examples that use images.
169
 
@@ -187,7 +219,8 @@ def yolo_inference_for_examples(image, model_id, conf_threshold, iou_threshold,
187
  model_id=model_id,
188
  conf_threshold=conf_threshold,
189
  iou_threshold=iou_threshold,
190
- max_detection=max_detection
 
191
  )
192
  return annotated_image
193
 
@@ -234,6 +267,7 @@ with gr.Blocks(theme=theme) as app:
234
  conf_threshold = gr.Slider(minimum=0, maximum=1, value=0.35, label="Confidence Threshold")
235
  iou_threshold = gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold")
236
  max_detection = gr.Slider(minimum=1, maximum=300, step=1, value=300, label="Max Detection")
 
237
  infer_button = gr.Button("Detect Objects", variant="primary")
238
  with gr.Column():
239
  output_image = gr.Image(type="pil", show_label=False, show_share_button=False, visible=True)
@@ -248,18 +282,18 @@ with gr.Blocks(theme=theme) as app:
248
 
249
  infer_button.click(
250
  fn=yolo_inference,
251
- inputs=[input_type, image, video, model_id, conf_threshold, iou_threshold, max_detection],
252
  outputs=[output_image, output_video],
253
  )
254
 
255
  gr.Examples(
256
  examples=[
257
- ["zidane.jpg", "yolov13s.pt", 0.35, 0.45, 300],
258
- ["bus.jpg", "yolov13l.pt", 0.35, 0.45, 300],
259
- ["yolo_vision.jpg", "yolov13x.pt", 0.35, 0.45, 300],
260
  ],
261
  fn=yolo_inference_for_examples,
262
- inputs=[image, model_id, conf_threshold, iou_threshold, max_detection],
263
  outputs=[output_image],
264
  label="Examples (Images)",
265
  )
 
23
  """
24
  return hf_hub_download(repo_id="atalaydenknalbant/Yolov13", filename=model_filename)
25
 
26
+ TARGET_ALIASES = {
27
+ "ball": "sports ball",
28
+ "soccer ball": "sports ball",
29
+ "football": "sports ball",
30
+ "sports ball": "sports ball",
31
+ }
32
+
33
+
34
+ def resolve_target_class(model, target_label: str):
35
+ """
36
+ Resolve a human-provided class name to YOLO class indices.
37
+
38
+ Args:
39
+ model (YOLO): Loaded YOLO model instance.
40
+ target_label (str): Label entered by the user.
41
+
42
+ Returns:
43
+ list[int] | None: List of class indices to filter on, or None to keep all classes.
44
+ """
45
+ if not target_label:
46
+ return None
47
+
48
+ cleaned = target_label.strip().lower()
49
+ canonical = TARGET_ALIASES.get(cleaned, cleaned)
50
+ matching_ids = [idx for idx, name in model.names.items() if name.lower() == canonical]
51
+ return matching_ids or None
52
+
53
+
54
  @spaces.GPU
55
+ def yolo_inference(input_type, image, video, model_id, conf_threshold, iou_threshold, max_detection, target_class):
56
  """
57
  Performs object detection inference using a YOLOv13 model on either an image or a video.
58
 
 
99
  return blank_image, None
100
 
101
  model = YOLO(model_path)
102
+ class_ids = resolve_target_class(model, target_class)
103
  results = model.predict(
104
  source=image,
105
  conf=conf_threshold,
106
  iou=iou_threshold,
107
  imgsz=640,
108
  max_det=max_detection,
109
+ classes=class_ids,
110
  show_labels=True,
111
  show_conf=True,
112
  )
 
137
  return None, temp_video_file
138
 
139
  model = YOLO(model_path)
140
+ class_ids = resolve_target_class(model, target_class)
141
  cap = cv2.VideoCapture(video)
142
  fps = cap.get(cv2.CAP_PROP_FPS) if cap.get(cv2.CAP_PROP_FPS) > 0 else 25
143
  frames = []
 
152
  iou=iou_threshold,
153
  imgsz=640,
154
  max_det=max_detection,
155
+ classes=class_ids,
156
  show_labels=True,
157
  show_conf=True,
158
  )
 
195
  else:
196
  return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
197
 
198
+ def yolo_inference_for_examples(image, model_id, conf_threshold, iou_threshold, max_detection, target_class):
199
  """
200
  Wrapper function for `yolo_inference` specifically for Gradio examples that use images.
201
 
 
219
  model_id=model_id,
220
  conf_threshold=conf_threshold,
221
  iou_threshold=iou_threshold,
222
+ max_detection=max_detection,
223
+ target_class=target_class
224
  )
225
  return annotated_image
226
 
 
267
  conf_threshold = gr.Slider(minimum=0, maximum=1, value=0.35, label="Confidence Threshold")
268
  iou_threshold = gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold")
269
  max_detection = gr.Slider(minimum=1, maximum=300, step=1, value=300, label="Max Detection")
270
+ target_class = gr.Textbox(value="sports ball", label="Target class (default: sports ball)")
271
  infer_button = gr.Button("Detect Objects", variant="primary")
272
  with gr.Column():
273
  output_image = gr.Image(type="pil", show_label=False, show_share_button=False, visible=True)
 
282
 
283
  infer_button.click(
284
  fn=yolo_inference,
285
+ inputs=[input_type, image, video, model_id, conf_threshold, iou_threshold, max_detection, target_class],
286
  outputs=[output_image, output_video],
287
  )
288
 
289
  gr.Examples(
290
  examples=[
291
+ ["zidane.jpg", "yolov13s.pt", 0.35, 0.45, 300, "sports ball"],
292
+ ["bus.jpg", "yolov13l.pt", 0.35, 0.45, 300, "sports ball"],
293
+ ["yolo_vision.jpg", "yolov13x.pt", 0.35, 0.45, 300, "sports ball"],
294
  ],
295
  fn=yolo_inference_for_examples,
296
+ inputs=[image, model_id, conf_threshold, iou_threshold, max_detection, target_class],
297
  outputs=[output_image],
298
  label="Examples (Images)",
299
  )