Trackers / app.py
SkalskiP's picture
Pre-load RF-DETR models at startup and add progress bar
04ddb03
"""Gradio app for the trackers library — run object tracking on uploaded videos."""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import supervision as sv
import torch
from tqdm import tqdm
from inference_models import AutoModel
from trackers import ByteTrackTracker, SORTTracker, frames_from_source
MAX_DURATION_SECONDS = 30
MODELS = [
"rfdetr-nano",
"rfdetr-small",
"rfdetr-medium",
"rfdetr-large",
"rfdetr-seg-nano",
"rfdetr-seg-small",
"rfdetr-seg-medium",
"rfdetr-seg-large",
]
TRACKERS = ["bytetrack", "sort"]
COCO_CLASSES = [
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"truck",
"cat",
"dog",
"sports ball",
]
# Device and model pre-loading
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading {len(MODELS)} models on {DEVICE}...")
LOADED_MODELS = {}
for model_id in MODELS:
print(f" Loading {model_id}...")
LOADED_MODELS[model_id] = AutoModel.from_pretrained(model_id, device=DEVICE)
print("All models loaded.")
# Visualization
COLOR_PALETTE = sv.ColorPalette.from_hex(
[
"#ffff00",
"#ff9b00",
"#ff8080",
"#ff66b2",
"#ff66ff",
"#b266ff",
"#9999ff",
"#3399ff",
"#66ffff",
"#33ff99",
"#66ff66",
"#99ff00",
]
)
RESULTS_DIR = "results"
os.makedirs(RESULTS_DIR, exist_ok=True)
def _init_annotators(
show_boxes: bool = False,
show_masks: bool = False,
show_labels: bool = False,
show_ids: bool = False,
show_confidence: bool = False,
) -> tuple[list, sv.LabelAnnotator | None]:
"""Initialize supervision annotators based on display options."""
annotators: list = []
label_annotator: sv.LabelAnnotator | None = None
if show_masks:
annotators.append(
sv.MaskAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.TRACK,
)
)
if show_boxes:
annotators.append(
sv.BoxAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.TRACK,
)
)
if show_labels or show_ids or show_confidence:
label_annotator = sv.LabelAnnotator(
color=COLOR_PALETTE,
text_color=sv.Color.BLACK,
text_position=sv.Position.TOP_LEFT,
color_lookup=sv.ColorLookup.TRACK,
)
return annotators, label_annotator
def _format_labels(
detections: sv.Detections,
class_names: list[str],
*,
show_ids: bool = False,
show_labels: bool = False,
show_confidence: bool = False,
) -> list[str]:
"""Generate label strings for each detection."""
labels = []
for i in range(len(detections)):
parts = []
if show_ids and detections.tracker_id is not None:
parts.append(f"#{int(detections.tracker_id[i])}")
if show_labels and detections.class_id is not None:
class_id = int(detections.class_id[i])
if class_names and 0 <= class_id < len(class_names):
parts.append(class_names[class_id])
else:
parts.append(str(class_id))
if show_confidence and detections.confidence is not None:
parts.append(f"{detections.confidence[i]:.2f}")
labels.append(" ".join(parts))
return labels
VIDEO_EXAMPLES = [
[
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
"rfdetr-small",
"bytetrack",
0.2,
30,
0.3,
3,
0.1,
0.6,
[],
True,
True,
False,
False,
True,
False,
],
[
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-2.mp4",
"rfdetr-seg-small",
"sort",
0.2,
30,
0.3,
3,
0.3,
0.6,
[],
True,
True,
False,
False,
True,
True,
],
[
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/cars-1280x720-1.mp4",
"rfdetr-small",
"bytetrack",
0.2,
30,
0.3,
3,
0.1,
0.6,
["car"],
True,
True,
False,
True,
False,
False,
],
[
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-1.mp4",
"rfdetr-small",
"bytetrack",
0.2,
30,
0.3,
3,
0.1,
0.6,
[],
True,
True,
False,
False,
False,
False,
],
[
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-2.mp4",
"rfdetr-seg-small",
"bytetrack",
0.2,
30,
0.3,
3,
0.1,
0.6,
[],
True,
True,
False,
False,
True,
False,
],
[
"https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/vehicles-1280x720.mp4",
"rfdetr-small",
"bytetrack",
0.2,
30,
0.3,
3,
0.1,
0.6,
[],
True,
True,
True,
False,
True,
False,
],
]
def _get_video_info(path: str) -> tuple[float, int]:
"""Return video duration in seconds and frame count using OpenCV."""
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise gr.Error("Could not open the uploaded video.")
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
if fps <= 0:
raise gr.Error("Could not determine video frame rate.")
return frame_count / fps, frame_count
def _resolve_class_filter(
classes: list[str] | None,
class_names: list[str],
) -> list[int] | None:
"""Resolve class names to integer IDs."""
if not classes:
return None
name_to_id = {name: i for i, name in enumerate(class_names)}
class_filter: list[int] = []
for name in classes:
if name in name_to_id:
class_filter.append(name_to_id[name])
return class_filter if class_filter else None
def track(
video_path: str,
model_id: str,
tracker_type: str,
confidence: float,
lost_track_buffer: int,
track_activation_threshold: float,
minimum_consecutive_frames: int,
minimum_iou_threshold: float,
high_conf_det_threshold: float,
classes: list[str] | None = None,
show_boxes: bool = True,
show_ids: bool = True,
show_labels: bool = False,
show_confidence: bool = False,
show_trajectories: bool = False,
show_masks: bool = False,
progress=gr.Progress(track_tqdm=True),
) -> str:
"""Run tracking on the uploaded video and return the output path."""
if video_path is None:
raise gr.Error("Please upload a video.")
duration, total_frames = _get_video_info(video_path)
if duration > MAX_DURATION_SECONDS:
raise gr.Error(
f"Video is {duration:.1f}s long. "
f"Maximum allowed duration is {MAX_DURATION_SECONDS}s."
)
# Get pre-loaded model
detection_model = LOADED_MODELS[model_id]
class_names = getattr(detection_model, "class_names", [])
# Resolve class filter
class_filter = _resolve_class_filter(classes, class_names)
# Create tracker instance and reset ID counter
if tracker_type == "bytetrack":
tracker = ByteTrackTracker(
lost_track_buffer=lost_track_buffer,
track_activation_threshold=track_activation_threshold,
minimum_consecutive_frames=minimum_consecutive_frames,
minimum_iou_threshold=minimum_iou_threshold,
high_conf_det_threshold=high_conf_det_threshold,
)
else:
tracker = SORTTracker(
lost_track_buffer=lost_track_buffer,
track_activation_threshold=track_activation_threshold,
minimum_consecutive_frames=minimum_consecutive_frames,
minimum_iou_threshold=minimum_iou_threshold,
)
tracker.reset()
# Setup annotators
annotators, label_annotator = _init_annotators(
show_boxes=show_boxes,
show_masks=show_masks,
show_labels=show_labels,
show_ids=show_ids,
show_confidence=show_confidence,
)
trace_annotator = None
if show_trajectories:
trace_annotator = sv.TraceAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.TRACK,
)
# Setup output
tmp_dir = tempfile.mkdtemp()
output_path = str(Path(tmp_dir) / "output.mp4")
# Get video info for output
video_info = sv.VideoInfo.from_video_path(video_path)
# Process video with progress bar
frame_gen = frames_from_source(video_path)
with sv.VideoSink(output_path, video_info=video_info) as sink:
for frame_idx, frame in tqdm(frame_gen, total=total_frames, desc="Processing video..."):
# Run detection
predictions = detection_model(frame)
if predictions:
detections = predictions[0].to_supervision()
# Filter by confidence
if len(detections) > 0 and detections.confidence is not None:
mask = detections.confidence >= confidence
detections = detections[mask]
# Filter by class
if class_filter is not None and len(detections) > 0:
mask = np.isin(detections.class_id, class_filter)
detections = detections[mask]
else:
detections = sv.Detections.empty()
# Run tracker
tracked = tracker.update(detections)
# Annotate frame
annotated = frame.copy()
if trace_annotator is not None:
annotated = trace_annotator.annotate(annotated, tracked)
for annotator in annotators:
annotated = annotator.annotate(annotated, tracked)
if label_annotator is not None:
labeled = tracked[tracked.tracker_id != -1]
labels = _format_labels(
labeled,
class_names,
show_ids=show_ids,
show_labels=show_labels,
show_confidence=show_confidence,
)
annotated = label_annotator.annotate(annotated, labeled, labels=labels)
sink.write_frame(annotated)
return output_path
with gr.Blocks(title="Trackers Playground 🔥") as demo:
gr.Markdown(
"# Trackers Playground 🔥\n\n"
"Upload a video, detect COCO objects with "
"[RF-DETR](https://github.com/roboflow-ai/rf-detr) and track them with "
"[Trackers](https://github.com/roboflow/trackers)."
)
with gr.Row():
input_video = gr.Video(label="Input Video")
output_video = gr.Video(label="Tracked Video")
track_btn = gr.Button(value="Track", variant="primary")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=MODELS,
value="rfdetr-small",
label="Detection Model",
)
tracker_dropdown = gr.Dropdown(
choices=TRACKERS,
value="bytetrack",
label="Tracker",
)
with gr.Accordion("Configuration", open=False):
with gr.Row():
with gr.Column():
gr.Markdown("### Model")
confidence_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.05,
label="Detection Confidence",
info="Minimum score for a detection to be kept.",
)
class_filter = gr.CheckboxGroup(
choices=COCO_CLASSES,
value=[],
label="Filter Classes",
info="Only track selected classes. None selected means all.",
)
with gr.Column():
gr.Markdown("### Tracker")
lost_track_buffer_slider = gr.Slider(
minimum=1,
maximum=120,
value=30,
step=1,
label="Lost Track Buffer",
info="Frames to keep a lost track before removing it.",
)
track_activation_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.05,
label="Track Activation Threshold",
info="Minimum score for a track to be activated.",
)
min_consecutive_slider = gr.Slider(
minimum=1,
maximum=10,
value=2,
step=1,
label="Minimum Consecutive Frames",
info="Detections needed before a track is confirmed.",
)
min_iou_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.05,
label="Minimum IoU Threshold",
info="Overlap required to match a detection to a track.",
)
high_conf_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.6,
step=0.05,
label="High Confidence Detection Threshold",
info="Detections above this are matched first (ByteTrack only).",
)
with gr.Column():
gr.Markdown("### Visualization")
show_boxes_checkbox = gr.Checkbox(
value=True,
label="Show Boxes",
info="Draw bounding boxes around detections.",
)
show_ids_checkbox = gr.Checkbox(
value=True,
label="Show IDs",
info="Display track ID for each object.",
)
show_labels_checkbox = gr.Checkbox(
value=False,
label="Show Labels",
info="Display class name for each detection.",
)
show_confidence_checkbox = gr.Checkbox(
value=False,
label="Show Confidence",
info="Display detection confidence score.",
)
show_trajectories_checkbox = gr.Checkbox(
value=False,
label="Show Trajectories",
info="Draw motion path for each tracked object.",
)
show_masks_checkbox = gr.Checkbox(
value=False,
label="Show Masks",
info="Draw segmentation masks (seg models only).",
)
gr.Examples(
examples=VIDEO_EXAMPLES,
fn=track,
cache_examples=True,
inputs=[
input_video,
model_dropdown,
tracker_dropdown,
confidence_slider,
lost_track_buffer_slider,
track_activation_slider,
min_consecutive_slider,
min_iou_slider,
high_conf_slider,
class_filter,
show_boxes_checkbox,
show_ids_checkbox,
show_labels_checkbox,
show_confidence_checkbox,
show_trajectories_checkbox,
show_masks_checkbox,
],
outputs=output_video,
)
track_btn.click(
fn=track,
inputs=[
input_video,
model_dropdown,
tracker_dropdown,
confidence_slider,
lost_track_buffer_slider,
track_activation_slider,
min_consecutive_slider,
min_iou_slider,
high_conf_slider,
class_filter,
show_boxes_checkbox,
show_ids_checkbox,
show_labels_checkbox,
show_confidence_checkbox,
show_trajectories_checkbox,
show_masks_checkbox,
],
outputs=output_video,
)
if __name__ == "__main__":
demo.launch()