| import supervision as sv |
| import torch |
| import numpy as np |
| from collections import defaultdict |
| from rfdetr import RFDETRSeg2XLarge |
| from PIL import Image |
| import cv2 |
| from scipy.optimize import linear_sum_assignment |
| from .utils import ( |
| mask_nms, |
| toRGB, |
| matcher_probs_custom_argmax, |
| get_distance_cost_matrix, |
| mask_iou, |
| get_crops_from_masks |
| ) |
| from .view_transformer import ( |
| get_players_court_xy |
| ) |
| from tqdm import tqdm |
| from code import interact |
|
|
| np.set_printoptions(suppress=True, precision=4) |
| torch.set_printoptions(sci_mode=False) |
|
|
| def indices_to_matches( |
| cost_matrix, indices, thresh: float |
| ): |
| matched_cost = cost_matrix[tuple(zip(*indices))] |
| matched_mask = matched_cost <= thresh |
|
|
| matches = indices[matched_mask] |
| unmatched_a = list(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) |
| unmatched_b = list(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) |
| return matches, unmatched_a, unmatched_b |
|
|
| def linear_assignment( |
| cost_matrix, thresh |
| ): |
| row_ind, col_ind = linear_sum_assignment(cost_matrix) |
| indices = np.column_stack((row_ind, col_ind)) |
|
|
| return indices_to_matches(cost_matrix, indices, thresh) |
|
|
| class Tracker: |
|
|
| def __init__( |
| self, |
| initial_detections:sv.Detections, |
| initial_xy: np.ndarray, |
| initial_frame: np.ndarray, |
| matcher, |
| hungarian_mask_threshold: float, |
| hungarian_pos_threshold: float |
| ): |
|
|
| self.frame_id = 0 |
| self.track_ids = list(range(len(initial_detections))) |
| self.previous_detections = initial_detections |
| self.previous_xy = initial_xy |
| self.hungarian_mask_threshold = hungarian_mask_threshold |
| self.hungarian_pos_threshold = hungarian_pos_threshold |
| self.matcher = matcher |
|
|
| '''Initialize track_ids of all 10 players''' |
| self.all_players_detected = len(initial_detections) == 10 |
| initial_detections.tracker_id = np.array(self.track_ids) |
| self.frame_id_to_xy = { |
| self.frame_id : dict(zip(initial_detections.tracker_id, initial_xy)) |
| } |
|
|
| |
| self.track_id_to_crop = defaultdict(list) |
| for track_id, crop in zip(initial_detections.tracker_id, get_crops_from_masks(initial_frame, initial_detections.mask)): |
| for _ in range(2): |
| self.track_id_to_crop[track_id].append(crop) |
|
|
| self.stats = { |
| self.frame_id : { |
| "detected_players" : len(initial_detections), |
| "new_detections" : None, |
| "all_players_detected" : self.all_players_detected, |
| "mask_based_matches" : None, |
| "position_based_matches" : None, |
| "appearance_based_matches" : None, |
| "unmatched" : None |
| } |
| } |
|
|
| def update_tracks_with_new_detections(self, detections: sv.Detections, xy: np.ndarray, frame: np.ndarray): |
|
|
| detections.tracker_id = -np.ones(shape=(len(detections)), dtype=np.int64) |
| masks = detections.mask |
|
|
| '''First Layer | Mask-based tracking: |
| Safely track players based on their masks coordinates. When in doubt, leave the detections untracked''' |
| |
| null_track = self.previous_detections.tracker_id == -1 |
| mask_cost_matrix = 1.0 - mask_iou(masks, self.previous_detections[~null_track].mask) |
| matches, unmatched_rows_t, _ = linear_assignment(mask_cost_matrix, self.hungarian_mask_threshold) |
|
|
| |
| detections.tracker_id[matches[:,0]] = self.previous_detections[~null_track].tracker_id[matches[:,1]] |
|
|
| |
| unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1])) |
| mask_based_matches = len(matches) |
| |
| if len(unmatched_rows_t) == 0: |
| self.save_statistics(detections, xy, mask_based_matches) |
| return |
|
|
| '''Second Layer | Court-position-based tracking: |
| Safely track remaining un-matched player based on their court (x,y) coordinates. |
| ''' |
| pos_based_matches = 0 |
| dist_cost_matrix = get_distance_cost_matrix( |
| xy, |
| self.previous_xy[~null_track], |
| ord = 2, |
| ) |
| dist_cost_matrix[matches[:,0], :] = 1e3 |
| dist_cost_matrix[:, matches[:,1]] = 1e3 |
| |
| matches_, _, _ = linear_assignment(dist_cost_matrix, self.hungarian_pos_threshold) |
|
|
| |
| for match_ in matches_: |
| if match_[0] in matches[:,0]: |
| continue |
| detections.tracker_id[match_[0]] = self.previous_detections[~null_track].tracker_id[match_[1]] |
| pos_based_matches += 1 |
| |
| |
| unmatched_rows_t = [i for i in range(len(detections)) if detections.tracker_id[i] == -1] |
| unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1])) |
|
|
| if len(unmatched_rows_t) == 0: |
| self.save_statistics(detections, xy, mask_based_matches, pos_based_matches) |
| return |
|
|
| '''Third Layer | Appearance-based tracking: |
| Use a vision model to match remaining player crops to their corresponding crop at t-1 |
| ''' |
|
|
| unmatched = 0 |
| appearance_based_matches = 0 |
| new_detections = 0 |
|
|
| while len(unmatched_rows_t) > 0: |
|
|
| unmatched_row_t = unmatched_rows_t.pop(0) |
|
|
| |
| if self.all_players_detected and len(unmatched_track_ids_t_1) == 1 and len(unmatched_rows_t) == 0: |
| detections.tracker_id[unmatched_row_t] = unmatched_track_ids_t_1[0] |
| unmatched_track_ids_t_1.pop(0) |
| break |
| |
| '''Appearance-based tracking: track remaining un-matched players''' |
| query_crop = get_crops_from_masks(frame, detections[unmatched_row_t].mask)[0] |
| base_candidate_crops = [self.track_id_to_crop[t_id][0] for t_id in unmatched_track_ids_t_1] |
| latest_candidate_crops = [self.track_id_to_crop[t_id][1] for t_id in unmatched_track_ids_t_1] |
|
|
| probs = self.matcher.predict(query_crop, base_candidate_crops) |
| probs = (probs + self.matcher.predict(query_crop, latest_candidate_crops)) / 2 |
| prediction = matcher_probs_custom_argmax(probs) |
|
|
| if prediction != len(base_candidate_crops): |
| pred_track_id = unmatched_track_ids_t_1[prediction] |
| detections.tracker_id[unmatched_row_t] = pred_track_id |
|
|
| unmatched_track_ids_t_1.pop(prediction) |
| appearance_based_matches += 1 |
|
|
| |
| elif not(self.all_players_detected): |
| new_track_id = max(self.track_ids) + 1 |
| detections.tracker_id[unmatched_row_t] = new_track_id |
| |
| new_detections += 1 |
| self.track_ids.append(new_track_id) |
| self.all_players_detected = len(self.track_ids) == 10 |
| |
| else: |
| unmatched += 1 |
| |
| self.save_statistics(detections, xy, mask_based_matches, pos_based_matches, appearance_based_matches, new_detections, unmatched) |
| |
| def save_statistics(self, detections, xy, mask_based_matches, pos_based_matches=0, appearance_based_matches=0, new_detections=0, unmatched=0): |
| '''Update tracking statistics''' |
| self.frame_id += 1 |
| self.stats[self.frame_id] = { |
| "detected_players" : len(detections), |
| "all_players_detected" : self.all_players_detected, |
| "mask_based_matches" : mask_based_matches, |
| "position_based_matches" : pos_based_matches, |
| "appearance_based_matches" : appearance_based_matches, |
| "new_detections" : new_detections, |
| "unmatched" : unmatched |
| } |
|
|
| for i in range(len(detections)): |
| track_id = detections.tracker_id[i] |
| if track_id != -1: |
| self.track_id_to_crop[track_id][1] = get_crops_from_masks(frame, detections[i].mask)[0] |
| self.previous_detections = detections |
| self.previous_xy = xy |
|
|
| if __name__ == "__main__": |
|
|
| from basketball_analysis import Matcher |
| from utils import show_annotations, annotate_frame |
| from inference import get_model |
|
|
| VIDEO_PATH = "DEN_SAC_1_2025.mp4" |
| HUNGARIAN_MASK_THRESHOLD = 0.6 |
| HUNGARIAN_POS_THRESHOLD = 2.0 |
|
|
| SEGMENTATION_CONFIDENCE_THRESHOLD = 0.4 |
| SEG_MODEL = RFDETRSeg2XLarge(resolution=1008, pretrain_weights="checkpoint_best_ema.pth") |
| SEG_MODEL.optimize_for_inference() |
|
|
| ROBOFLOW_API_KEY = "PUNfWgLHrHDufisOOaZp" |
| KEYPOINT_DETECTION_MODEL_ID = "basketball-court-detection-2/14" |
| KEYPOINT_MODEL = get_model(model_id=KEYPOINT_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY) |
| KEYPOINT_COLOR = sv.Color.from_hex('#FF1493') |
|
|
| matcher = Matcher(10,8, "DINOv2_small") |
| sd = torch.load("matcher_tuned.pt") |
| matcher.load_state_dict(sd) |
|
|
| for p in matcher.parameters(): |
| p.requires_grad_(False) |
| matcher.eval(); |
|
|
| def get_models_predictions(frame): |
|
|
| |
| detections = SEG_MODEL.predict(frame, threshold=SEGMENTATION_CONFIDENCE_THRESHOLD) |
| keep = mask_nms(detections.mask, detections.confidence, iou_thresh=0.2) |
| detections = detections[keep] |
| if len(detections) > 10: |
| |
| detections = detections[:10] |
|
|
| |
| court_xy = get_players_court_xy(frame, detections, KEYPOINT_MODEL) |
|
|
| return detections, court_xy |
|
|
| video_iterator = sv.get_video_frames_generator(VIDEO_PATH) |
| frame = toRGB(next(video_iterator)) |
| initial_detections, initial_xy = get_models_predictions(frame) |
|
|
| history = [] |
| tracker = Tracker(initial_detections, initial_xy, frame, matcher, HUNGARIAN_MASK_THRESHOLD, HUNGARIAN_POS_THRESHOLD) |
| history.append(annotate_frame(frame, initial_detections)) |
|
|
| for frame_id, frame in tqdm(enumerate(video_iterator, start=1)): |
|
|
| frame = toRGB(frame) |
| detections, xy = get_models_predictions(frame) |
| tracker.update_tracks_with_new_detections(detections, xy, frame) |
| history.append(annotate_frame(frame, detections)) |
| if frame_id == 150: |
| Image.fromarray(history[-1]).save("-1.png") |
| Image.fromarray(history[0]).save("0.png") |
| interact(local=locals()) |
| |