import supervision as sv import torch import numpy as np from collections import defaultdict from rfdetr import RFDETRSeg2XLarge from PIL import Image import cv2 from scipy.optimize import linear_sum_assignment from .utils import ( mask_nms, toRGB, matcher_probs_custom_argmax, get_distance_cost_matrix, mask_iou, get_crops_from_masks ) from .view_transformer import ( get_players_court_xy ) from tqdm import tqdm from code import interact np.set_printoptions(suppress=True, precision=4) torch.set_printoptions(sci_mode=False) def indices_to_matches( cost_matrix, indices, thresh: float ): matched_cost = cost_matrix[tuple(zip(*indices))] matched_mask = matched_cost <= thresh matches = indices[matched_mask] unmatched_a = list(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) unmatched_b = list(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) return matches, unmatched_a, unmatched_b def linear_assignment( cost_matrix, thresh ): row_ind, col_ind = linear_sum_assignment(cost_matrix) indices = np.column_stack((row_ind, col_ind)) return indices_to_matches(cost_matrix, indices, thresh) class Tracker: def __init__( self, initial_detections:sv.Detections, initial_xy: np.ndarray, initial_frame: np.ndarray, matcher, hungarian_mask_threshold: float, hungarian_pos_threshold: float ): self.frame_id = 0 self.track_ids = list(range(len(initial_detections))) self.previous_detections = initial_detections self.previous_xy = initial_xy self.hungarian_mask_threshold = hungarian_mask_threshold self.hungarian_pos_threshold = hungarian_pos_threshold self.matcher = matcher '''Initialize track_ids of all 10 players''' self.all_players_detected = len(initial_detections) == 10 initial_detections.tracker_id = np.array(self.track_ids) self.frame_id_to_xy = { self.frame_id : dict(zip(initial_detections.tracker_id, initial_xy)) } # Keep one "base selfie" and one "latest selfie" of all players in memory. self.track_id_to_crop = defaultdict(list) for track_id, crop in zip(initial_detections.tracker_id, get_crops_from_masks(initial_frame, initial_detections.mask)): for _ in range(2): self.track_id_to_crop[track_id].append(crop) self.stats = { self.frame_id : { "detected_players" : len(initial_detections), "new_detections" : None, "all_players_detected" : self.all_players_detected, "mask_based_matches" : None, "position_based_matches" : None, "appearance_based_matches" : None, "unmatched" : None } } def update_tracks_with_new_detections(self, detections: sv.Detections, xy: np.ndarray, frame: np.ndarray): detections.tracker_id = -np.ones(shape=(len(detections)), dtype=np.int64) masks = detections.mask '''First Layer | Mask-based tracking: Safely track players based on their masks coordinates. When in doubt, leave the detections untracked''' # Cost_matrix_ij = 1 - IoU(mask_i, mask_j) null_track = self.previous_detections.tracker_id == -1 mask_cost_matrix = 1.0 - mask_iou(masks, self.previous_detections[~null_track].mask) matches, unmatched_rows_t, _ = linear_assignment(mask_cost_matrix, self.hungarian_mask_threshold) # Apply results detections.tracker_id[matches[:,0]] = self.previous_detections[~null_track].tracker_id[matches[:,1]] # Remainder unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1])) mask_based_matches = len(matches) if len(unmatched_rows_t) == 0: self.save_statistics(detections, xy, mask_based_matches) return '''Second Layer | Court-position-based tracking: Safely track remaining un-matched player based on their court (x,y) coordinates. ''' pos_based_matches = 0 dist_cost_matrix = get_distance_cost_matrix( xy, self.previous_xy[~null_track], ord = 2, # EUCLIDIAN DISTANCE ) dist_cost_matrix[matches[:,0], :] = 1e3 dist_cost_matrix[:, matches[:,1]] = 1e3 matches_, _, _ = linear_assignment(dist_cost_matrix, self.hungarian_pos_threshold) # Apply results for match_ in matches_: if match_[0] in matches[:,0]: continue detections.tracker_id[match_[0]] = self.previous_detections[~null_track].tracker_id[match_[1]] pos_based_matches += 1 # Remainder unmatched_rows_t = [i for i in range(len(detections)) if detections.tracker_id[i] == -1] unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1])) if len(unmatched_rows_t) == 0: self.save_statistics(detections, xy, mask_based_matches, pos_based_matches) return '''Third Layer | Appearance-based tracking: Use a vision model to match remaining player crops to their corresponding crop at t-1 ''' unmatched = 0 appearance_based_matches = 0 new_detections = 0 while len(unmatched_rows_t) > 0: unmatched_row_t = unmatched_rows_t.pop(0) # If there is only one un-matched mask at t-1 and t, they must correspond to the same player (assuming all players have been detected once, so there's no new player) if self.all_players_detected and len(unmatched_track_ids_t_1) == 1 and len(unmatched_rows_t) == 0: detections.tracker_id[unmatched_row_t] = unmatched_track_ids_t_1[0] unmatched_track_ids_t_1.pop(0) break '''Appearance-based tracking: track remaining un-matched players''' query_crop = get_crops_from_masks(frame, detections[unmatched_row_t].mask)[0] # Crop unmatched player at time t base_candidate_crops = [self.track_id_to_crop[t_id][0] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players latest_candidate_crops = [self.track_id_to_crop[t_id][1] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players probs = self.matcher.predict(query_crop, base_candidate_crops) probs = (probs + self.matcher.predict(query_crop, latest_candidate_crops)) / 2 prediction = matcher_probs_custom_argmax(probs) if prediction != len(base_candidate_crops): pred_track_id = unmatched_track_ids_t_1[prediction] detections.tracker_id[unmatched_row_t] = pred_track_id unmatched_track_ids_t_1.pop(prediction) appearance_based_matches += 1 # still unmatched -> (likely) a new player elif not(self.all_players_detected): new_track_id = max(self.track_ids) + 1 detections.tracker_id[unmatched_row_t] = new_track_id new_detections += 1 self.track_ids.append(new_track_id) self.all_players_detected = len(self.track_ids) == 10 else: unmatched += 1 self.save_statistics(detections, xy, mask_based_matches, pos_based_matches, appearance_based_matches, new_detections, unmatched) def save_statistics(self, detections, xy, mask_based_matches, pos_based_matches=0, appearance_based_matches=0, new_detections=0, unmatched=0): '''Update tracking statistics''' self.frame_id += 1 self.stats[self.frame_id] = { "detected_players" : len(detections), "all_players_detected" : self.all_players_detected, "mask_based_matches" : mask_based_matches, "position_based_matches" : pos_based_matches, "appearance_based_matches" : appearance_based_matches, "new_detections" : new_detections, "unmatched" : unmatched } for i in range(len(detections)): track_id = detections.tracker_id[i] if track_id != -1: self.track_id_to_crop[track_id][1] = get_crops_from_masks(frame, detections[i].mask)[0] self.previous_detections = detections self.previous_xy = xy if __name__ == "__main__": from basketball_analysis import Matcher from utils import show_annotations, annotate_frame from inference import get_model VIDEO_PATH = "DEN_SAC_1_2025.mp4" HUNGARIAN_MASK_THRESHOLD = 0.6 HUNGARIAN_POS_THRESHOLD = 2.0 SEGMENTATION_CONFIDENCE_THRESHOLD = 0.4 SEG_MODEL = RFDETRSeg2XLarge(resolution=1008, pretrain_weights="checkpoint_best_ema.pth") SEG_MODEL.optimize_for_inference() ROBOFLOW_API_KEY = "PUNfWgLHrHDufisOOaZp" KEYPOINT_DETECTION_MODEL_ID = "basketball-court-detection-2/14" KEYPOINT_MODEL = get_model(model_id=KEYPOINT_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY) KEYPOINT_COLOR = sv.Color.from_hex('#FF1493') matcher = Matcher(10,8, "DINOv2_small") sd = torch.load("matcher_tuned.pt") matcher.load_state_dict(sd) for p in matcher.parameters(): p.requires_grad_(False) matcher.eval(); def get_models_predictions(frame): # Segmentation detections = SEG_MODEL.predict(frame, threshold=SEGMENTATION_CONFIDENCE_THRESHOLD) keep = mask_nms(detections.mask, detections.confidence, iou_thresh=0.2) detections = detections[keep] if len(detections) > 10: # keep first 10 detections (10 highest confidence detections) detections = detections[:10] # X,Y coordinates retrieval court_xy = get_players_court_xy(frame, detections, KEYPOINT_MODEL) return detections, court_xy video_iterator = sv.get_video_frames_generator(VIDEO_PATH) frame = toRGB(next(video_iterator)) initial_detections, initial_xy = get_models_predictions(frame) history = [] tracker = Tracker(initial_detections, initial_xy, frame, matcher, HUNGARIAN_MASK_THRESHOLD, HUNGARIAN_POS_THRESHOLD) history.append(annotate_frame(frame, initial_detections)) for frame_id, frame in tqdm(enumerate(video_iterator, start=1)): frame = toRGB(frame) detections, xy = get_models_predictions(frame) tracker.update_tracks_with_new_detections(detections, xy, frame) history.append(annotate_frame(frame, detections)) if frame_id == 150: Image.fromarray(history[-1]).save("-1.png") Image.fromarray(history[0]).save("0.png") interact(local=locals())