SkalskiP commited on
Commit
5cd39e6
·
1 Parent(s): af526d5

Add Filter IDs feature, new video examples, and improve code readability

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +147 -51
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea/
2
+ .gradio/
app.py CHANGED
@@ -1,8 +1,7 @@
1
- """Gradio app for the trackers library — run object tracking on uploaded videos."""
2
-
3
  from __future__ import annotations
4
 
5
  import os
 
6
  import tempfile
7
  from pathlib import Path
8
 
@@ -44,7 +43,6 @@ COCO_CLASSES = [
44
  "sports ball",
45
  ]
46
 
47
- # Device and model pre-loading
48
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
 
50
  print(f"Loading {len(MODELS)} models on {DEVICE}...")
@@ -54,7 +52,6 @@ for model_id in MODELS:
54
  LOADED_MODELS[model_id] = AutoModel.from_pretrained(model_id, device=DEVICE)
55
  print("All models loaded.")
56
 
57
- # Visualization
58
  COLOR_PALETTE = sv.ColorPalette.from_hex(
59
  [
60
  "#ffff00",
@@ -158,6 +155,26 @@ VIDEO_EXAMPLES = [
158
  0.1,
159
  0.6,
160
  [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  True,
162
  True,
163
  False,
@@ -176,6 +193,7 @@ VIDEO_EXAMPLES = [
176
  0.3,
177
  0.6,
178
  [],
 
179
  True,
180
  True,
181
  False,
@@ -184,21 +202,22 @@ VIDEO_EXAMPLES = [
184
  True,
185
  ],
186
  [
187
- "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/cars-1280x720-1.mp4",
188
- "rfdetr-small",
189
- "bytetrack",
190
  0.2,
191
  30,
192
  0.3,
193
  3,
194
  0.1,
195
  0.6,
196
- ["car"],
 
197
  True,
198
  True,
199
- False,
200
  True,
201
  False,
 
202
  False,
203
  ],
204
  [
@@ -212,6 +231,7 @@ VIDEO_EXAMPLES = [
212
  0.1,
213
  0.6,
214
  [],
 
215
  True,
216
  True,
217
  False,
@@ -230,16 +250,55 @@ VIDEO_EXAMPLES = [
230
  0.1,
231
  0.6,
232
  [],
 
233
  True,
234
  True,
235
  False,
236
  False,
237
  True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  False,
 
 
239
  ],
240
  [
241
- "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/vehicles-1280x720.mp4",
242
  "rfdetr-small",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  "bytetrack",
244
  0.2,
245
  30,
@@ -248,6 +307,7 @@ VIDEO_EXAMPLES = [
248
  0.1,
249
  0.6,
250
  [],
 
251
  True,
252
  True,
253
  True,
@@ -260,15 +320,15 @@ VIDEO_EXAMPLES = [
260
 
261
  def _get_video_info(path: str) -> tuple[float, int]:
262
  """Return video duration in seconds and frame count using OpenCV."""
263
- cap = cv2.VideoCapture(path)
264
- if not cap.isOpened():
265
  raise gr.Error("Could not open the uploaded video.")
266
- fps = cap.get(cv2.CAP_PROP_FPS)
267
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
268
- cap.release()
269
- if fps <= 0:
270
  raise gr.Error("Could not determine video frame rate.")
271
- return frame_count / fps, frame_count
272
 
273
 
274
  def _resolve_class_filter(
@@ -287,6 +347,32 @@ def _resolve_class_filter(
287
  return class_filter if class_filter else None
288
 
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  def track(
291
  video_path: str,
292
  model_id: str,
@@ -298,6 +384,7 @@ def track(
298
  minimum_iou_threshold: float,
299
  high_conf_det_threshold: float,
300
  classes: list[str] | None = None,
 
301
  show_boxes: bool = True,
302
  show_ids: bool = True,
303
  show_labels: bool = False,
@@ -314,17 +401,16 @@ def track(
314
  if duration > MAX_DURATION_SECONDS:
315
  raise gr.Error(
316
  f"Video is {duration:.1f}s long. "
317
- f"Maximum allowed duration is {MAX_DURATION_SECONDS}s."
 
318
  )
319
 
320
- # Get pre-loaded model
321
  detection_model = LOADED_MODELS[model_id]
322
  class_names = getattr(detection_model, "class_names", [])
323
 
324
- # Resolve class filter
325
- class_filter = _resolve_class_filter(classes, class_names)
326
 
327
- # Create tracker instance and reset ID counter
328
  if tracker_type == "bytetrack":
329
  tracker = ByteTrackTracker(
330
  lost_track_buffer=lost_track_buffer,
@@ -342,7 +428,6 @@ def track(
342
  )
343
  tracker.reset()
344
 
345
- # Setup annotators
346
  annotators, label_annotator = _init_annotators(
347
  show_boxes=show_boxes,
348
  show_masks=show_masks,
@@ -357,39 +442,38 @@ def track(
357
  color_lookup=sv.ColorLookup.TRACK,
358
  )
359
 
360
- # Setup output
361
- tmp_dir = tempfile.mkdtemp()
362
- output_path = str(Path(tmp_dir) / "output.mp4")
363
 
364
- # Get video info for output
365
  video_info = sv.VideoInfo.from_video_path(video_path)
366
 
367
- # Process video with progress bar
368
- frame_gen = frames_from_source(video_path)
369
 
370
  with sv.VideoSink(output_path, video_info=video_info) as sink:
371
- for frame_idx, frame in tqdm(frame_gen, total=total_frames, desc="Processing video..."):
372
- # Run detection
 
373
  predictions = detection_model(frame)
374
  if predictions:
375
  detections = predictions[0].to_supervision()
376
 
377
- # Filter by confidence
378
  if len(detections) > 0 and detections.confidence is not None:
379
- mask = detections.confidence >= confidence
380
- detections = detections[mask]
381
 
382
- # Filter by class
383
- if class_filter is not None and len(detections) > 0:
384
- mask = np.isin(detections.class_id, class_filter)
385
- detections = detections[mask]
386
  else:
387
  detections = sv.Detections.empty()
388
 
389
- # Run tracker
390
  tracked = tracker.update(detections)
391
 
392
- # Annotate frame
 
 
 
 
393
  annotated = frame.copy()
394
  if trace_annotator is not None:
395
  annotated = trace_annotator.annotate(annotated, tracked)
@@ -423,7 +507,7 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
423
  input_video = gr.Video(label="Input Video")
424
  output_video = gr.Video(label="Tracked Video")
425
 
426
- track_btn = gr.Button(value="Track", variant="primary")
427
 
428
  with gr.Row():
429
  model_dropdown = gr.Dropdown(
@@ -455,6 +539,16 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
455
  label="Filter Classes",
456
  info="Only track selected classes. None selected means all.",
457
  )
 
 
 
 
 
 
 
 
 
 
458
 
459
  with gr.Column():
460
  gr.Markdown("### Tracker")
@@ -474,7 +568,7 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
474
  label="Track Activation Threshold",
475
  info="Minimum score for a track to be activated.",
476
  )
477
- min_consecutive_slider = gr.Slider(
478
  minimum=1,
479
  maximum=10,
480
  value=2,
@@ -482,7 +576,7 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
482
  label="Minimum Consecutive Frames",
483
  info="Detections needed before a track is confirmed.",
484
  )
485
- min_iou_slider = gr.Slider(
486
  minimum=0.0,
487
  maximum=1.0,
488
  value=0.1,
@@ -490,7 +584,7 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
490
  label="Minimum IoU Threshold",
491
  info="Overlap required to match a detection to a track.",
492
  )
493
- high_conf_slider = gr.Slider(
494
  minimum=0.0,
495
  maximum=1.0,
496
  value=0.6,
@@ -543,10 +637,11 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
543
  confidence_slider,
544
  lost_track_buffer_slider,
545
  track_activation_slider,
546
- min_consecutive_slider,
547
- min_iou_slider,
548
- high_conf_slider,
549
  class_filter,
 
550
  show_boxes_checkbox,
551
  show_ids_checkbox,
552
  show_labels_checkbox,
@@ -557,7 +652,7 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
557
  outputs=output_video,
558
  )
559
 
560
- track_btn.click(
561
  fn=track,
562
  inputs=[
563
  input_video,
@@ -566,10 +661,11 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
566
  confidence_slider,
567
  lost_track_buffer_slider,
568
  track_activation_slider,
569
- min_consecutive_slider,
570
- min_iou_slider,
571
- high_conf_slider,
572
  class_filter,
 
573
  show_boxes_checkbox,
574
  show_ids_checkbox,
575
  show_labels_checkbox,
 
 
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import sys
5
  import tempfile
6
  from pathlib import Path
7
 
 
43
  "sports ball",
44
  ]
45
 
 
46
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
 
48
  print(f"Loading {len(MODELS)} models on {DEVICE}...")
 
52
  LOADED_MODELS[model_id] = AutoModel.from_pretrained(model_id, device=DEVICE)
53
  print("All models loaded.")
54
 
 
55
  COLOR_PALETTE = sv.ColorPalette.from_hex(
56
  [
57
  "#ffff00",
 
155
  0.1,
156
  0.6,
157
  [],
158
+ "",
159
+ True,
160
+ True,
161
+ False,
162
+ False,
163
+ True,
164
+ False,
165
+ ],
166
+ [
167
+ "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
168
+ "rfdetr-small",
169
+ "bytetrack",
170
+ 0.2,
171
+ 30,
172
+ 0.3,
173
+ 3,
174
+ 0.1,
175
+ 0.6,
176
+ ["person"],
177
+ "",
178
  True,
179
  True,
180
  False,
 
193
  0.3,
194
  0.6,
195
  [],
196
+ "",
197
  True,
198
  True,
199
  False,
 
202
  True,
203
  ],
204
  [
205
+ "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/apples-1280x720-2.mp4",
206
+ "rfdetr-nano",
207
+ "sort",
208
  0.2,
209
  30,
210
  0.3,
211
  3,
212
  0.1,
213
  0.6,
214
+ [],
215
+ "",
216
  True,
217
  True,
 
218
  True,
219
  False,
220
+ True,
221
  False,
222
  ],
223
  [
 
231
  0.1,
232
  0.6,
233
  [],
234
+ "",
235
  True,
236
  True,
237
  False,
 
250
  0.1,
251
  0.6,
252
  [],
253
+ "",
254
  True,
255
  True,
256
  False,
257
  False,
258
  True,
259
+ True,
260
+ ],
261
+ [
262
+ "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-2.mp4",
263
+ "rfdetr-seg-small",
264
+ "bytetrack",
265
+ 0.2,
266
+ 30,
267
+ 0.3,
268
+ 3,
269
+ 0.1,
270
+ 0.6,
271
+ [],
272
+ "1",
273
+ True,
274
+ True,
275
+ False,
276
  False,
277
+ True,
278
+ True,
279
  ],
280
  [
281
+ "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/suitcases-1280x720-4.mp4",
282
  "rfdetr-small",
283
+ "sort",
284
+ 0.2,
285
+ 30,
286
+ 0.3,
287
+ 3,
288
+ 0.1,
289
+ 0.6,
290
+ [],
291
+ "",
292
+ True,
293
+ True,
294
+ True,
295
+ False,
296
+ True,
297
+ False,
298
+ ],
299
+ [
300
+ "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/vehicles-1280x720.mp4",
301
+ "rfdetr-medium",
302
  "bytetrack",
303
  0.2,
304
  30,
 
307
  0.1,
308
  0.6,
309
  [],
310
+ "",
311
  True,
312
  True,
313
  True,
 
320
 
321
  def _get_video_info(path: str) -> tuple[float, int]:
322
  """Return video duration in seconds and frame count using OpenCV."""
323
+ video_capture = cv2.VideoCapture(path)
324
+ if not video_capture.isOpened():
325
  raise gr.Error("Could not open the uploaded video.")
326
+ frames_per_second = video_capture.get(cv2.CAP_PROP_FPS)
327
+ frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
328
+ video_capture.release()
329
+ if frames_per_second <= 0:
330
  raise gr.Error("Could not determine video frame rate.")
331
+ return frame_count / frames_per_second, frame_count
332
 
333
 
334
  def _resolve_class_filter(
 
347
  return class_filter if class_filter else None
348
 
349
 
350
+ def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None:
351
+ """Resolve a comma-separated string of track IDs to a list of integers.
352
+
353
+ Args:
354
+ track_ids_arg: Comma-separated string (e.g. `"1,3,5"`). `None` or
355
+ empty string means no filter.
356
+
357
+ Returns:
358
+ List of integer track IDs, or `None` when no valid filter remains.
359
+ """
360
+ if not track_ids_arg:
361
+ return None
362
+
363
+ track_ids: list[int] = []
364
+ for token in track_ids_arg.split(","):
365
+ token = token.strip()
366
+ try:
367
+ track_ids.append(int(token))
368
+ except ValueError:
369
+ print(
370
+ f"Warning: '{token}' is not a valid track ID, skipping.",
371
+ file=sys.stderr,
372
+ )
373
+ return track_ids if track_ids else None
374
+
375
+
376
  def track(
377
  video_path: str,
378
  model_id: str,
 
384
  minimum_iou_threshold: float,
385
  high_conf_det_threshold: float,
386
  classes: list[str] | None = None,
387
+ track_ids: str = "",
388
  show_boxes: bool = True,
389
  show_ids: bool = True,
390
  show_labels: bool = False,
 
401
  if duration > MAX_DURATION_SECONDS:
402
  raise gr.Error(
403
  f"Video is {duration:.1f}s long. "
404
+ f"Maximum allowed duration is {MAX_DURATION_SECONDS}s. "
405
+ f"Please use the trim tool in the Input Video player to shorten it."
406
  )
407
 
 
408
  detection_model = LOADED_MODELS[model_id]
409
  class_names = getattr(detection_model, "class_names", [])
410
 
411
+ selected_class_ids = _resolve_class_filter(classes, class_names)
412
+ selected_track_ids = _resolve_track_id_filter(track_ids)
413
 
 
414
  if tracker_type == "bytetrack":
415
  tracker = ByteTrackTracker(
416
  lost_track_buffer=lost_track_buffer,
 
428
  )
429
  tracker.reset()
430
 
 
431
  annotators, label_annotator = _init_annotators(
432
  show_boxes=show_boxes,
433
  show_masks=show_masks,
 
442
  color_lookup=sv.ColorLookup.TRACK,
443
  )
444
 
445
+ temporary_directory = tempfile.mkdtemp()
446
+ output_path = str(Path(temporary_directory) / "output.mp4")
 
447
 
 
448
  video_info = sv.VideoInfo.from_video_path(video_path)
449
 
450
+ frame_generator = frames_from_source(video_path)
 
451
 
452
  with sv.VideoSink(output_path, video_info=video_info) as sink:
453
+ for frame_idx, frame in tqdm(
454
+ frame_generator, total=total_frames, desc="Processing video..."
455
+ ):
456
  predictions = detection_model(frame)
457
  if predictions:
458
  detections = predictions[0].to_supervision()
459
 
 
460
  if len(detections) > 0 and detections.confidence is not None:
461
+ confidence_mask = detections.confidence >= confidence
462
+ detections = detections[confidence_mask]
463
 
464
+ if selected_class_ids is not None and len(detections) > 0:
465
+ class_mask = np.isin(detections.class_id, selected_class_ids)
466
+ detections = detections[class_mask]
 
467
  else:
468
  detections = sv.Detections.empty()
469
 
 
470
  tracked = tracker.update(detections)
471
 
472
+ if selected_track_ids is not None and len(tracked) > 0:
473
+ if tracked.tracker_id is not None:
474
+ track_id_mask = np.isin(tracked.tracker_id, selected_track_ids)
475
+ tracked = tracked[track_id_mask]
476
+
477
  annotated = frame.copy()
478
  if trace_annotator is not None:
479
  annotated = trace_annotator.annotate(annotated, tracked)
 
507
  input_video = gr.Video(label="Input Video")
508
  output_video = gr.Video(label="Tracked Video")
509
 
510
+ track_button = gr.Button(value="Track", variant="primary")
511
 
512
  with gr.Row():
513
  model_dropdown = gr.Dropdown(
 
539
  label="Filter Classes",
540
  info="Only track selected classes. None selected means all.",
541
  )
542
+ track_id_filter = gr.Textbox(
543
+ value="",
544
+ label="Filter IDs",
545
+ info=(
546
+ "Only display tracks with specific track IDs "
547
+ "(comma-separated, e.g. 1,3,5). "
548
+ "Leave empty for all."
549
+ ),
550
+ placeholder="e.g. 1,3,5",
551
+ )
552
 
553
  with gr.Column():
554
  gr.Markdown("### Tracker")
 
568
  label="Track Activation Threshold",
569
  info="Minimum score for a track to be activated.",
570
  )
571
+ minimum_consecutive_slider = gr.Slider(
572
  minimum=1,
573
  maximum=10,
574
  value=2,
 
576
  label="Minimum Consecutive Frames",
577
  info="Detections needed before a track is confirmed.",
578
  )
579
+ minimum_iou_slider = gr.Slider(
580
  minimum=0.0,
581
  maximum=1.0,
582
  value=0.1,
 
584
  label="Minimum IoU Threshold",
585
  info="Overlap required to match a detection to a track.",
586
  )
587
+ high_confidence_slider = gr.Slider(
588
  minimum=0.0,
589
  maximum=1.0,
590
  value=0.6,
 
637
  confidence_slider,
638
  lost_track_buffer_slider,
639
  track_activation_slider,
640
+ minimum_consecutive_slider,
641
+ minimum_iou_slider,
642
+ high_confidence_slider,
643
  class_filter,
644
+ track_id_filter,
645
  show_boxes_checkbox,
646
  show_ids_checkbox,
647
  show_labels_checkbox,
 
652
  outputs=output_video,
653
  )
654
 
655
+ track_button.click(
656
  fn=track,
657
  inputs=[
658
  input_video,
 
661
  confidence_slider,
662
  lost_track_buffer_slider,
663
  track_activation_slider,
664
+ minimum_consecutive_slider,
665
+ minimum_iou_slider,
666
+ high_confidence_slider,
667
  class_filter,
668
+ track_id_filter,
669
  show_boxes_checkbox,
670
  show_ids_checkbox,
671
  show_labels_checkbox,