poolay2 commited on
Commit
1e857bf
·
verified ·
1 Parent(s): 8235ecf

Upload tracking.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tracking.py +270 -0
tracking.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import supervision as sv
2
+ import torch
3
+ import numpy as np
4
+ from collections import defaultdict
5
+ from rfdetr import RFDETRSeg2XLarge
6
+ from PIL import Image
7
+ import cv2
8
+ from scipy.optimize import linear_sum_assignment
9
+ from .utils import (
10
+ mask_nms,
11
+ toRGB,
12
+ matcher_probs_custom_argmax,
13
+ get_distance_cost_matrix,
14
+ mask_iou,
15
+ get_crops_from_masks
16
+ )
17
+ from .view_transformer import (
18
+ get_players_court_xy
19
+ )
20
+ from tqdm import tqdm
21
+ from code import interact
22
+
23
+ np.set_printoptions(suppress=True, precision=4)
24
+ torch.set_printoptions(sci_mode=False)
25
+
26
+ def indices_to_matches(
27
+ cost_matrix, indices, thresh: float
28
+ ):
29
+ matched_cost = cost_matrix[tuple(zip(*indices))]
30
+ matched_mask = matched_cost <= thresh
31
+
32
+ matches = indices[matched_mask]
33
+ unmatched_a = list(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
34
+ unmatched_b = list(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
35
+ return matches, unmatched_a, unmatched_b
36
+
37
+ def linear_assignment(
38
+ cost_matrix, thresh
39
+ ):
40
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
41
+ indices = np.column_stack((row_ind, col_ind))
42
+
43
+ return indices_to_matches(cost_matrix, indices, thresh)
44
+
45
+ class Tracker:
46
+
47
+ def __init__(
48
+ self,
49
+ initial_detections:sv.Detections,
50
+ initial_xy: np.ndarray,
51
+ initial_frame: np.ndarray,
52
+ matcher,
53
+ hungarian_mask_threshold: float,
54
+ hungarian_pos_threshold: float
55
+ ):
56
+
57
+ self.frame_id = 0
58
+ self.track_ids = list(range(len(initial_detections)))
59
+ self.previous_detections = initial_detections
60
+ self.previous_xy = initial_xy
61
+ self.hungarian_mask_threshold = hungarian_mask_threshold
62
+ self.hungarian_pos_threshold = hungarian_pos_threshold
63
+ self.matcher = matcher
64
+
65
+ '''Initialize track_ids of all 10 players'''
66
+ self.all_players_detected = len(initial_detections) == 10
67
+ initial_detections.tracker_id = np.array(self.track_ids)
68
+ self.frame_id_to_xy = {
69
+ self.frame_id : dict(zip(initial_detections.tracker_id, initial_xy))
70
+ }
71
+
72
+ # Keep one "base selfie" and one "latest selfie" of all players in memory.
73
+ self.track_id_to_crop = defaultdict(list)
74
+ for track_id, crop in zip(initial_detections.tracker_id, get_crops_from_masks(initial_frame, initial_detections.mask)):
75
+ for _ in range(2):
76
+ self.track_id_to_crop[track_id].append(crop)
77
+
78
+ self.stats = {
79
+ self.frame_id : {
80
+ "detected_players" : len(initial_detections),
81
+ "new_detections" : None,
82
+ "all_players_detected" : self.all_players_detected,
83
+ "mask_based_matches" : None,
84
+ "position_based_matches" : None,
85
+ "appearance_based_matches" : None,
86
+ "unmatched" : None
87
+ }
88
+ }
89
+
90
+ def update_tracks_with_new_detections(self, detections: sv.Detections, xy: np.ndarray, frame: np.ndarray):
91
+
92
+ detections.tracker_id = -np.ones(shape=(len(detections)), dtype=np.int64)
93
+ masks = detections.mask
94
+
95
+ '''First Layer | Mask-based tracking:
96
+ Safely track players based on their masks coordinates. When in doubt, leave the detections untracked'''
97
+ # Cost_matrix_ij = 1 - IoU(mask_i, mask_j)
98
+ null_track = self.previous_detections.tracker_id == -1
99
+ mask_cost_matrix = 1.0 - mask_iou(masks, self.previous_detections[~null_track].mask)
100
+ matches, unmatched_rows_t, _ = linear_assignment(mask_cost_matrix, self.hungarian_mask_threshold)
101
+
102
+ # Apply results
103
+ detections.tracker_id[matches[:,0]] = self.previous_detections[~null_track].tracker_id[matches[:,1]]
104
+
105
+ # Remainder
106
+ unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
107
+ mask_based_matches = len(matches)
108
+
109
+ if len(unmatched_rows_t) == 0:
110
+ self.save_statistics(detections, xy, mask_based_matches)
111
+ return
112
+
113
+ '''Second Layer | Court-position-based tracking:
114
+ Safely track remaining un-matched player based on their court (x,y) coordinates.
115
+ '''
116
+ pos_based_matches = 0
117
+ dist_cost_matrix = get_distance_cost_matrix(
118
+ xy,
119
+ self.previous_xy[~null_track],
120
+ ord = 2, # EUCLIDIAN DISTANCE
121
+ )
122
+ dist_cost_matrix[matches[:,0], :] = 1e3
123
+ dist_cost_matrix[:, matches[:,1]] = 1e3
124
+
125
+ matches_, _, _ = linear_assignment(dist_cost_matrix, self.hungarian_pos_threshold)
126
+
127
+ # Apply results
128
+ for match_ in matches_:
129
+ if match_[0] in matches[:,0]:
130
+ continue
131
+ detections.tracker_id[match_[0]] = self.previous_detections[~null_track].tracker_id[match_[1]]
132
+ pos_based_matches += 1
133
+
134
+ # Remainder
135
+ unmatched_rows_t = [i for i in range(len(detections)) if detections.tracker_id[i] == -1]
136
+ unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
137
+
138
+ if len(unmatched_rows_t) == 0:
139
+ self.save_statistics(detections, xy, mask_based_matches, pos_based_matches)
140
+ return
141
+
142
+ '''Third Layer | Appearance-based tracking:
143
+ Use a vision model to match remaining player crops to their corresponding crop at t-1
144
+ '''
145
+
146
+ unmatched = 0
147
+ appearance_based_matches = 0
148
+ new_detections = 0
149
+
150
+ while len(unmatched_rows_t) > 0:
151
+
152
+ unmatched_row_t = unmatched_rows_t.pop(0)
153
+
154
+ # If there is only one un-matched mask at t-1 and t, they must correspond to the same player (assuming all players have been detected once, so there's no new player)
155
+ if self.all_players_detected and len(unmatched_track_ids_t_1) == 1 and len(unmatched_rows_t) == 0:
156
+ detections.tracker_id[unmatched_row_t] = unmatched_track_ids_t_1[0]
157
+ unmatched_track_ids_t_1.pop(0)
158
+ break
159
+
160
+ '''Appearance-based tracking: track remaining un-matched players'''
161
+ query_crop = get_crops_from_masks(frame, detections[unmatched_row_t].mask)[0] # Crop unmatched player at time t
162
+ base_candidate_crops = [self.track_id_to_crop[t_id][0] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
163
+ latest_candidate_crops = [self.track_id_to_crop[t_id][1] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
164
+
165
+ probs = self.matcher.predict(query_crop, base_candidate_crops)
166
+ probs = (probs + self.matcher.predict(query_crop, latest_candidate_crops)) / 2
167
+ prediction = matcher_probs_custom_argmax(probs)
168
+
169
+ if prediction != len(base_candidate_crops):
170
+ pred_track_id = unmatched_track_ids_t_1[prediction]
171
+ detections.tracker_id[unmatched_row_t] = pred_track_id
172
+
173
+ unmatched_track_ids_t_1.pop(prediction)
174
+ appearance_based_matches += 1
175
+
176
+ # still unmatched -> (likely) a new player
177
+ elif not(self.all_players_detected):
178
+ new_track_id = max(self.track_ids) + 1
179
+ detections.tracker_id[unmatched_row_t] = new_track_id
180
+
181
+ new_detections += 1
182
+ self.track_ids.append(new_track_id)
183
+ self.all_players_detected = len(self.track_ids) == 10
184
+
185
+ else:
186
+ unmatched += 1
187
+
188
+ self.save_statistics(detections, xy, mask_based_matches, pos_based_matches, appearance_based_matches, new_detections, unmatched)
189
+
190
+ def save_statistics(self, detections, xy, mask_based_matches, pos_based_matches=0, appearance_based_matches=0, new_detections=0, unmatched=0):
191
+ '''Update tracking statistics'''
192
+ self.frame_id += 1
193
+ self.stats[self.frame_id] = {
194
+ "detected_players" : len(detections),
195
+ "all_players_detected" : self.all_players_detected,
196
+ "mask_based_matches" : mask_based_matches,
197
+ "position_based_matches" : pos_based_matches,
198
+ "appearance_based_matches" : appearance_based_matches,
199
+ "new_detections" : new_detections,
200
+ "unmatched" : unmatched
201
+ }
202
+
203
+ for i in range(len(detections)):
204
+ track_id = detections.tracker_id[i]
205
+ if track_id != -1:
206
+ self.track_id_to_crop[track_id][1] = get_crops_from_masks(frame, detections[i].mask)[0]
207
+ self.previous_detections = detections
208
+ self.previous_xy = xy
209
+
210
+ if __name__ == "__main__":
211
+
212
+ from basketball_analysis import Matcher
213
+ from utils import show_annotations, annotate_frame
214
+ from inference import get_model
215
+
216
+ VIDEO_PATH = "DEN_SAC_1_2025.mp4"
217
+ HUNGARIAN_MASK_THRESHOLD = 0.6
218
+ HUNGARIAN_POS_THRESHOLD = 2.0
219
+
220
+ SEGMENTATION_CONFIDENCE_THRESHOLD = 0.4
221
+ SEG_MODEL = RFDETRSeg2XLarge(resolution=1008, pretrain_weights="checkpoint_best_ema.pth")
222
+ SEG_MODEL.optimize_for_inference()
223
+
224
+ ROBOFLOW_API_KEY = "PUNfWgLHrHDufisOOaZp"
225
+ KEYPOINT_DETECTION_MODEL_ID = "basketball-court-detection-2/14"
226
+ KEYPOINT_MODEL = get_model(model_id=KEYPOINT_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
227
+ KEYPOINT_COLOR = sv.Color.from_hex('#FF1493')
228
+
229
+ matcher = Matcher(10,8, "DINOv2_small")
230
+ sd = torch.load("matcher_tuned.pt")
231
+ matcher.load_state_dict(sd)
232
+
233
+ for p in matcher.parameters():
234
+ p.requires_grad_(False)
235
+ matcher.eval();
236
+
237
+ def get_models_predictions(frame):
238
+
239
+ # Segmentation
240
+ detections = SEG_MODEL.predict(frame, threshold=SEGMENTATION_CONFIDENCE_THRESHOLD)
241
+ keep = mask_nms(detections.mask, detections.confidence, iou_thresh=0.2)
242
+ detections = detections[keep]
243
+ if len(detections) > 10:
244
+ # keep first 10 detections (10 highest confidence detections)
245
+ detections = detections[:10]
246
+
247
+ # X,Y coordinates retrieval
248
+ court_xy = get_players_court_xy(frame, detections, KEYPOINT_MODEL)
249
+
250
+ return detections, court_xy
251
+
252
+ video_iterator = sv.get_video_frames_generator(VIDEO_PATH)
253
+ frame = toRGB(next(video_iterator))
254
+ initial_detections, initial_xy = get_models_predictions(frame)
255
+
256
+ history = []
257
+ tracker = Tracker(initial_detections, initial_xy, frame, matcher, HUNGARIAN_MASK_THRESHOLD, HUNGARIAN_POS_THRESHOLD)
258
+ history.append(annotate_frame(frame, initial_detections))
259
+
260
+ for frame_id, frame in tqdm(enumerate(video_iterator, start=1)):
261
+
262
+ frame = toRGB(frame)
263
+ detections, xy = get_models_predictions(frame)
264
+ tracker.update_tracks_with_new_detections(detections, xy, frame)
265
+ history.append(annotate_frame(frame, detections))
266
+ if frame_id == 150:
267
+ Image.fromarray(history[-1]).save("-1.png")
268
+ Image.fromarray(history[0]).save("0.png")
269
+ interact(local=locals())
270
+