| from __future__ import annotations |
| import torch |
| import numpy as np |
| import supervision as sv |
| from pycocotools import mask as mask_utils |
| import cv2 |
| import ffmpeg |
| from PIL import Image |
| import numpy as np |
| from typing import List, Iterable |
| from matplotlib import pyplot as plt |
|
|
| class SAM2Tracker: |
| def __init__(self, predictor) -> None: |
| self.predictor = predictor |
| self._prompted = False |
|
|
| def prompt_first_frame(self, frame: np.ndarray, detections: sv.Detections) -> None: |
| if len(detections) == 0: |
| raise ValueError("detections must contain at least one box") |
|
|
| if detections.tracker_id is None: |
| detections.tracker_id = list(range(1, len(detections) + 1)) |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| self.predictor.load_first_frame(frame) |
| for xyxy, obj_id in zip(detections.xyxy, detections.tracker_id): |
| bbox = np.asarray([xyxy], dtype=np.float32) |
| self.predictor.add_new_prompt( |
| frame_idx=0, |
| obj_id=int(obj_id), |
| bbox=bbox, |
| ) |
|
|
| self._prompted = True |
|
|
| def propagate(self, frame: np.ndarray) -> sv.Detections: |
| if not self._prompted: |
| raise RuntimeError("Call prompt_first_frame before propagate") |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| tracker_ids, mask_logits = self.predictor.track(frame) |
|
|
| tracker_ids = np.asarray(tracker_ids, dtype=np.int32) |
| masks = (mask_logits > 0.0).cpu().numpy() |
| masks = np.squeeze(masks).astype(bool) |
|
|
| if masks.ndim == 2: |
| masks = masks[None, ...] |
|
|
| masks = np.array([ |
| sv.filter_segments_by_distance(mask, relative_distance=0.03, mode="edge") |
| for mask in masks |
| ]) |
|
|
| xyxy = sv.mask_to_xyxy(masks=masks) |
| detections = sv.Detections(xyxy=xyxy, mask=masks, tracker_id=tracker_ids) |
| return detections |
|
|
| def reset(self) -> None: |
| self._prompted = False |
|
|
| def get_crops_from_masks(frame: np.ndarray, masks: np.ndarray) -> list[np.ndarray]: |
| """ |
| Args:mask_index |
| frame: (H, W, 3) image |
| masks: (N, H, W) binary masks |
| |
| Returns: |
| List of cropped images, one per mask. Each crop is a rectangular |
| bounding box around the mask, with black pixels outside the mask. |
| """ |
| crops = [] |
|
|
| for mask in masks: |
|
|
| |
| ys, xs = np.where(mask) |
| if len(xs) == 0 or len(ys) == 0: |
| |
| crops.append(np.zeros((0, 0, 3), dtype=frame.dtype)) |
| continue |
|
|
| y_min, y_max = ys.min(), ys.max() + 1 |
| x_min, x_max = xs.min(), xs.max() + 1 |
|
|
| |
| frame_crop = frame[y_min:y_max, x_min:x_max] |
| mask_crop = mask[y_min:y_max, x_min:x_max] |
|
|
| |
| crop = np.zeros_like(frame_crop) |
| crop[mask_crop] = frame_crop[mask_crop] |
|
|
| crops.append(crop) |
|
|
| return crops |
|
|
| def f(detections: sv.Detections, track_history: dict, frame_index): |
| |
| for i in range(len(detections)): |
|
|
| mask = detections.mask[i] |
| rle = mask_utils.encode(np.asfortranarray(mask)) |
| track_history[int(detections.tracker_id[i])].append((frame_index, rle['counts'])) |
|
|
|
|
| def toRGB(img: np.ndarray): |
| return cv2.cvtColor(img, code=cv2.COLOR_BGR2RGB) |
|
|
| def read_frame_from_video(in_filename, frame_num): |
| raw_bytes, err = ( |
| ffmpeg |
| .input(in_filename) |
| .filter('select', 'gte(n,{})'.format(frame_num)) |
| .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') |
| .global_args('-loglevel', 'error') |
| .run(capture_stdout=True) |
| ) |
| assert len(raw_bytes) == 1080 * 1920 * 3 |
| return np.frombuffer(raw_bytes, np.uint8).reshape(1, 1080, 1920, 3).copy() |
|
|
| def read_consecutive_frames_from_video(in_filename, start_frame, num_frames) -> np.ndarray: |
|
|
| out, err = ffmpeg.input(in_filename)\ |
| .output( |
| 'pipe:1', |
| vf=f'select=between(n\\,{start_frame}\\,{start_frame + num_frames - 1})', |
| vsync=0, |
| vframes=num_frames, |
| format='rawvideo', |
| pix_fmt='rgb24' |
| ).global_args('-loglevel', 'error')\ |
| .run(capture_stdout=True, capture_stderr=True) |
|
|
| W, H = 1920, 1080 |
| frame_size = W * H * 3 |
| frames = np.frombuffer(out, np.uint8) |
|
|
| if frames.size != num_frames * frame_size: |
| raise RuntimeError( |
| f'Expected {num_frames * frame_size} bytes, got {frames.size}\n' |
| f'ffmpeg stderr:\n{err.decode()}' |
| ) |
|
|
| |
| return frames.reshape(num_frames, H, W, 3).copy() |
|
|
| def xywhn_to_xywh(xywhn:list, height:int, width:int): |
|
|
| x,y,w,h = xywhn |
|
|
| return [int(x * width), int(y * height), int(w * width), int(h * height)] |
|
|
| def crop_frame_at_mask_from_bbox(frame: np.ndarray, mask: np.ndarray, bbox: list) -> np.array: |
|
|
| x,y,w,h = bbox |
| crop = frame[y: y+h, x: x+w] |
| cropped_mask = mask[y: y+h, x: x+w] |
| |
| crop[~cropped_mask] = np.array([0,0,0], dtype=np.uint8) |
|
|
| return crop |
|
|
| def find_consecutive_streaks(nums: list|Iterable): |
|
|
| if isinstance(nums, Iterable): nums = list(nums) |
| if not nums: |
| return [] |
|
|
| streaks = [] |
| start = nums[0] |
| for i in range(1, len(nums)): |
| if nums[i] != nums[i-1] + 1: |
| stop = nums[i-1] |
| streaks.append(range(start, stop + 1)) |
| start = nums[i] |
|
|
| streaks.append(range(start, nums[-1] + 1)) |
| return streaks |
|
|
| def save_loss_history(fpath, loss:float): |
| |
| with open(fpath, "a+") as f: |
| f.write(f"{loss:.6f}\n") |
|
|
| def save_loss_history_plot(loss_history: list[float], fpath): |
|
|
| plt.plot(loss_history) |
| plt.savefig(fpath) |
|
|
| def save_checkpoint( |
| path, |
| model, |
| optimizer, |
| epoch, |
| step, |
| ): |
|
|
| ckpt = { |
| "model": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "epoch": epoch, |
| "step": step, |
| } |
| torch.save(ckpt, path) |
|
|
| def load_checkpoint( |
| path, |
| model, |
| optimizer, |
| device="cuda" |
| ): |
| ckpt = torch.load(path, map_location=device) |
|
|
| model.load_state_dict(ckpt["model"]) |
| optimizer.load_state_dict(ckpt["optimizer"]) |
|
|
| epoch = ckpt.get("epoch", 0) |
| step = ckpt.get("step", 0) |
|
|
| return epoch, step |
|
|
| def mask_iou_pair(m1, m2): |
| inter = np.logical_and(m1, m2).sum() |
| if inter == 0: |
| return 0.0 |
| union = m1.sum() + m2.sum() - inter |
| return inter / (union + 1e-6) |
|
|
|
|
| def mask_nms(masks, scores, iou_thresh=0.6): |
| order = np.argsort(-scores) |
| keep = [] |
| suppressed = np.zeros(len(masks), dtype=bool) |
|
|
| for i in order: |
| if suppressed[i]: |
| continue |
|
|
| keep.append(i) |
|
|
| for j in order: |
| if j <= i or suppressed[j]: |
| continue |
|
|
| iou = mask_iou_pair(masks[i], masks[j]) |
| if iou > iou_thresh: |
| suppressed[j] = True |
|
|
| return keep |
|
|
| def mask_iou(masks_t: np.ndarray, masks_t1): |
| |
| N, H, W = masks_t.shape |
| M = masks_t1.shape[0] |
|
|
| masks_t = masks_t.reshape(N, -1).astype(float) |
| masks_t1 = masks_t1.reshape(M, -1).astype(float) |
|
|
| |
| intersection = masks_t @ masks_t1.T |
|
|
| |
| area_t = masks_t.sum(1, keepdims=True) |
| area_t1 = masks_t1.sum(1, keepdims=True) |
|
|
| |
| union = area_t + area_t1.T - intersection |
|
|
| iou = intersection / (union + 1e-6) |
| return iou |
|
|
| COURT_KEYPOINT_COORDINATES = np.array([ |
| (0.0, 0.0), |
| (0.0, 2.99), |
| (0.0, 17.0), |
| (0.0, 33.01), |
| (0.0, 47.02), |
| (0.0, 50.0), |
| (5.25, 25.0), |
| (13.92, 2.99), |
| (13.92, 47.02), |
| (19.0, 17.0), |
| (19.0, 25.0), |
| (19.0, 33.01), |
| (27.4, 0.0), |
| (29.01, 25.0), |
| (27.4, 50.0), |
| (46.99, 0.0), |
| (46.99, 25.0), |
| (46.99, 50.0), |
| (66.61, 0.0), |
| (65.0, 25.0), |
| (66.61, 50.0), |
| (75.0, 17.0), |
| (75.0, 25.0), |
| (75.0, 33.01), |
| (80.09, 2.99), |
| (80.09, 47.02), |
| (88.75, 25.0), |
| (94.0, 0.0), |
| (94.0, 2.99), |
| (94.0, 17.0), |
| (94.0, 33.01), |
| (94.0, 47.02), |
| (94.0, 50.0) |
| ]) |
|
|
| def get_distance_cost_matrix(arr1:np.ndarray, arr2:np.ndarray, ord=1) : |
|
|
| cost_matrix = np.empty(shape=(len(arr1), len(arr2)), dtype=np.float64) |
|
|
| for i in range(len(arr1)): |
| cost_matrix[i] = np.linalg.norm(arr1[i] - arr2, ord=ord, axis=-1) |
|
|
| return torch.tensor(cost_matrix) |
|
|
| def matcher_probs_custom_argmax(probs:np.ndarray, confidence_threshold=0.7): |
| probs = probs.squeeze(0) |
| pred = probs.argmax() |
| |
| if pred == len(probs) - 1 and probs[pred] < confidence_threshold: |
| |
| second_best = probs[:-1].argmax() |
| if probs[second_best] > 1.0 - confidence_threshold - 0.05: |
| pred = second_best |
| |
| return pred |
|
|
| def show_annotations(frame_, detections_): |
| annotated_frame = frame_.copy() |
| annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_) |
| annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id)) |
| return Image.fromarray(annotated_frame) |
|
|
| def annotate_frame(frame_, detections_): |
| annotated_frame = frame_.copy() |
| annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_) |
| annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id)) |
| return annotated_frame |
|
|
| if __name__ == "__main__": |
| from code import interact |
| frames = read_consecutive_frames_from_video("nba_sample_videos/batch2/SAC_LAL_1.mp4", 199, 1) |
| |
| interact(local=locals()) |