poolay2's picture
Upload folder using huggingface_hub
bbc0514 verified
import supervision as sv
import torch
import numpy as np
from collections import defaultdict
from rfdetr import RFDETRSeg2XLarge
from PIL import Image
import cv2
from scipy.optimize import linear_sum_assignment
from .utils import (
mask_nms,
toRGB,
matcher_probs_custom_argmax,
get_distance_cost_matrix,
mask_iou,
get_crops_from_masks
)
from .view_transformer import (
get_players_court_xy
)
from tqdm import tqdm
from code import interact
np.set_printoptions(suppress=True, precision=4)
torch.set_printoptions(sci_mode=False)
def indices_to_matches(
cost_matrix, indices, thresh: float
):
matched_cost = cost_matrix[tuple(zip(*indices))]
matched_mask = matched_cost <= thresh
matches = indices[matched_mask]
unmatched_a = list(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
unmatched_b = list(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
return matches, unmatched_a, unmatched_b
def linear_assignment(
cost_matrix, thresh
):
row_ind, col_ind = linear_sum_assignment(cost_matrix)
indices = np.column_stack((row_ind, col_ind))
return indices_to_matches(cost_matrix, indices, thresh)
class Tracker:
def __init__(
self,
initial_detections:sv.Detections,
initial_xy: np.ndarray,
initial_frame: np.ndarray,
matcher,
hungarian_mask_threshold: float,
hungarian_pos_threshold: float
):
self.frame_id = 0
self.track_ids = list(range(len(initial_detections)))
self.previous_detections = initial_detections
self.previous_xy = initial_xy
self.hungarian_mask_threshold = hungarian_mask_threshold
self.hungarian_pos_threshold = hungarian_pos_threshold
self.matcher = matcher
'''Initialize track_ids of all 10 players'''
self.all_players_detected = len(initial_detections) == 10
initial_detections.tracker_id = np.array(self.track_ids)
self.frame_id_to_xy = {
self.frame_id : dict(zip(initial_detections.tracker_id, initial_xy))
}
# Keep one "base selfie" and one "latest selfie" of all players in memory.
self.track_id_to_crop = defaultdict(list)
for track_id, crop in zip(initial_detections.tracker_id, get_crops_from_masks(initial_frame, initial_detections.mask)):
for _ in range(2):
self.track_id_to_crop[track_id].append(crop)
self.stats = {
self.frame_id : {
"detected_players" : len(initial_detections),
"new_detections" : None,
"all_players_detected" : self.all_players_detected,
"mask_based_matches" : None,
"position_based_matches" : None,
"appearance_based_matches" : None,
"unmatched" : None
}
}
def update_tracks_with_new_detections(self, detections: sv.Detections, xy: np.ndarray, frame: np.ndarray):
detections.tracker_id = -np.ones(shape=(len(detections)), dtype=np.int64)
masks = detections.mask
'''First Layer | Mask-based tracking:
Safely track players based on their masks coordinates. When in doubt, leave the detections untracked'''
# Cost_matrix_ij = 1 - IoU(mask_i, mask_j)
null_track = self.previous_detections.tracker_id == -1
mask_cost_matrix = 1.0 - mask_iou(masks, self.previous_detections[~null_track].mask)
matches, unmatched_rows_t, _ = linear_assignment(mask_cost_matrix, self.hungarian_mask_threshold)
# Apply results
detections.tracker_id[matches[:,0]] = self.previous_detections[~null_track].tracker_id[matches[:,1]]
# Remainder
unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
mask_based_matches = len(matches)
if len(unmatched_rows_t) == 0:
self.save_statistics(detections, xy, mask_based_matches)
return
'''Second Layer | Court-position-based tracking:
Safely track remaining un-matched player based on their court (x,y) coordinates.
'''
pos_based_matches = 0
dist_cost_matrix = get_distance_cost_matrix(
xy,
self.previous_xy[~null_track],
ord = 2, # EUCLIDIAN DISTANCE
)
dist_cost_matrix[matches[:,0], :] = 1e3
dist_cost_matrix[:, matches[:,1]] = 1e3
matches_, _, _ = linear_assignment(dist_cost_matrix, self.hungarian_pos_threshold)
# Apply results
for match_ in matches_:
if match_[0] in matches[:,0]:
continue
detections.tracker_id[match_[0]] = self.previous_detections[~null_track].tracker_id[match_[1]]
pos_based_matches += 1
# Remainder
unmatched_rows_t = [i for i in range(len(detections)) if detections.tracker_id[i] == -1]
unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
if len(unmatched_rows_t) == 0:
self.save_statistics(detections, xy, mask_based_matches, pos_based_matches)
return
'''Third Layer | Appearance-based tracking:
Use a vision model to match remaining player crops to their corresponding crop at t-1
'''
unmatched = 0
appearance_based_matches = 0
new_detections = 0
while len(unmatched_rows_t) > 0:
unmatched_row_t = unmatched_rows_t.pop(0)
# If there is only one un-matched mask at t-1 and t, they must correspond to the same player (assuming all players have been detected once, so there's no new player)
if self.all_players_detected and len(unmatched_track_ids_t_1) == 1 and len(unmatched_rows_t) == 0:
detections.tracker_id[unmatched_row_t] = unmatched_track_ids_t_1[0]
unmatched_track_ids_t_1.pop(0)
break
'''Appearance-based tracking: track remaining un-matched players'''
query_crop = get_crops_from_masks(frame, detections[unmatched_row_t].mask)[0] # Crop unmatched player at time t
base_candidate_crops = [self.track_id_to_crop[t_id][0] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
latest_candidate_crops = [self.track_id_to_crop[t_id][1] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
probs = self.matcher.predict(query_crop, base_candidate_crops)
probs = (probs + self.matcher.predict(query_crop, latest_candidate_crops)) / 2
prediction = matcher_probs_custom_argmax(probs)
if prediction != len(base_candidate_crops):
pred_track_id = unmatched_track_ids_t_1[prediction]
detections.tracker_id[unmatched_row_t] = pred_track_id
unmatched_track_ids_t_1.pop(prediction)
appearance_based_matches += 1
# still unmatched -> (likely) a new player
elif not(self.all_players_detected):
new_track_id = max(self.track_ids) + 1
detections.tracker_id[unmatched_row_t] = new_track_id
new_detections += 1
self.track_ids.append(new_track_id)
self.all_players_detected = len(self.track_ids) == 10
else:
unmatched += 1
self.save_statistics(detections, xy, mask_based_matches, pos_based_matches, appearance_based_matches, new_detections, unmatched)
def save_statistics(self, detections, xy, mask_based_matches, pos_based_matches=0, appearance_based_matches=0, new_detections=0, unmatched=0):
'''Update tracking statistics'''
self.frame_id += 1
self.stats[self.frame_id] = {
"detected_players" : len(detections),
"all_players_detected" : self.all_players_detected,
"mask_based_matches" : mask_based_matches,
"position_based_matches" : pos_based_matches,
"appearance_based_matches" : appearance_based_matches,
"new_detections" : new_detections,
"unmatched" : unmatched
}
for i in range(len(detections)):
track_id = detections.tracker_id[i]
if track_id != -1:
self.track_id_to_crop[track_id][1] = get_crops_from_masks(frame, detections[i].mask)[0]
self.previous_detections = detections
self.previous_xy = xy
if __name__ == "__main__":
from basketball_analysis import Matcher
from utils import show_annotations, annotate_frame
from inference import get_model
VIDEO_PATH = "DEN_SAC_1_2025.mp4"
HUNGARIAN_MASK_THRESHOLD = 0.6
HUNGARIAN_POS_THRESHOLD = 2.0
SEGMENTATION_CONFIDENCE_THRESHOLD = 0.4
SEG_MODEL = RFDETRSeg2XLarge(resolution=1008, pretrain_weights="checkpoint_best_ema.pth")
SEG_MODEL.optimize_for_inference()
ROBOFLOW_API_KEY = "PUNfWgLHrHDufisOOaZp"
KEYPOINT_DETECTION_MODEL_ID = "basketball-court-detection-2/14"
KEYPOINT_MODEL = get_model(model_id=KEYPOINT_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
KEYPOINT_COLOR = sv.Color.from_hex('#FF1493')
matcher = Matcher(10,8, "DINOv2_small")
sd = torch.load("matcher_tuned.pt")
matcher.load_state_dict(sd)
for p in matcher.parameters():
p.requires_grad_(False)
matcher.eval();
def get_models_predictions(frame):
# Segmentation
detections = SEG_MODEL.predict(frame, threshold=SEGMENTATION_CONFIDENCE_THRESHOLD)
keep = mask_nms(detections.mask, detections.confidence, iou_thresh=0.2)
detections = detections[keep]
if len(detections) > 10:
# keep first 10 detections (10 highest confidence detections)
detections = detections[:10]
# X,Y coordinates retrieval
court_xy = get_players_court_xy(frame, detections, KEYPOINT_MODEL)
return detections, court_xy
video_iterator = sv.get_video_frames_generator(VIDEO_PATH)
frame = toRGB(next(video_iterator))
initial_detections, initial_xy = get_models_predictions(frame)
history = []
tracker = Tracker(initial_detections, initial_xy, frame, matcher, HUNGARIAN_MASK_THRESHOLD, HUNGARIAN_POS_THRESHOLD)
history.append(annotate_frame(frame, initial_detections))
for frame_id, frame in tqdm(enumerate(video_iterator, start=1)):
frame = toRGB(frame)
detections, xy = get_models_predictions(frame)
tracker.update_tracks_with_new_detections(detections, xy, frame)
history.append(annotate_frame(frame, detections))
if frame_id == 150:
Image.fromarray(history[-1]).save("-1.png")
Image.fromarray(history[0]).save("0.png")
interact(local=locals())