SkalskiP commited on
Commit
04ddb03
·
1 Parent(s): d944b7f

Pre-load RF-DETR models at startup and add progress bar

Browse files

- Replace CLI subprocess with direct inference using pre-loaded models
- Add tqdm + gr.Progress for video processing feedback
- Switch to full inference-models package (CUDA support)

Files changed (2) hide show
  1. app.py +217 -55
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,12 +2,19 @@
2
 
3
  from __future__ import annotations
4
 
5
- import subprocess
6
  import tempfile
7
  from pathlib import Path
8
 
9
  import cv2
10
  import gradio as gr
 
 
 
 
 
 
 
11
 
12
  MAX_DURATION_SECONDS = 30
13
 
@@ -37,6 +44,108 @@ COCO_CLASSES = [
37
  "sports ball",
38
  ]
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  VIDEO_EXAMPLES = [
41
  [
42
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
@@ -149,23 +258,39 @@ VIDEO_EXAMPLES = [
149
  ]
150
 
151
 
152
- def _get_video_duration(path: str) -> float:
153
- """Return video duration in seconds using OpenCV."""
154
  cap = cv2.VideoCapture(path)
155
  if not cap.isOpened():
156
  raise gr.Error("Could not open the uploaded video.")
157
  fps = cap.get(cv2.CAP_PROP_FPS)
158
- frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
159
  cap.release()
160
  if fps <= 0:
161
  raise gr.Error("Could not determine video frame rate.")
162
- return frame_count / fps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  def track(
166
  video_path: str,
167
- model: str,
168
- tracker: str,
169
  confidence: float,
170
  lost_track_buffer: int,
171
  track_activation_threshold: float,
@@ -179,72 +304,109 @@ def track(
179
  show_confidence: bool = False,
180
  show_trajectories: bool = False,
181
  show_masks: bool = False,
 
182
  ) -> str:
183
  """Run tracking on the uploaded video and return the output path."""
184
  if video_path is None:
185
  raise gr.Error("Please upload a video.")
186
 
187
- duration = _get_video_duration(video_path)
188
  if duration > MAX_DURATION_SECONDS:
189
  raise gr.Error(
190
  f"Video is {duration:.1f}s long. "
191
  f"Maximum allowed duration is {MAX_DURATION_SECONDS}s."
192
  )
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  tmp_dir = tempfile.mkdtemp()
195
  output_path = str(Path(tmp_dir) / "output.mp4")
196
 
197
- cmd = [
198
- "trackers",
199
- "track",
200
- "--source",
201
- video_path,
202
- "--output",
203
- output_path,
204
- "--overwrite",
205
- "--model",
206
- model,
207
- "--model.device",
208
- "cuda",
209
- "--tracker",
210
- tracker,
211
- "--model.confidence",
212
- str(confidence),
213
- "--tracker.lost_track_buffer",
214
- str(lost_track_buffer),
215
- "--tracker.track_activation_threshold",
216
- str(track_activation_threshold),
217
- "--tracker.minimum_consecutive_frames",
218
- str(minimum_consecutive_frames),
219
- "--tracker.minimum_iou_threshold",
220
- str(minimum_iou_threshold),
221
- ]
222
 
223
- # ByteTrack extra param
224
- if tracker == "bytetrack":
225
- cmd += ["--tracker.high_conf_det_threshold", str(high_conf_det_threshold)]
226
 
227
- if classes:
228
- cmd += ["--classes", ",".join(classes)]
 
 
 
 
229
 
230
- if show_boxes:
231
- cmd += ["--show-boxes"]
232
- else:
233
- cmd += ["--no-boxes"]
234
- if show_ids:
235
- cmd += ["--show-ids"]
236
- if show_labels:
237
- cmd += ["--show-labels"]
238
- if show_confidence:
239
- cmd += ["--show-confidence"]
240
- if show_trajectories:
241
- cmd += ["--show-trajectories"]
242
- if show_masks:
243
- cmd += ["--show-masks"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
- result = subprocess.run(cmd, capture_output=True, text=True) # noqa: S603
246
- if result.returncode != 0:
247
- raise gr.Error(f"Tracking failed:\n{result.stderr[-500:]}")
248
 
249
  return output_path
250
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ import os
6
  import tempfile
7
  from pathlib import Path
8
 
9
  import cv2
10
  import gradio as gr
11
+ import numpy as np
12
+ import supervision as sv
13
+ import torch
14
+ from tqdm import tqdm
15
+ from inference_models import AutoModel
16
+
17
+ from trackers import ByteTrackTracker, SORTTracker, frames_from_source
18
 
19
  MAX_DURATION_SECONDS = 30
20
 
 
44
  "sports ball",
45
  ]
46
 
47
+ # Device and model pre-loading
48
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
+
50
+ print(f"Loading {len(MODELS)} models on {DEVICE}...")
51
+ LOADED_MODELS = {}
52
+ for model_id in MODELS:
53
+ print(f" Loading {model_id}...")
54
+ LOADED_MODELS[model_id] = AutoModel.from_pretrained(model_id, device=DEVICE)
55
+ print("All models loaded.")
56
+
57
+ # Visualization
58
+ COLOR_PALETTE = sv.ColorPalette.from_hex(
59
+ [
60
+ "#ffff00",
61
+ "#ff9b00",
62
+ "#ff8080",
63
+ "#ff66b2",
64
+ "#ff66ff",
65
+ "#b266ff",
66
+ "#9999ff",
67
+ "#3399ff",
68
+ "#66ffff",
69
+ "#33ff99",
70
+ "#66ff66",
71
+ "#99ff00",
72
+ ]
73
+ )
74
+
75
+ RESULTS_DIR = "results"
76
+ os.makedirs(RESULTS_DIR, exist_ok=True)
77
+
78
+
79
+ def _init_annotators(
80
+ show_boxes: bool = False,
81
+ show_masks: bool = False,
82
+ show_labels: bool = False,
83
+ show_ids: bool = False,
84
+ show_confidence: bool = False,
85
+ ) -> tuple[list, sv.LabelAnnotator | None]:
86
+ """Initialize supervision annotators based on display options."""
87
+ annotators: list = []
88
+ label_annotator: sv.LabelAnnotator | None = None
89
+
90
+ if show_masks:
91
+ annotators.append(
92
+ sv.MaskAnnotator(
93
+ color=COLOR_PALETTE,
94
+ color_lookup=sv.ColorLookup.TRACK,
95
+ )
96
+ )
97
+
98
+ if show_boxes:
99
+ annotators.append(
100
+ sv.BoxAnnotator(
101
+ color=COLOR_PALETTE,
102
+ color_lookup=sv.ColorLookup.TRACK,
103
+ )
104
+ )
105
+
106
+ if show_labels or show_ids or show_confidence:
107
+ label_annotator = sv.LabelAnnotator(
108
+ color=COLOR_PALETTE,
109
+ text_color=sv.Color.BLACK,
110
+ text_position=sv.Position.TOP_LEFT,
111
+ color_lookup=sv.ColorLookup.TRACK,
112
+ )
113
+
114
+ return annotators, label_annotator
115
+
116
+
117
+ def _format_labels(
118
+ detections: sv.Detections,
119
+ class_names: list[str],
120
+ *,
121
+ show_ids: bool = False,
122
+ show_labels: bool = False,
123
+ show_confidence: bool = False,
124
+ ) -> list[str]:
125
+ """Generate label strings for each detection."""
126
+ labels = []
127
+
128
+ for i in range(len(detections)):
129
+ parts = []
130
+
131
+ if show_ids and detections.tracker_id is not None:
132
+ parts.append(f"#{int(detections.tracker_id[i])}")
133
+
134
+ if show_labels and detections.class_id is not None:
135
+ class_id = int(detections.class_id[i])
136
+ if class_names and 0 <= class_id < len(class_names):
137
+ parts.append(class_names[class_id])
138
+ else:
139
+ parts.append(str(class_id))
140
+
141
+ if show_confidence and detections.confidence is not None:
142
+ parts.append(f"{detections.confidence[i]:.2f}")
143
+
144
+ labels.append(" ".join(parts))
145
+
146
+ return labels
147
+
148
+
149
  VIDEO_EXAMPLES = [
150
  [
151
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
 
258
  ]
259
 
260
 
261
+ def _get_video_info(path: str) -> tuple[float, int]:
262
+ """Return video duration in seconds and frame count using OpenCV."""
263
  cap = cv2.VideoCapture(path)
264
  if not cap.isOpened():
265
  raise gr.Error("Could not open the uploaded video.")
266
  fps = cap.get(cv2.CAP_PROP_FPS)
267
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
268
  cap.release()
269
  if fps <= 0:
270
  raise gr.Error("Could not determine video frame rate.")
271
+ return frame_count / fps, frame_count
272
+
273
+
274
+ def _resolve_class_filter(
275
+ classes: list[str] | None,
276
+ class_names: list[str],
277
+ ) -> list[int] | None:
278
+ """Resolve class names to integer IDs."""
279
+ if not classes:
280
+ return None
281
+
282
+ name_to_id = {name: i for i, name in enumerate(class_names)}
283
+ class_filter: list[int] = []
284
+ for name in classes:
285
+ if name in name_to_id:
286
+ class_filter.append(name_to_id[name])
287
+ return class_filter if class_filter else None
288
 
289
 
290
  def track(
291
  video_path: str,
292
+ model_id: str,
293
+ tracker_type: str,
294
  confidence: float,
295
  lost_track_buffer: int,
296
  track_activation_threshold: float,
 
304
  show_confidence: bool = False,
305
  show_trajectories: bool = False,
306
  show_masks: bool = False,
307
+ progress=gr.Progress(track_tqdm=True),
308
  ) -> str:
309
  """Run tracking on the uploaded video and return the output path."""
310
  if video_path is None:
311
  raise gr.Error("Please upload a video.")
312
 
313
+ duration, total_frames = _get_video_info(video_path)
314
  if duration > MAX_DURATION_SECONDS:
315
  raise gr.Error(
316
  f"Video is {duration:.1f}s long. "
317
  f"Maximum allowed duration is {MAX_DURATION_SECONDS}s."
318
  )
319
 
320
+ # Get pre-loaded model
321
+ detection_model = LOADED_MODELS[model_id]
322
+ class_names = getattr(detection_model, "class_names", [])
323
+
324
+ # Resolve class filter
325
+ class_filter = _resolve_class_filter(classes, class_names)
326
+
327
+ # Create tracker instance and reset ID counter
328
+ if tracker_type == "bytetrack":
329
+ tracker = ByteTrackTracker(
330
+ lost_track_buffer=lost_track_buffer,
331
+ track_activation_threshold=track_activation_threshold,
332
+ minimum_consecutive_frames=minimum_consecutive_frames,
333
+ minimum_iou_threshold=minimum_iou_threshold,
334
+ high_conf_det_threshold=high_conf_det_threshold,
335
+ )
336
+ else:
337
+ tracker = SORTTracker(
338
+ lost_track_buffer=lost_track_buffer,
339
+ track_activation_threshold=track_activation_threshold,
340
+ minimum_consecutive_frames=minimum_consecutive_frames,
341
+ minimum_iou_threshold=minimum_iou_threshold,
342
+ )
343
+ tracker.reset()
344
+
345
+ # Setup annotators
346
+ annotators, label_annotator = _init_annotators(
347
+ show_boxes=show_boxes,
348
+ show_masks=show_masks,
349
+ show_labels=show_labels,
350
+ show_ids=show_ids,
351
+ show_confidence=show_confidence,
352
+ )
353
+ trace_annotator = None
354
+ if show_trajectories:
355
+ trace_annotator = sv.TraceAnnotator(
356
+ color=COLOR_PALETTE,
357
+ color_lookup=sv.ColorLookup.TRACK,
358
+ )
359
+
360
+ # Setup output
361
  tmp_dir = tempfile.mkdtemp()
362
  output_path = str(Path(tmp_dir) / "output.mp4")
363
 
364
+ # Get video info for output
365
+ video_info = sv.VideoInfo.from_video_path(video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
+ # Process video with progress bar
368
+ frame_gen = frames_from_source(video_path)
 
369
 
370
+ with sv.VideoSink(output_path, video_info=video_info) as sink:
371
+ for frame_idx, frame in tqdm(frame_gen, total=total_frames, desc="Processing video..."):
372
+ # Run detection
373
+ predictions = detection_model(frame)
374
+ if predictions:
375
+ detections = predictions[0].to_supervision()
376
 
377
+ # Filter by confidence
378
+ if len(detections) > 0 and detections.confidence is not None:
379
+ mask = detections.confidence >= confidence
380
+ detections = detections[mask]
381
+
382
+ # Filter by class
383
+ if class_filter is not None and len(detections) > 0:
384
+ mask = np.isin(detections.class_id, class_filter)
385
+ detections = detections[mask]
386
+ else:
387
+ detections = sv.Detections.empty()
388
+
389
+ # Run tracker
390
+ tracked = tracker.update(detections)
391
+
392
+ # Annotate frame
393
+ annotated = frame.copy()
394
+ if trace_annotator is not None:
395
+ annotated = trace_annotator.annotate(annotated, tracked)
396
+ for annotator in annotators:
397
+ annotated = annotator.annotate(annotated, tracked)
398
+ if label_annotator is not None:
399
+ labeled = tracked[tracked.tracker_id != -1]
400
+ labels = _format_labels(
401
+ labeled,
402
+ class_names,
403
+ show_ids=show_ids,
404
+ show_labels=show_labels,
405
+ show_confidence=show_confidence,
406
+ )
407
+ annotated = label_annotator.annotate(annotated, labeled, labels=labels)
408
 
409
+ sink.write_frame(annotated)
 
 
410
 
411
  return output_path
412
 
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  gradio>=6.3.0,<6.4.0
2
- trackers[detection]@git+https://github.com/roboflow/trackers.git
 
 
1
  gradio>=6.3.0,<6.4.0
2
+ inference-models==0.18.6rc14
3
+ trackers==2.2.0rc1