| from __future__ import annotations |
|
|
| import os |
| import sys |
| import tempfile |
| from pathlib import Path |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import supervision as sv |
| import torch |
| from tqdm import tqdm |
| from inference_models import AutoModel |
|
|
| from trackers import ByteTrackTracker, OCSORTTracker, SORTTracker, frames_from_source |
|
|
| MAX_DURATION_SECONDS = 30 |
|
|
| MODELS = [ |
| "rfdetr-nano", |
| "rfdetr-small", |
| "rfdetr-medium", |
| "rfdetr-large", |
| "rfdetr-seg-nano", |
| "rfdetr-seg-small", |
| "rfdetr-seg-medium", |
| "rfdetr-seg-large", |
| ] |
|
|
| TRACKERS = ["bytetrack", "sort", "ocsort"] |
|
|
| COCO_CLASSES = [ |
| "person", |
| "bicycle", |
| "car", |
| "motorcycle", |
| "airplane", |
| "bus", |
| "truck", |
| "cat", |
| "dog", |
| "sports ball", |
| ] |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| print(f"Loading {len(MODELS)} models on {DEVICE}...") |
| LOADED_MODELS = {} |
| for model_id in MODELS: |
| print(f" Loading {model_id}...") |
| LOADED_MODELS[model_id] = AutoModel.from_pretrained(model_id, device=DEVICE) |
| print("All models loaded.") |
|
|
| COLOR_PALETTE = sv.ColorPalette.from_hex( |
| [ |
| "#ffff00", |
| "#ff9b00", |
| "#ff8080", |
| "#ff66b2", |
| "#ff66ff", |
| "#b266ff", |
| "#9999ff", |
| "#3399ff", |
| "#66ffff", |
| "#33ff99", |
| "#66ff66", |
| "#99ff00", |
| ] |
| ) |
|
|
| RESULTS_DIR = "results" |
| os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
|
|
| def _init_annotators( |
| show_boxes: bool = False, |
| show_masks: bool = False, |
| show_labels: bool = False, |
| show_ids: bool = False, |
| show_confidence: bool = False, |
| ) -> tuple[list, sv.LabelAnnotator | None]: |
| """Initialize supervision annotators based on display options.""" |
| annotators: list = [] |
| label_annotator: sv.LabelAnnotator | None = None |
|
|
| if show_masks: |
| annotators.append( |
| sv.MaskAnnotator( |
| color=COLOR_PALETTE, |
| color_lookup=sv.ColorLookup.TRACK, |
| ) |
| ) |
|
|
| if show_boxes: |
| annotators.append( |
| sv.BoxAnnotator( |
| color=COLOR_PALETTE, |
| color_lookup=sv.ColorLookup.TRACK, |
| ) |
| ) |
|
|
| if show_labels or show_ids or show_confidence: |
| label_annotator = sv.LabelAnnotator( |
| color=COLOR_PALETTE, |
| text_color=sv.Color.BLACK, |
| text_position=sv.Position.TOP_LEFT, |
| color_lookup=sv.ColorLookup.TRACK, |
| ) |
|
|
| return annotators, label_annotator |
|
|
|
|
| def _format_labels( |
| detections: sv.Detections, |
| class_names: list[str], |
| *, |
| show_ids: bool = False, |
| show_labels: bool = False, |
| show_confidence: bool = False, |
| ) -> list[str]: |
| """Generate label strings for each detection.""" |
| labels = [] |
|
|
| for i in range(len(detections)): |
| parts = [] |
|
|
| if show_ids and detections.tracker_id is not None: |
| parts.append(f"#{int(detections.tracker_id[i])}") |
|
|
| if show_labels and detections.class_id is not None: |
| class_id = int(detections.class_id[i]) |
| if class_names and 0 <= class_id < len(class_names): |
| parts.append(class_names[class_id]) |
| else: |
| parts.append(str(class_id)) |
|
|
| if show_confidence and detections.confidence is not None: |
| parts.append(f"{detections.confidence[i]:.2f}") |
|
|
| labels.append(" ".join(parts)) |
|
|
| return labels |
|
|
|
|
| VIDEO_EXAMPLES = [ |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4", |
| "rfdetr-small", |
| "ocsort", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.1, |
| 0.6, |
| 0.2, |
| 3, |
| [], |
| "", |
| True, |
| True, |
| False, |
| False, |
| True, |
| False, |
| ], |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4", |
| "rfdetr-small", |
| "ocsort", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.1, |
| 0.6, |
| 0.2, |
| 3, |
| ["person"], |
| "", |
| True, |
| True, |
| False, |
| False, |
| True, |
| False, |
| ], |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-2.mp4", |
| "rfdetr-seg-small", |
| "sort", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.3, |
| 0.6, |
| 0.2, |
| 3, |
| [], |
| "", |
| True, |
| True, |
| False, |
| False, |
| True, |
| True, |
| ], |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/apples-1280x720-2.mp4", |
| "rfdetr-nano", |
| "sort", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.1, |
| 0.6, |
| 0.2, |
| 3, |
| [], |
| "", |
| True, |
| True, |
| True, |
| False, |
| True, |
| False, |
| ], |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-1.mp4", |
| "rfdetr-small", |
| "bytetrack", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.1, |
| 0.6, |
| 0.2, |
| 3, |
| [], |
| "", |
| True, |
| True, |
| False, |
| False, |
| False, |
| False, |
| ], |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-2.mp4", |
| "rfdetr-seg-small", |
| "bytetrack", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.1, |
| 0.6, |
| 0.2, |
| 3, |
| [], |
| "", |
| True, |
| True, |
| False, |
| False, |
| True, |
| True, |
| ], |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-2.mp4", |
| "rfdetr-seg-small", |
| "bytetrack", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.1, |
| 0.6, |
| 0.2, |
| 3, |
| [], |
| "1", |
| True, |
| True, |
| False, |
| False, |
| True, |
| True, |
| ], |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/suitcases-1280x720-4.mp4", |
| "rfdetr-small", |
| "sort", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.1, |
| 0.6, |
| 0.2, |
| 3, |
| [], |
| "", |
| True, |
| True, |
| True, |
| False, |
| True, |
| False, |
| ], |
| [ |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/vehicles-1280x720.mp4", |
| "rfdetr-medium", |
| "ocsort", |
| 0.2, |
| 30, |
| 0.3, |
| 3, |
| 0.1, |
| 0.6, |
| 0.2, |
| 3, |
| [], |
| "", |
| True, |
| True, |
| True, |
| False, |
| True, |
| False, |
| ], |
| ] |
|
|
|
|
| def _get_video_info(path: str) -> tuple[float, int]: |
| """Return video duration in seconds and frame count using OpenCV.""" |
| video_capture = cv2.VideoCapture(path) |
| if not video_capture.isOpened(): |
| raise gr.Error("Could not open the uploaded video.") |
| frames_per_second = video_capture.get(cv2.CAP_PROP_FPS) |
| frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
| video_capture.release() |
| if frames_per_second <= 0: |
| raise gr.Error("Could not determine video frame rate.") |
| return frame_count / frames_per_second, frame_count |
|
|
|
|
| def _resolve_class_filter( |
| classes: list[str] | None, |
| class_names: list[str], |
| ) -> list[int] | None: |
| """Resolve class names to integer IDs.""" |
| if not classes: |
| return None |
|
|
| name_to_id = {name: i for i, name in enumerate(class_names)} |
| class_filter: list[int] = [] |
| for name in classes: |
| if name in name_to_id: |
| class_filter.append(name_to_id[name]) |
| return class_filter if class_filter else None |
|
|
|
|
| def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None: |
| """Resolve a comma-separated string of track IDs to a list of integers. |
| |
| Args: |
| track_ids_arg: Comma-separated string (e.g. `"1,3,5"`). `None` or |
| empty string means no filter. |
| |
| Returns: |
| List of integer track IDs, or `None` when no valid filter remains. |
| """ |
| if not track_ids_arg: |
| return None |
|
|
| track_ids: list[int] = [] |
| for token in track_ids_arg.split(","): |
| token = token.strip() |
| try: |
| track_ids.append(int(token)) |
| except ValueError: |
| print( |
| f"Warning: '{token}' is not a valid track ID, skipping.", |
| file=sys.stderr, |
| ) |
| return track_ids if track_ids else None |
|
|
|
|
| def track( |
| video_path: str, |
| model_id: str, |
| tracker_type: str, |
| confidence: float, |
| lost_track_buffer: int, |
| track_activation_threshold: float, |
| minimum_consecutive_frames: int, |
| minimum_iou_threshold: float, |
| high_conf_det_threshold: float, |
| direction_consistency_weight: float, |
| delta_t: int, |
| classes: list[str] | None = None, |
| track_ids: str = "", |
| show_boxes: bool = True, |
| show_ids: bool = True, |
| show_labels: bool = False, |
| show_confidence: bool = False, |
| show_trajectories: bool = False, |
| show_masks: bool = False, |
| progress=gr.Progress(track_tqdm=True), |
| ) -> str: |
| """Run tracking on the uploaded video and return the output path.""" |
| if video_path is None: |
| raise gr.Error("Please upload a video.") |
|
|
| duration, total_frames = _get_video_info(video_path) |
| if duration > MAX_DURATION_SECONDS: |
| raise gr.Error( |
| f"Video is {duration:.1f}s long. " |
| f"Maximum allowed duration is {MAX_DURATION_SECONDS}s. " |
| f"Please use the trim tool in the Input Video player to shorten it." |
| ) |
|
|
| detection_model = LOADED_MODELS[model_id] |
| class_names = getattr(detection_model, "class_names", []) |
|
|
| selected_class_ids = _resolve_class_filter(classes, class_names) |
| selected_track_ids = _resolve_track_id_filter(track_ids) |
|
|
| if tracker_type == "bytetrack": |
| tracker = ByteTrackTracker( |
| lost_track_buffer=lost_track_buffer, |
| track_activation_threshold=track_activation_threshold, |
| minimum_consecutive_frames=minimum_consecutive_frames, |
| minimum_iou_threshold=minimum_iou_threshold, |
| high_conf_det_threshold=high_conf_det_threshold, |
| ) |
| elif tracker_type == "ocsort": |
| tracker = OCSORTTracker( |
| lost_track_buffer=lost_track_buffer, |
| minimum_consecutive_frames=minimum_consecutive_frames, |
| minimum_iou_threshold=minimum_iou_threshold, |
| high_conf_det_threshold=high_conf_det_threshold, |
| direction_consistency_weight=direction_consistency_weight, |
| delta_t=delta_t, |
| ) |
| else: |
| tracker = SORTTracker( |
| lost_track_buffer=lost_track_buffer, |
| track_activation_threshold=track_activation_threshold, |
| minimum_consecutive_frames=minimum_consecutive_frames, |
| minimum_iou_threshold=minimum_iou_threshold, |
| ) |
| tracker.reset() |
|
|
| annotators, label_annotator = _init_annotators( |
| show_boxes=show_boxes, |
| show_masks=show_masks, |
| show_labels=show_labels, |
| show_ids=show_ids, |
| show_confidence=show_confidence, |
| ) |
| trace_annotator = None |
| if show_trajectories: |
| trace_annotator = sv.TraceAnnotator( |
| color=COLOR_PALETTE, |
| color_lookup=sv.ColorLookup.TRACK, |
| ) |
|
|
| temporary_directory = tempfile.mkdtemp() |
| output_path = str(Path(temporary_directory) / "output.mp4") |
|
|
| video_info = sv.VideoInfo.from_video_path(video_path) |
|
|
| frame_generator = frames_from_source(video_path) |
|
|
| with sv.VideoSink(output_path, video_info=video_info) as sink: |
| for frame_idx, frame in tqdm( |
| frame_generator, total=total_frames, desc="Processing video..." |
| ): |
| predictions = detection_model(frame) |
| if predictions: |
| detections = predictions[0].to_supervision() |
|
|
| if len(detections) > 0 and detections.confidence is not None: |
| confidence_mask = detections.confidence >= confidence |
| detections = detections[confidence_mask] |
|
|
| if selected_class_ids is not None and len(detections) > 0: |
| class_mask = np.isin(detections.class_id, selected_class_ids) |
| detections = detections[class_mask] |
| else: |
| detections = sv.Detections.empty() |
|
|
| tracked = tracker.update(detections) |
|
|
| if selected_track_ids is not None and len(tracked) > 0: |
| if tracked.tracker_id is not None: |
| track_id_mask = np.isin(tracked.tracker_id, selected_track_ids) |
| tracked = tracked[track_id_mask] |
|
|
| annotated = frame.copy() |
| if trace_annotator is not None: |
| annotated = trace_annotator.annotate(annotated, tracked) |
| for annotator in annotators: |
| annotated = annotator.annotate(annotated, tracked) |
| if label_annotator is not None: |
| labeled = tracked[tracked.tracker_id != -1] |
| labels = _format_labels( |
| labeled, |
| class_names, |
| show_ids=show_ids, |
| show_labels=show_labels, |
| show_confidence=show_confidence, |
| ) |
| annotated = label_annotator.annotate(annotated, labeled, labels=labels) |
|
|
| sink.write_frame(annotated) |
|
|
| return output_path |
|
|
|
|
| with gr.Blocks(title="Trackers Playground 🔥") as demo: |
| gr.Markdown( |
| "# Trackers Playground 🔥\n\n" |
| "Upload a video, detect COCO objects with " |
| "[RF-DETR](https://github.com/roboflow-ai/rf-detr) and track them with " |
| "[Trackers](https://github.com/roboflow/trackers)." |
| ) |
|
|
| with gr.Row(): |
| input_video = gr.Video(label="Input Video") |
| output_video = gr.Video(label="Tracked Video") |
|
|
| track_button = gr.Button(value="Track", variant="primary") |
|
|
| with gr.Row(): |
| model_dropdown = gr.Dropdown( |
| choices=MODELS, |
| value="rfdetr-small", |
| label="Detection Model", |
| ) |
| tracker_dropdown = gr.Dropdown( |
| choices=TRACKERS, |
| value="bytetrack", |
| label="Tracker", |
| ) |
|
|
| with gr.Accordion("Configuration", open=False): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Model") |
| confidence_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.2, |
| step=0.05, |
| label="Detection Confidence", |
| info="Minimum score for a detection to be kept.", |
| ) |
| class_filter = gr.CheckboxGroup( |
| choices=COCO_CLASSES, |
| value=[], |
| label="Filter Classes", |
| info="Only track selected classes. None selected means all.", |
| ) |
| track_id_filter = gr.Textbox( |
| value="", |
| label="Filter IDs", |
| info=( |
| "Only display tracks with specific track IDs " |
| "(comma-separated, e.g. 1,3,5). " |
| "Leave empty for all." |
| ), |
| placeholder="e.g. 1,3,5", |
| ) |
|
|
| with gr.Column(): |
| gr.Markdown("### Tracker") |
| lost_track_buffer_slider = gr.Slider( |
| minimum=1, |
| maximum=120, |
| value=30, |
| step=1, |
| label="Lost Track Buffer", |
| info="Frames to keep a lost track before removing it (ByteTrack, SORT, OC-SORT).", |
| ) |
| track_activation_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.3, |
| step=0.05, |
| label="Track Activation Threshold", |
| info="Minimum score for a track to be activated (ByteTrack, SORT).", |
| ) |
| minimum_consecutive_slider = gr.Slider( |
| minimum=1, |
| maximum=10, |
| value=2, |
| step=1, |
| label="Minimum Consecutive Frames", |
| info="Detections needed before a track is confirmed (ByteTrack, SORT, OC-SORT).", |
| ) |
| minimum_iou_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.1, |
| step=0.05, |
| label="Minimum IoU Threshold", |
| info="Overlap required to match a detection to a track (ByteTrack, SORT, OC-SORT).", |
| ) |
| high_confidence_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.6, |
| step=0.05, |
| label="High Confidence Detection Threshold", |
| info="Detections above this are matched first (ByteTrack / OC-SORT).", |
| ) |
| direction_consistency_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.2, |
| step=0.05, |
| label="Direction Consistency Weight", |
| info="Weight for direction consistency in association cost (OC-SORT only).", |
| ) |
| delta_t_slider = gr.Slider( |
| minimum=1, |
| maximum=10, |
| value=3, |
| step=1, |
| label="Delta T", |
| info="Past frames for velocity estimation during occlusion (OC-SORT only).", |
| ) |
|
|
| with gr.Column(): |
| gr.Markdown("### Visualization") |
| show_boxes_checkbox = gr.Checkbox( |
| value=True, |
| label="Show Boxes", |
| info="Draw bounding boxes around detections.", |
| ) |
| show_ids_checkbox = gr.Checkbox( |
| value=True, |
| label="Show IDs", |
| info="Display track ID for each object.", |
| ) |
| show_labels_checkbox = gr.Checkbox( |
| value=False, |
| label="Show Labels", |
| info="Display class name for each detection.", |
| ) |
| show_confidence_checkbox = gr.Checkbox( |
| value=False, |
| label="Show Confidence", |
| info="Display detection confidence score.", |
| ) |
| show_trajectories_checkbox = gr.Checkbox( |
| value=False, |
| label="Show Trajectories", |
| info="Draw motion path for each tracked object.", |
| ) |
| show_masks_checkbox = gr.Checkbox( |
| value=False, |
| label="Show Masks", |
| info="Draw segmentation masks (seg models only).", |
| ) |
|
|
| gr.Examples( |
| examples=VIDEO_EXAMPLES, |
| fn=track, |
| cache_examples=True, |
| inputs=[ |
| input_video, |
| model_dropdown, |
| tracker_dropdown, |
| confidence_slider, |
| lost_track_buffer_slider, |
| track_activation_slider, |
| minimum_consecutive_slider, |
| minimum_iou_slider, |
| high_confidence_slider, |
| direction_consistency_slider, |
| delta_t_slider, |
| class_filter, |
| track_id_filter, |
| show_boxes_checkbox, |
| show_ids_checkbox, |
| show_labels_checkbox, |
| show_confidence_checkbox, |
| show_trajectories_checkbox, |
| show_masks_checkbox, |
| ], |
| outputs=output_video, |
| ) |
|
|
| track_button.click( |
| fn=track, |
| inputs=[ |
| input_video, |
| model_dropdown, |
| tracker_dropdown, |
| confidence_slider, |
| lost_track_buffer_slider, |
| track_activation_slider, |
| minimum_consecutive_slider, |
| minimum_iou_slider, |
| high_confidence_slider, |
| direction_consistency_slider, |
| delta_t_slider, |
| class_filter, |
| track_id_filter, |
| show_boxes_checkbox, |
| show_ids_checkbox, |
| show_labels_checkbox, |
| show_confidence_checkbox, |
| show_trajectories_checkbox, |
| show_masks_checkbox, |
| ], |
| outputs=output_video, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|