Spaces:
Running
on
T4
Running
on
T4
| """Gradio app for the trackers library — run object tracking on uploaded videos.""" | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import supervision as sv | |
| import torch | |
| from tqdm import tqdm | |
| from inference_models import AutoModel | |
| from trackers import ByteTrackTracker, SORTTracker, frames_from_source | |
| MAX_DURATION_SECONDS = 30 | |
| MODELS = [ | |
| "rfdetr-nano", | |
| "rfdetr-small", | |
| "rfdetr-medium", | |
| "rfdetr-large", | |
| "rfdetr-seg-nano", | |
| "rfdetr-seg-small", | |
| "rfdetr-seg-medium", | |
| "rfdetr-seg-large", | |
| ] | |
| TRACKERS = ["bytetrack", "sort"] | |
| COCO_CLASSES = [ | |
| "person", | |
| "bicycle", | |
| "car", | |
| "motorcycle", | |
| "airplane", | |
| "bus", | |
| "truck", | |
| "cat", | |
| "dog", | |
| "sports ball", | |
| ] | |
| # Device and model pre-loading | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading {len(MODELS)} models on {DEVICE}...") | |
| LOADED_MODELS = {} | |
| for model_id in MODELS: | |
| print(f" Loading {model_id}...") | |
| LOADED_MODELS[model_id] = AutoModel.from_pretrained(model_id, device=DEVICE) | |
| print("All models loaded.") | |
| # Visualization | |
| COLOR_PALETTE = sv.ColorPalette.from_hex( | |
| [ | |
| "#ffff00", | |
| "#ff9b00", | |
| "#ff8080", | |
| "#ff66b2", | |
| "#ff66ff", | |
| "#b266ff", | |
| "#9999ff", | |
| "#3399ff", | |
| "#66ffff", | |
| "#33ff99", | |
| "#66ff66", | |
| "#99ff00", | |
| ] | |
| ) | |
| RESULTS_DIR = "results" | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| def _init_annotators( | |
| show_boxes: bool = False, | |
| show_masks: bool = False, | |
| show_labels: bool = False, | |
| show_ids: bool = False, | |
| show_confidence: bool = False, | |
| ) -> tuple[list, sv.LabelAnnotator | None]: | |
| """Initialize supervision annotators based on display options.""" | |
| annotators: list = [] | |
| label_annotator: sv.LabelAnnotator | None = None | |
| if show_masks: | |
| annotators.append( | |
| sv.MaskAnnotator( | |
| color=COLOR_PALETTE, | |
| color_lookup=sv.ColorLookup.TRACK, | |
| ) | |
| ) | |
| if show_boxes: | |
| annotators.append( | |
| sv.BoxAnnotator( | |
| color=COLOR_PALETTE, | |
| color_lookup=sv.ColorLookup.TRACK, | |
| ) | |
| ) | |
| if show_labels or show_ids or show_confidence: | |
| label_annotator = sv.LabelAnnotator( | |
| color=COLOR_PALETTE, | |
| text_color=sv.Color.BLACK, | |
| text_position=sv.Position.TOP_LEFT, | |
| color_lookup=sv.ColorLookup.TRACK, | |
| ) | |
| return annotators, label_annotator | |
| def _format_labels( | |
| detections: sv.Detections, | |
| class_names: list[str], | |
| *, | |
| show_ids: bool = False, | |
| show_labels: bool = False, | |
| show_confidence: bool = False, | |
| ) -> list[str]: | |
| """Generate label strings for each detection.""" | |
| labels = [] | |
| for i in range(len(detections)): | |
| parts = [] | |
| if show_ids and detections.tracker_id is not None: | |
| parts.append(f"#{int(detections.tracker_id[i])}") | |
| if show_labels and detections.class_id is not None: | |
| class_id = int(detections.class_id[i]) | |
| if class_names and 0 <= class_id < len(class_names): | |
| parts.append(class_names[class_id]) | |
| else: | |
| parts.append(str(class_id)) | |
| if show_confidence and detections.confidence is not None: | |
| parts.append(f"{detections.confidence[i]:.2f}") | |
| labels.append(" ".join(parts)) | |
| return labels | |
| VIDEO_EXAMPLES = [ | |
| [ | |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4", | |
| "rfdetr-small", | |
| "bytetrack", | |
| 0.2, | |
| 30, | |
| 0.3, | |
| 3, | |
| 0.1, | |
| 0.6, | |
| [], | |
| True, | |
| True, | |
| False, | |
| False, | |
| True, | |
| False, | |
| ], | |
| [ | |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-2.mp4", | |
| "rfdetr-seg-small", | |
| "sort", | |
| 0.2, | |
| 30, | |
| 0.3, | |
| 3, | |
| 0.3, | |
| 0.6, | |
| [], | |
| True, | |
| True, | |
| False, | |
| False, | |
| True, | |
| True, | |
| ], | |
| [ | |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/cars-1280x720-1.mp4", | |
| "rfdetr-small", | |
| "bytetrack", | |
| 0.2, | |
| 30, | |
| 0.3, | |
| 3, | |
| 0.1, | |
| 0.6, | |
| ["car"], | |
| True, | |
| True, | |
| False, | |
| True, | |
| False, | |
| False, | |
| ], | |
| [ | |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-1.mp4", | |
| "rfdetr-small", | |
| "bytetrack", | |
| 0.2, | |
| 30, | |
| 0.3, | |
| 3, | |
| 0.1, | |
| 0.6, | |
| [], | |
| True, | |
| True, | |
| False, | |
| False, | |
| False, | |
| False, | |
| ], | |
| [ | |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-2.mp4", | |
| "rfdetr-seg-small", | |
| "bytetrack", | |
| 0.2, | |
| 30, | |
| 0.3, | |
| 3, | |
| 0.1, | |
| 0.6, | |
| [], | |
| True, | |
| True, | |
| False, | |
| False, | |
| True, | |
| False, | |
| ], | |
| [ | |
| "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/vehicles-1280x720.mp4", | |
| "rfdetr-small", | |
| "bytetrack", | |
| 0.2, | |
| 30, | |
| 0.3, | |
| 3, | |
| 0.1, | |
| 0.6, | |
| [], | |
| True, | |
| True, | |
| True, | |
| False, | |
| True, | |
| False, | |
| ], | |
| ] | |
| def _get_video_info(path: str) -> tuple[float, int]: | |
| """Return video duration in seconds and frame count using OpenCV.""" | |
| cap = cv2.VideoCapture(path) | |
| if not cap.isOpened(): | |
| raise gr.Error("Could not open the uploaded video.") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| if fps <= 0: | |
| raise gr.Error("Could not determine video frame rate.") | |
| return frame_count / fps, frame_count | |
| def _resolve_class_filter( | |
| classes: list[str] | None, | |
| class_names: list[str], | |
| ) -> list[int] | None: | |
| """Resolve class names to integer IDs.""" | |
| if not classes: | |
| return None | |
| name_to_id = {name: i for i, name in enumerate(class_names)} | |
| class_filter: list[int] = [] | |
| for name in classes: | |
| if name in name_to_id: | |
| class_filter.append(name_to_id[name]) | |
| return class_filter if class_filter else None | |
| def track( | |
| video_path: str, | |
| model_id: str, | |
| tracker_type: str, | |
| confidence: float, | |
| lost_track_buffer: int, | |
| track_activation_threshold: float, | |
| minimum_consecutive_frames: int, | |
| minimum_iou_threshold: float, | |
| high_conf_det_threshold: float, | |
| classes: list[str] | None = None, | |
| show_boxes: bool = True, | |
| show_ids: bool = True, | |
| show_labels: bool = False, | |
| show_confidence: bool = False, | |
| show_trajectories: bool = False, | |
| show_masks: bool = False, | |
| progress=gr.Progress(track_tqdm=True), | |
| ) -> str: | |
| """Run tracking on the uploaded video and return the output path.""" | |
| if video_path is None: | |
| raise gr.Error("Please upload a video.") | |
| duration, total_frames = _get_video_info(video_path) | |
| if duration > MAX_DURATION_SECONDS: | |
| raise gr.Error( | |
| f"Video is {duration:.1f}s long. " | |
| f"Maximum allowed duration is {MAX_DURATION_SECONDS}s." | |
| ) | |
| # Get pre-loaded model | |
| detection_model = LOADED_MODELS[model_id] | |
| class_names = getattr(detection_model, "class_names", []) | |
| # Resolve class filter | |
| class_filter = _resolve_class_filter(classes, class_names) | |
| # Create tracker instance and reset ID counter | |
| if tracker_type == "bytetrack": | |
| tracker = ByteTrackTracker( | |
| lost_track_buffer=lost_track_buffer, | |
| track_activation_threshold=track_activation_threshold, | |
| minimum_consecutive_frames=minimum_consecutive_frames, | |
| minimum_iou_threshold=minimum_iou_threshold, | |
| high_conf_det_threshold=high_conf_det_threshold, | |
| ) | |
| else: | |
| tracker = SORTTracker( | |
| lost_track_buffer=lost_track_buffer, | |
| track_activation_threshold=track_activation_threshold, | |
| minimum_consecutive_frames=minimum_consecutive_frames, | |
| minimum_iou_threshold=minimum_iou_threshold, | |
| ) | |
| tracker.reset() | |
| # Setup annotators | |
| annotators, label_annotator = _init_annotators( | |
| show_boxes=show_boxes, | |
| show_masks=show_masks, | |
| show_labels=show_labels, | |
| show_ids=show_ids, | |
| show_confidence=show_confidence, | |
| ) | |
| trace_annotator = None | |
| if show_trajectories: | |
| trace_annotator = sv.TraceAnnotator( | |
| color=COLOR_PALETTE, | |
| color_lookup=sv.ColorLookup.TRACK, | |
| ) | |
| # Setup output | |
| tmp_dir = tempfile.mkdtemp() | |
| output_path = str(Path(tmp_dir) / "output.mp4") | |
| # Get video info for output | |
| video_info = sv.VideoInfo.from_video_path(video_path) | |
| # Process video with progress bar | |
| frame_gen = frames_from_source(video_path) | |
| with sv.VideoSink(output_path, video_info=video_info) as sink: | |
| for frame_idx, frame in tqdm(frame_gen, total=total_frames, desc="Processing video..."): | |
| # Run detection | |
| predictions = detection_model(frame) | |
| if predictions: | |
| detections = predictions[0].to_supervision() | |
| # Filter by confidence | |
| if len(detections) > 0 and detections.confidence is not None: | |
| mask = detections.confidence >= confidence | |
| detections = detections[mask] | |
| # Filter by class | |
| if class_filter is not None and len(detections) > 0: | |
| mask = np.isin(detections.class_id, class_filter) | |
| detections = detections[mask] | |
| else: | |
| detections = sv.Detections.empty() | |
| # Run tracker | |
| tracked = tracker.update(detections) | |
| # Annotate frame | |
| annotated = frame.copy() | |
| if trace_annotator is not None: | |
| annotated = trace_annotator.annotate(annotated, tracked) | |
| for annotator in annotators: | |
| annotated = annotator.annotate(annotated, tracked) | |
| if label_annotator is not None: | |
| labeled = tracked[tracked.tracker_id != -1] | |
| labels = _format_labels( | |
| labeled, | |
| class_names, | |
| show_ids=show_ids, | |
| show_labels=show_labels, | |
| show_confidence=show_confidence, | |
| ) | |
| annotated = label_annotator.annotate(annotated, labeled, labels=labels) | |
| sink.write_frame(annotated) | |
| return output_path | |
| with gr.Blocks(title="Trackers Playground 🔥") as demo: | |
| gr.Markdown( | |
| "# Trackers Playground 🔥\n\n" | |
| "Upload a video, detect COCO objects with " | |
| "[RF-DETR](https://github.com/roboflow-ai/rf-detr) and track them with " | |
| "[Trackers](https://github.com/roboflow/trackers)." | |
| ) | |
| with gr.Row(): | |
| input_video = gr.Video(label="Input Video") | |
| output_video = gr.Video(label="Tracked Video") | |
| track_btn = gr.Button(value="Track", variant="primary") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=MODELS, | |
| value="rfdetr-small", | |
| label="Detection Model", | |
| ) | |
| tracker_dropdown = gr.Dropdown( | |
| choices=TRACKERS, | |
| value="bytetrack", | |
| label="Tracker", | |
| ) | |
| with gr.Accordion("Configuration", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Model") | |
| confidence_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="Detection Confidence", | |
| info="Minimum score for a detection to be kept.", | |
| ) | |
| class_filter = gr.CheckboxGroup( | |
| choices=COCO_CLASSES, | |
| value=[], | |
| label="Filter Classes", | |
| info="Only track selected classes. None selected means all.", | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### Tracker") | |
| lost_track_buffer_slider = gr.Slider( | |
| minimum=1, | |
| maximum=120, | |
| value=30, | |
| step=1, | |
| label="Lost Track Buffer", | |
| info="Frames to keep a lost track before removing it.", | |
| ) | |
| track_activation_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.3, | |
| step=0.05, | |
| label="Track Activation Threshold", | |
| info="Minimum score for a track to be activated.", | |
| ) | |
| min_consecutive_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=2, | |
| step=1, | |
| label="Minimum Consecutive Frames", | |
| info="Detections needed before a track is confirmed.", | |
| ) | |
| min_iou_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.05, | |
| label="Minimum IoU Threshold", | |
| info="Overlap required to match a detection to a track.", | |
| ) | |
| high_conf_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.6, | |
| step=0.05, | |
| label="High Confidence Detection Threshold", | |
| info="Detections above this are matched first (ByteTrack only).", | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### Visualization") | |
| show_boxes_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Show Boxes", | |
| info="Draw bounding boxes around detections.", | |
| ) | |
| show_ids_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Show IDs", | |
| info="Display track ID for each object.", | |
| ) | |
| show_labels_checkbox = gr.Checkbox( | |
| value=False, | |
| label="Show Labels", | |
| info="Display class name for each detection.", | |
| ) | |
| show_confidence_checkbox = gr.Checkbox( | |
| value=False, | |
| label="Show Confidence", | |
| info="Display detection confidence score.", | |
| ) | |
| show_trajectories_checkbox = gr.Checkbox( | |
| value=False, | |
| label="Show Trajectories", | |
| info="Draw motion path for each tracked object.", | |
| ) | |
| show_masks_checkbox = gr.Checkbox( | |
| value=False, | |
| label="Show Masks", | |
| info="Draw segmentation masks (seg models only).", | |
| ) | |
| gr.Examples( | |
| examples=VIDEO_EXAMPLES, | |
| fn=track, | |
| cache_examples=True, | |
| inputs=[ | |
| input_video, | |
| model_dropdown, | |
| tracker_dropdown, | |
| confidence_slider, | |
| lost_track_buffer_slider, | |
| track_activation_slider, | |
| min_consecutive_slider, | |
| min_iou_slider, | |
| high_conf_slider, | |
| class_filter, | |
| show_boxes_checkbox, | |
| show_ids_checkbox, | |
| show_labels_checkbox, | |
| show_confidence_checkbox, | |
| show_trajectories_checkbox, | |
| show_masks_checkbox, | |
| ], | |
| outputs=output_video, | |
| ) | |
| track_btn.click( | |
| fn=track, | |
| inputs=[ | |
| input_video, | |
| model_dropdown, | |
| tracker_dropdown, | |
| confidence_slider, | |
| lost_track_buffer_slider, | |
| track_activation_slider, | |
| min_consecutive_slider, | |
| min_iou_slider, | |
| high_conf_slider, | |
| class_filter, | |
| show_boxes_checkbox, | |
| show_ids_checkbox, | |
| show_labels_checkbox, | |
| show_confidence_checkbox, | |
| show_trajectories_checkbox, | |
| show_masks_checkbox, | |
| ], | |
| outputs=output_video, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |