Pre-load RF-DETR models at startup and add progress bar

#1
by SkalskiP - opened
Files changed (2) hide show
  1. app.py +55 -217
  2. requirements.txt +1 -1
app.py CHANGED
@@ -2,19 +2,12 @@
2
 
3
  from __future__ import annotations
4
 
5
- import os
6
  import tempfile
7
  from pathlib import Path
8
 
9
  import cv2
10
  import gradio as gr
11
- import numpy as np
12
- import supervision as sv
13
- import torch
14
- from tqdm import tqdm
15
- from inference_models import AutoModel
16
-
17
- from trackers import ByteTrackTracker, SORTTracker, frames_from_source
18
 
19
  MAX_DURATION_SECONDS = 30
20
 
@@ -44,108 +37,6 @@ COCO_CLASSES = [
44
  "sports ball",
45
  ]
46
 
47
- # Device and model pre-loading
48
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
-
50
- print(f"Loading {len(MODELS)} models on {DEVICE}...")
51
- LOADED_MODELS = {}
52
- for model_id in MODELS:
53
- print(f" Loading {model_id}...")
54
- LOADED_MODELS[model_id] = AutoModel.from_pretrained(model_id, device=DEVICE)
55
- print("All models loaded.")
56
-
57
- # Visualization
58
- COLOR_PALETTE = sv.ColorPalette.from_hex(
59
- [
60
- "#ffff00",
61
- "#ff9b00",
62
- "#ff8080",
63
- "#ff66b2",
64
- "#ff66ff",
65
- "#b266ff",
66
- "#9999ff",
67
- "#3399ff",
68
- "#66ffff",
69
- "#33ff99",
70
- "#66ff66",
71
- "#99ff00",
72
- ]
73
- )
74
-
75
- RESULTS_DIR = "results"
76
- os.makedirs(RESULTS_DIR, exist_ok=True)
77
-
78
-
79
- def _init_annotators(
80
- show_boxes: bool = False,
81
- show_masks: bool = False,
82
- show_labels: bool = False,
83
- show_ids: bool = False,
84
- show_confidence: bool = False,
85
- ) -> tuple[list, sv.LabelAnnotator | None]:
86
- """Initialize supervision annotators based on display options."""
87
- annotators: list = []
88
- label_annotator: sv.LabelAnnotator | None = None
89
-
90
- if show_masks:
91
- annotators.append(
92
- sv.MaskAnnotator(
93
- color=COLOR_PALETTE,
94
- color_lookup=sv.ColorLookup.TRACK,
95
- )
96
- )
97
-
98
- if show_boxes:
99
- annotators.append(
100
- sv.BoxAnnotator(
101
- color=COLOR_PALETTE,
102
- color_lookup=sv.ColorLookup.TRACK,
103
- )
104
- )
105
-
106
- if show_labels or show_ids or show_confidence:
107
- label_annotator = sv.LabelAnnotator(
108
- color=COLOR_PALETTE,
109
- text_color=sv.Color.BLACK,
110
- text_position=sv.Position.TOP_LEFT,
111
- color_lookup=sv.ColorLookup.TRACK,
112
- )
113
-
114
- return annotators, label_annotator
115
-
116
-
117
- def _format_labels(
118
- detections: sv.Detections,
119
- class_names: list[str],
120
- *,
121
- show_ids: bool = False,
122
- show_labels: bool = False,
123
- show_confidence: bool = False,
124
- ) -> list[str]:
125
- """Generate label strings for each detection."""
126
- labels = []
127
-
128
- for i in range(len(detections)):
129
- parts = []
130
-
131
- if show_ids and detections.tracker_id is not None:
132
- parts.append(f"#{int(detections.tracker_id[i])}")
133
-
134
- if show_labels and detections.class_id is not None:
135
- class_id = int(detections.class_id[i])
136
- if class_names and 0 <= class_id < len(class_names):
137
- parts.append(class_names[class_id])
138
- else:
139
- parts.append(str(class_id))
140
-
141
- if show_confidence and detections.confidence is not None:
142
- parts.append(f"{detections.confidence[i]:.2f}")
143
-
144
- labels.append(" ".join(parts))
145
-
146
- return labels
147
-
148
-
149
  VIDEO_EXAMPLES = [
150
  [
151
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
@@ -258,39 +149,23 @@ VIDEO_EXAMPLES = [
258
  ]
259
 
260
 
261
- def _get_video_info(path: str) -> tuple[float, int]:
262
- """Return video duration in seconds and frame count using OpenCV."""
263
  cap = cv2.VideoCapture(path)
264
  if not cap.isOpened():
265
  raise gr.Error("Could not open the uploaded video.")
266
  fps = cap.get(cv2.CAP_PROP_FPS)
267
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
268
  cap.release()
269
  if fps <= 0:
270
  raise gr.Error("Could not determine video frame rate.")
271
- return frame_count / fps, frame_count
272
-
273
-
274
- def _resolve_class_filter(
275
- classes: list[str] | None,
276
- class_names: list[str],
277
- ) -> list[int] | None:
278
- """Resolve class names to integer IDs."""
279
- if not classes:
280
- return None
281
-
282
- name_to_id = {name: i for i, name in enumerate(class_names)}
283
- class_filter: list[int] = []
284
- for name in classes:
285
- if name in name_to_id:
286
- class_filter.append(name_to_id[name])
287
- return class_filter if class_filter else None
288
 
289
 
290
  def track(
291
  video_path: str,
292
- model_id: str,
293
- tracker_type: str,
294
  confidence: float,
295
  lost_track_buffer: int,
296
  track_activation_threshold: float,
@@ -304,109 +179,72 @@ def track(
304
  show_confidence: bool = False,
305
  show_trajectories: bool = False,
306
  show_masks: bool = False,
307
- progress=gr.Progress(track_tqdm=True),
308
  ) -> str:
309
  """Run tracking on the uploaded video and return the output path."""
310
  if video_path is None:
311
  raise gr.Error("Please upload a video.")
312
 
313
- duration, total_frames = _get_video_info(video_path)
314
  if duration > MAX_DURATION_SECONDS:
315
  raise gr.Error(
316
  f"Video is {duration:.1f}s long. "
317
  f"Maximum allowed duration is {MAX_DURATION_SECONDS}s."
318
  )
319
 
320
- # Get pre-loaded model
321
- detection_model = LOADED_MODELS[model_id]
322
- class_names = getattr(detection_model, "class_names", [])
323
-
324
- # Resolve class filter
325
- class_filter = _resolve_class_filter(classes, class_names)
326
-
327
- # Create tracker instance and reset ID counter
328
- if tracker_type == "bytetrack":
329
- tracker = ByteTrackTracker(
330
- lost_track_buffer=lost_track_buffer,
331
- track_activation_threshold=track_activation_threshold,
332
- minimum_consecutive_frames=minimum_consecutive_frames,
333
- minimum_iou_threshold=minimum_iou_threshold,
334
- high_conf_det_threshold=high_conf_det_threshold,
335
- )
336
- else:
337
- tracker = SORTTracker(
338
- lost_track_buffer=lost_track_buffer,
339
- track_activation_threshold=track_activation_threshold,
340
- minimum_consecutive_frames=minimum_consecutive_frames,
341
- minimum_iou_threshold=minimum_iou_threshold,
342
- )
343
- tracker.reset()
344
-
345
- # Setup annotators
346
- annotators, label_annotator = _init_annotators(
347
- show_boxes=show_boxes,
348
- show_masks=show_masks,
349
- show_labels=show_labels,
350
- show_ids=show_ids,
351
- show_confidence=show_confidence,
352
- )
353
- trace_annotator = None
354
- if show_trajectories:
355
- trace_annotator = sv.TraceAnnotator(
356
- color=COLOR_PALETTE,
357
- color_lookup=sv.ColorLookup.TRACK,
358
- )
359
-
360
- # Setup output
361
  tmp_dir = tempfile.mkdtemp()
362
  output_path = str(Path(tmp_dir) / "output.mp4")
363
 
364
- # Get video info for output
365
- video_info = sv.VideoInfo.from_video_path(video_path)
366
-
367
- # Process video with progress bar
368
- frame_gen = frames_from_source(video_path)
369
-
370
- with sv.VideoSink(output_path, video_info=video_info) as sink:
371
- for frame_idx, frame in tqdm(frame_gen, total=total_frames, desc="Processing video..."):
372
- # Run detection
373
- predictions = detection_model(frame)
374
- if predictions:
375
- detections = predictions[0].to_supervision()
376
-
377
- # Filter by confidence
378
- if len(detections) > 0 and detections.confidence is not None:
379
- mask = detections.confidence >= confidence
380
- detections = detections[mask]
 
 
 
 
 
 
 
 
381
 
382
- # Filter by class
383
- if class_filter is not None and len(detections) > 0:
384
- mask = np.isin(detections.class_id, class_filter)
385
- detections = detections[mask]
386
- else:
387
- detections = sv.Detections.empty()
388
 
389
- # Run tracker
390
- tracked = tracker.update(detections)
391
 
392
- # Annotate frame
393
- annotated = frame.copy()
394
- if trace_annotator is not None:
395
- annotated = trace_annotator.annotate(annotated, tracked)
396
- for annotator in annotators:
397
- annotated = annotator.annotate(annotated, tracked)
398
- if label_annotator is not None:
399
- labeled = tracked[tracked.tracker_id != -1]
400
- labels = _format_labels(
401
- labeled,
402
- class_names,
403
- show_ids=show_ids,
404
- show_labels=show_labels,
405
- show_confidence=show_confidence,
406
- )
407
- annotated = label_annotator.annotate(annotated, labeled, labels=labels)
408
 
409
- sink.write_frame(annotated)
 
 
410
 
411
  return output_path
412
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ import subprocess
6
  import tempfile
7
  from pathlib import Path
8
 
9
  import cv2
10
  import gradio as gr
 
 
 
 
 
 
 
11
 
12
  MAX_DURATION_SECONDS = 30
13
 
 
37
  "sports ball",
38
  ]
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  VIDEO_EXAMPLES = [
41
  [
42
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
 
149
  ]
150
 
151
 
152
+ def _get_video_duration(path: str) -> float:
153
+ """Return video duration in seconds using OpenCV."""
154
  cap = cv2.VideoCapture(path)
155
  if not cap.isOpened():
156
  raise gr.Error("Could not open the uploaded video.")
157
  fps = cap.get(cv2.CAP_PROP_FPS)
158
+ frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
159
  cap.release()
160
  if fps <= 0:
161
  raise gr.Error("Could not determine video frame rate.")
162
+ return frame_count / fps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  def track(
166
  video_path: str,
167
+ model: str,
168
+ tracker: str,
169
  confidence: float,
170
  lost_track_buffer: int,
171
  track_activation_threshold: float,
 
179
  show_confidence: bool = False,
180
  show_trajectories: bool = False,
181
  show_masks: bool = False,
 
182
  ) -> str:
183
  """Run tracking on the uploaded video and return the output path."""
184
  if video_path is None:
185
  raise gr.Error("Please upload a video.")
186
 
187
+ duration = _get_video_duration(video_path)
188
  if duration > MAX_DURATION_SECONDS:
189
  raise gr.Error(
190
  f"Video is {duration:.1f}s long. "
191
  f"Maximum allowed duration is {MAX_DURATION_SECONDS}s."
192
  )
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  tmp_dir = tempfile.mkdtemp()
195
  output_path = str(Path(tmp_dir) / "output.mp4")
196
 
197
+ cmd = [
198
+ "trackers",
199
+ "track",
200
+ "--source",
201
+ video_path,
202
+ "--output",
203
+ output_path,
204
+ "--overwrite",
205
+ "--model",
206
+ model,
207
+ "--model.device",
208
+ "cuda",
209
+ "--tracker",
210
+ tracker,
211
+ "--model.confidence",
212
+ str(confidence),
213
+ "--tracker.lost_track_buffer",
214
+ str(lost_track_buffer),
215
+ "--tracker.track_activation_threshold",
216
+ str(track_activation_threshold),
217
+ "--tracker.minimum_consecutive_frames",
218
+ str(minimum_consecutive_frames),
219
+ "--tracker.minimum_iou_threshold",
220
+ str(minimum_iou_threshold),
221
+ ]
222
 
223
+ # ByteTrack extra param
224
+ if tracker == "bytetrack":
225
+ cmd += ["--tracker.high_conf_det_threshold", str(high_conf_det_threshold)]
 
 
 
226
 
227
+ if classes:
228
+ cmd += ["--classes", ",".join(classes)]
229
 
230
+ if show_boxes:
231
+ cmd += ["--show-boxes"]
232
+ else:
233
+ cmd += ["--no-boxes"]
234
+ if show_ids:
235
+ cmd += ["--show-ids"]
236
+ if show_labels:
237
+ cmd += ["--show-labels"]
238
+ if show_confidence:
239
+ cmd += ["--show-confidence"]
240
+ if show_trajectories:
241
+ cmd += ["--show-trajectories"]
242
+ if show_masks:
243
+ cmd += ["--show-masks"]
 
 
244
 
245
+ result = subprocess.run(cmd, capture_output=True, text=True) # noqa: S603
246
+ if result.returncode != 0:
247
+ raise gr.Error(f"Tracking failed:\n{result.stderr[-500:]}")
248
 
249
  return output_path
250
 
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  gradio>=6.3.0,<6.4.0
2
- inference-models==0.18.6rc14
3
  trackers==2.2.0rc1
 
1
  gradio>=6.3.0,<6.4.0
2
+ inference-models[onnx-cpu]==0.18.6rc14
3
  trackers==2.2.0rc1