File size: 10,964 Bytes
bbc0514 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 | import supervision as sv
import torch
import numpy as np
from collections import defaultdict
from rfdetr import RFDETRSeg2XLarge
from PIL import Image
import cv2
from scipy.optimize import linear_sum_assignment
from .utils import (
mask_nms,
toRGB,
matcher_probs_custom_argmax,
get_distance_cost_matrix,
mask_iou,
get_crops_from_masks
)
from .view_transformer import (
get_players_court_xy
)
from tqdm import tqdm
from code import interact
np.set_printoptions(suppress=True, precision=4)
torch.set_printoptions(sci_mode=False)
def indices_to_matches(
cost_matrix, indices, thresh: float
):
matched_cost = cost_matrix[tuple(zip(*indices))]
matched_mask = matched_cost <= thresh
matches = indices[matched_mask]
unmatched_a = list(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
unmatched_b = list(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
return matches, unmatched_a, unmatched_b
def linear_assignment(
cost_matrix, thresh
):
row_ind, col_ind = linear_sum_assignment(cost_matrix)
indices = np.column_stack((row_ind, col_ind))
return indices_to_matches(cost_matrix, indices, thresh)
class Tracker:
def __init__(
self,
initial_detections:sv.Detections,
initial_xy: np.ndarray,
initial_frame: np.ndarray,
matcher,
hungarian_mask_threshold: float,
hungarian_pos_threshold: float
):
self.frame_id = 0
self.track_ids = list(range(len(initial_detections)))
self.previous_detections = initial_detections
self.previous_xy = initial_xy
self.hungarian_mask_threshold = hungarian_mask_threshold
self.hungarian_pos_threshold = hungarian_pos_threshold
self.matcher = matcher
'''Initialize track_ids of all 10 players'''
self.all_players_detected = len(initial_detections) == 10
initial_detections.tracker_id = np.array(self.track_ids)
self.frame_id_to_xy = {
self.frame_id : dict(zip(initial_detections.tracker_id, initial_xy))
}
# Keep one "base selfie" and one "latest selfie" of all players in memory.
self.track_id_to_crop = defaultdict(list)
for track_id, crop in zip(initial_detections.tracker_id, get_crops_from_masks(initial_frame, initial_detections.mask)):
for _ in range(2):
self.track_id_to_crop[track_id].append(crop)
self.stats = {
self.frame_id : {
"detected_players" : len(initial_detections),
"new_detections" : None,
"all_players_detected" : self.all_players_detected,
"mask_based_matches" : None,
"position_based_matches" : None,
"appearance_based_matches" : None,
"unmatched" : None
}
}
def update_tracks_with_new_detections(self, detections: sv.Detections, xy: np.ndarray, frame: np.ndarray):
detections.tracker_id = -np.ones(shape=(len(detections)), dtype=np.int64)
masks = detections.mask
'''First Layer | Mask-based tracking:
Safely track players based on their masks coordinates. When in doubt, leave the detections untracked'''
# Cost_matrix_ij = 1 - IoU(mask_i, mask_j)
null_track = self.previous_detections.tracker_id == -1
mask_cost_matrix = 1.0 - mask_iou(masks, self.previous_detections[~null_track].mask)
matches, unmatched_rows_t, _ = linear_assignment(mask_cost_matrix, self.hungarian_mask_threshold)
# Apply results
detections.tracker_id[matches[:,0]] = self.previous_detections[~null_track].tracker_id[matches[:,1]]
# Remainder
unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
mask_based_matches = len(matches)
if len(unmatched_rows_t) == 0:
self.save_statistics(detections, xy, mask_based_matches)
return
'''Second Layer | Court-position-based tracking:
Safely track remaining un-matched player based on their court (x,y) coordinates.
'''
pos_based_matches = 0
dist_cost_matrix = get_distance_cost_matrix(
xy,
self.previous_xy[~null_track],
ord = 2, # EUCLIDIAN DISTANCE
)
dist_cost_matrix[matches[:,0], :] = 1e3
dist_cost_matrix[:, matches[:,1]] = 1e3
matches_, _, _ = linear_assignment(dist_cost_matrix, self.hungarian_pos_threshold)
# Apply results
for match_ in matches_:
if match_[0] in matches[:,0]:
continue
detections.tracker_id[match_[0]] = self.previous_detections[~null_track].tracker_id[match_[1]]
pos_based_matches += 1
# Remainder
unmatched_rows_t = [i for i in range(len(detections)) if detections.tracker_id[i] == -1]
unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
if len(unmatched_rows_t) == 0:
self.save_statistics(detections, xy, mask_based_matches, pos_based_matches)
return
'''Third Layer | Appearance-based tracking:
Use a vision model to match remaining player crops to their corresponding crop at t-1
'''
unmatched = 0
appearance_based_matches = 0
new_detections = 0
while len(unmatched_rows_t) > 0:
unmatched_row_t = unmatched_rows_t.pop(0)
# If there is only one un-matched mask at t-1 and t, they must correspond to the same player (assuming all players have been detected once, so there's no new player)
if self.all_players_detected and len(unmatched_track_ids_t_1) == 1 and len(unmatched_rows_t) == 0:
detections.tracker_id[unmatched_row_t] = unmatched_track_ids_t_1[0]
unmatched_track_ids_t_1.pop(0)
break
'''Appearance-based tracking: track remaining un-matched players'''
query_crop = get_crops_from_masks(frame, detections[unmatched_row_t].mask)[0] # Crop unmatched player at time t
base_candidate_crops = [self.track_id_to_crop[t_id][0] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
latest_candidate_crops = [self.track_id_to_crop[t_id][1] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
probs = self.matcher.predict(query_crop, base_candidate_crops)
probs = (probs + self.matcher.predict(query_crop, latest_candidate_crops)) / 2
prediction = matcher_probs_custom_argmax(probs)
if prediction != len(base_candidate_crops):
pred_track_id = unmatched_track_ids_t_1[prediction]
detections.tracker_id[unmatched_row_t] = pred_track_id
unmatched_track_ids_t_1.pop(prediction)
appearance_based_matches += 1
# still unmatched -> (likely) a new player
elif not(self.all_players_detected):
new_track_id = max(self.track_ids) + 1
detections.tracker_id[unmatched_row_t] = new_track_id
new_detections += 1
self.track_ids.append(new_track_id)
self.all_players_detected = len(self.track_ids) == 10
else:
unmatched += 1
self.save_statistics(detections, xy, mask_based_matches, pos_based_matches, appearance_based_matches, new_detections, unmatched)
def save_statistics(self, detections, xy, mask_based_matches, pos_based_matches=0, appearance_based_matches=0, new_detections=0, unmatched=0):
'''Update tracking statistics'''
self.frame_id += 1
self.stats[self.frame_id] = {
"detected_players" : len(detections),
"all_players_detected" : self.all_players_detected,
"mask_based_matches" : mask_based_matches,
"position_based_matches" : pos_based_matches,
"appearance_based_matches" : appearance_based_matches,
"new_detections" : new_detections,
"unmatched" : unmatched
}
for i in range(len(detections)):
track_id = detections.tracker_id[i]
if track_id != -1:
self.track_id_to_crop[track_id][1] = get_crops_from_masks(frame, detections[i].mask)[0]
self.previous_detections = detections
self.previous_xy = xy
if __name__ == "__main__":
from basketball_analysis import Matcher
from utils import show_annotations, annotate_frame
from inference import get_model
VIDEO_PATH = "DEN_SAC_1_2025.mp4"
HUNGARIAN_MASK_THRESHOLD = 0.6
HUNGARIAN_POS_THRESHOLD = 2.0
SEGMENTATION_CONFIDENCE_THRESHOLD = 0.4
SEG_MODEL = RFDETRSeg2XLarge(resolution=1008, pretrain_weights="checkpoint_best_ema.pth")
SEG_MODEL.optimize_for_inference()
ROBOFLOW_API_KEY = "PUNfWgLHrHDufisOOaZp"
KEYPOINT_DETECTION_MODEL_ID = "basketball-court-detection-2/14"
KEYPOINT_MODEL = get_model(model_id=KEYPOINT_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
KEYPOINT_COLOR = sv.Color.from_hex('#FF1493')
matcher = Matcher(10,8, "DINOv2_small")
sd = torch.load("matcher_tuned.pt")
matcher.load_state_dict(sd)
for p in matcher.parameters():
p.requires_grad_(False)
matcher.eval();
def get_models_predictions(frame):
# Segmentation
detections = SEG_MODEL.predict(frame, threshold=SEGMENTATION_CONFIDENCE_THRESHOLD)
keep = mask_nms(detections.mask, detections.confidence, iou_thresh=0.2)
detections = detections[keep]
if len(detections) > 10:
# keep first 10 detections (10 highest confidence detections)
detections = detections[:10]
# X,Y coordinates retrieval
court_xy = get_players_court_xy(frame, detections, KEYPOINT_MODEL)
return detections, court_xy
video_iterator = sv.get_video_frames_generator(VIDEO_PATH)
frame = toRGB(next(video_iterator))
initial_detections, initial_xy = get_models_predictions(frame)
history = []
tracker = Tracker(initial_detections, initial_xy, frame, matcher, HUNGARIAN_MASK_THRESHOLD, HUNGARIAN_POS_THRESHOLD)
history.append(annotate_frame(frame, initial_detections))
for frame_id, frame in tqdm(enumerate(video_iterator, start=1)):
frame = toRGB(frame)
detections, xy = get_models_predictions(frame)
tracker.update_tracks_with_new_detections(detections, xy, frame)
history.append(annotate_frame(frame, detections))
if frame_id == 150:
Image.fromarray(history[-1]).save("-1.png")
Image.fromarray(history[0]).save("0.png")
interact(local=locals())
|