from __future__ import annotations import torch import numpy as np import supervision as sv from pycocotools import mask as mask_utils import cv2 import ffmpeg from PIL import Image import numpy as np from typing import List, Iterable from matplotlib import pyplot as plt class SAM2Tracker: def __init__(self, predictor) -> None: self.predictor = predictor self._prompted = False def prompt_first_frame(self, frame: np.ndarray, detections: sv.Detections) -> None: if len(detections) == 0: raise ValueError("detections must contain at least one box") if detections.tracker_id is None: detections.tracker_id = list(range(1, len(detections) + 1)) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): self.predictor.load_first_frame(frame) for xyxy, obj_id in zip(detections.xyxy, detections.tracker_id): bbox = np.asarray([xyxy], dtype=np.float32) self.predictor.add_new_prompt( frame_idx=0, obj_id=int(obj_id), bbox=bbox, ) self._prompted = True def propagate(self, frame: np.ndarray) -> sv.Detections: if not self._prompted: raise RuntimeError("Call prompt_first_frame before propagate") with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): tracker_ids, mask_logits = self.predictor.track(frame) tracker_ids = np.asarray(tracker_ids, dtype=np.int32) masks = (mask_logits > 0.0).cpu().numpy() masks = np.squeeze(masks).astype(bool) if masks.ndim == 2: masks = masks[None, ...] masks = np.array([ sv.filter_segments_by_distance(mask, relative_distance=0.03, mode="edge") for mask in masks ]) xyxy = sv.mask_to_xyxy(masks=masks) detections = sv.Detections(xyxy=xyxy, mask=masks, tracker_id=tracker_ids) return detections def reset(self) -> None: self._prompted = False def get_crops_from_masks(frame: np.ndarray, masks: np.ndarray) -> list[np.ndarray]: """ Args:mask_index frame: (H, W, 3) image masks: (N, H, W) binary masks Returns: List of cropped images, one per mask. Each crop is a rectangular bounding box around the mask, with black pixels outside the mask. """ crops = [] for mask in masks: # Find bounding box of the mask ys, xs = np.where(mask) if len(xs) == 0 or len(ys) == 0: # Empty mask → skip or return empty crop crops.append(np.zeros((0, 0, 3), dtype=frame.dtype)) continue y_min, y_max = ys.min(), ys.max() + 1 x_min, x_max = xs.min(), xs.max() + 1 # Crop the frame and mask frame_crop = frame[y_min:y_max, x_min:x_max] mask_crop = mask[y_min:y_max, x_min:x_max] # Apply mask: keep pixels where mask is True, else black crop = np.zeros_like(frame_crop) crop[mask_crop] = frame_crop[mask_crop] crops.append(crop) return crops def f(detections: sv.Detections, track_history: dict, frame_index): for i in range(len(detections)): mask = detections.mask[i] rle = mask_utils.encode(np.asfortranarray(mask)) track_history[int(detections.tracker_id[i])].append((frame_index, rle['counts'])) def toRGB(img: np.ndarray): return cv2.cvtColor(img, code=cv2.COLOR_BGR2RGB) def read_frame_from_video(in_filename, frame_num): raw_bytes, err = ( ffmpeg .input(in_filename) .filter('select', 'gte(n,{})'.format(frame_num)) .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') .global_args('-loglevel', 'error') .run(capture_stdout=True) ) assert len(raw_bytes) == 1080 * 1920 * 3 return np.frombuffer(raw_bytes, np.uint8).reshape(1, 1080, 1920, 3).copy() def read_consecutive_frames_from_video(in_filename, start_frame, num_frames) -> np.ndarray: out, err = ffmpeg.input(in_filename)\ .output( 'pipe:1', vf=f'select=between(n\\,{start_frame}\\,{start_frame + num_frames - 1})', vsync=0, vframes=num_frames, format='rawvideo', pix_fmt='rgb24' ).global_args('-loglevel', 'error')\ .run(capture_stdout=True, capture_stderr=True) W, H = 1920, 1080 frame_size = W * H * 3 frames = np.frombuffer(out, np.uint8) if frames.size != num_frames * frame_size: raise RuntimeError( f'Expected {num_frames * frame_size} bytes, got {frames.size}\n' f'ffmpeg stderr:\n{err.decode()}' ) # frames.setflags(write=True) return frames.reshape(num_frames, H, W, 3).copy() def xywhn_to_xywh(xywhn:list, height:int, width:int): x,y,w,h = xywhn return [int(x * width), int(y * height), int(w * width), int(h * height)] def crop_frame_at_mask_from_bbox(frame: np.ndarray, mask: np.ndarray, bbox: list) -> np.array: x,y,w,h = bbox crop = frame[y: y+h, x: x+w] cropped_mask = mask[y: y+h, x: x+w] # from code import interact; interact(local=locals()) crop[~cropped_mask] = np.array([0,0,0], dtype=np.uint8) return crop def find_consecutive_streaks(nums: list|Iterable): if isinstance(nums, Iterable): nums = list(nums) if not nums: return [] streaks = [] start = nums[0] for i in range(1, len(nums)): if nums[i] != nums[i-1] + 1: stop = nums[i-1] streaks.append(range(start, stop + 1)) start = nums[i] streaks.append(range(start, nums[-1] + 1)) return streaks def save_loss_history(fpath, loss:float): with open(fpath, "a+") as f: f.write(f"{loss:.6f}\n") def save_loss_history_plot(loss_history: list[float], fpath): plt.plot(loss_history) plt.savefig(fpath) def save_checkpoint( path, model, optimizer, epoch, step, ): ckpt = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "step": step, } torch.save(ckpt, path) def load_checkpoint( path, model, optimizer, device="cuda" ): ckpt = torch.load(path, map_location=device) model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) epoch = ckpt.get("epoch", 0) step = ckpt.get("step", 0) return epoch, step def mask_iou_pair(m1, m2): inter = np.logical_and(m1, m2).sum() if inter == 0: return 0.0 union = m1.sum() + m2.sum() - inter return inter / (union + 1e-6) def mask_nms(masks, scores, iou_thresh=0.6): order = np.argsort(-scores) keep = [] suppressed = np.zeros(len(masks), dtype=bool) for i in order: if suppressed[i]: continue keep.append(i) for j in order: if j <= i or suppressed[j]: continue iou = mask_iou_pair(masks[i], masks[j]) if iou > iou_thresh: suppressed[j] = True return keep def mask_iou(masks_t: np.ndarray, masks_t1): # Flatten N, H, W = masks_t.shape M = masks_t1.shape[0] masks_t = masks_t.reshape(N, -1).astype(float) # (N, HW) masks_t1 = masks_t1.reshape(M, -1).astype(float) # (M, HW) # Intersection: (N, M) intersection = masks_t @ masks_t1.T # Areas area_t = masks_t.sum(1, keepdims=True) # (N, 1) area_t1 = masks_t1.sum(1, keepdims=True) # (M, 1) # Union union = area_t + area_t1.T - intersection iou = intersection / (union + 1e-6) return iou # (N, M) COURT_KEYPOINT_COORDINATES = np.array([ (0.0, 0.0), (0.0, 2.99), (0.0, 17.0), (0.0, 33.01), (0.0, 47.02), (0.0, 50.0), (5.25, 25.0), (13.92, 2.99), (13.92, 47.02), (19.0, 17.0), (19.0, 25.0), (19.0, 33.01), (27.4, 0.0), (29.01, 25.0), (27.4, 50.0), (46.99, 0.0), (46.99, 25.0), (46.99, 50.0), (66.61, 0.0), (65.0, 25.0), (66.61, 50.0), (75.0, 17.0), (75.0, 25.0), (75.0, 33.01), (80.09, 2.99), (80.09, 47.02), (88.75, 25.0), (94.0, 0.0), (94.0, 2.99), (94.0, 17.0), (94.0, 33.01), (94.0, 47.02), (94.0, 50.0) ]) def get_distance_cost_matrix(arr1:np.ndarray, arr2:np.ndarray, ord=1) : cost_matrix = np.empty(shape=(len(arr1), len(arr2)), dtype=np.float64) for i in range(len(arr1)): cost_matrix[i] = np.linalg.norm(arr1[i] - arr2, ord=ord, axis=-1) return torch.tensor(cost_matrix) def matcher_probs_custom_argmax(probs:np.ndarray, confidence_threshold=0.7): probs = probs.squeeze(0) pred = probs.argmax() # if matcher predicts the null prediction, but it is not confident if pred == len(probs) - 1 and probs[pred] < confidence_threshold: # predict the second most confident prediction if it has high weight second_best = probs[:-1].argmax() if probs[second_best] > 1.0 - confidence_threshold - 0.05: pred = second_best return pred def show_annotations(frame_, detections_): annotated_frame = frame_.copy() annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_) annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id)) return Image.fromarray(annotated_frame) def annotate_frame(frame_, detections_): annotated_frame = frame_.copy() annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_) annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id)) return annotated_frame if __name__ == "__main__": from code import interact frames = read_consecutive_frames_from_video("nba_sample_videos/batch2/SAC_LAL_1.mp4", 199, 1) # crop_frame_at_mask_from_bbox(np.zeros((1080, 1920, 3)), ) interact(local=locals())