poolay2 commited on
Commit
6eb0dec
·
verified ·
1 Parent(s): 136d68f

Delete tracking.py

Browse files
Files changed (1) hide show
  1. tracking.py +0 -270
tracking.py DELETED
@@ -1,270 +0,0 @@
1
- import supervision as sv
2
- import torch
3
- import numpy as np
4
- from collections import defaultdict
5
- from rfdetr import RFDETRSeg2XLarge
6
- from PIL import Image
7
- import cv2
8
- from scipy.optimize import linear_sum_assignment
9
- from .utils import (
10
- mask_nms,
11
- toRGB,
12
- matcher_probs_custom_argmax,
13
- get_distance_cost_matrix,
14
- mask_iou,
15
- get_crops_from_masks
16
- )
17
- from .view_transformer import (
18
- get_players_court_xy
19
- )
20
- from tqdm import tqdm
21
- from code import interact
22
-
23
- np.set_printoptions(suppress=True, precision=4)
24
- torch.set_printoptions(sci_mode=False)
25
-
26
- def indices_to_matches(
27
- cost_matrix, indices, thresh: float
28
- ):
29
- matched_cost = cost_matrix[tuple(zip(*indices))]
30
- matched_mask = matched_cost <= thresh
31
-
32
- matches = indices[matched_mask]
33
- unmatched_a = list(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
34
- unmatched_b = list(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
35
- return matches, unmatched_a, unmatched_b
36
-
37
- def linear_assignment(
38
- cost_matrix, thresh
39
- ):
40
- row_ind, col_ind = linear_sum_assignment(cost_matrix)
41
- indices = np.column_stack((row_ind, col_ind))
42
-
43
- return indices_to_matches(cost_matrix, indices, thresh)
44
-
45
- class Tracker:
46
-
47
- def __init__(
48
- self,
49
- initial_detections:sv.Detections,
50
- initial_xy: np.ndarray,
51
- initial_frame: np.ndarray,
52
- matcher,
53
- hungarian_mask_threshold: float,
54
- hungarian_pos_threshold: float
55
- ):
56
-
57
- self.frame_id = 0
58
- self.track_ids = list(range(len(initial_detections)))
59
- self.previous_detections = initial_detections
60
- self.previous_xy = initial_xy
61
- self.hungarian_mask_threshold = hungarian_mask_threshold
62
- self.hungarian_pos_threshold = hungarian_pos_threshold
63
- self.matcher = matcher
64
-
65
- '''Initialize track_ids of all 10 players'''
66
- self.all_players_detected = len(initial_detections) == 10
67
- initial_detections.tracker_id = np.array(self.track_ids)
68
- self.frame_id_to_xy = {
69
- self.frame_id : dict(zip(initial_detections.tracker_id, initial_xy))
70
- }
71
-
72
- # Keep one "base selfie" and one "latest selfie" of all players in memory.
73
- self.track_id_to_crop = defaultdict(list)
74
- for track_id, crop in zip(initial_detections.tracker_id, get_crops_from_masks(initial_frame, initial_detections.mask)):
75
- for _ in range(2):
76
- self.track_id_to_crop[track_id].append(crop)
77
-
78
- self.stats = {
79
- self.frame_id : {
80
- "detected_players" : len(initial_detections),
81
- "new_detections" : None,
82
- "all_players_detected" : self.all_players_detected,
83
- "mask_based_matches" : None,
84
- "position_based_matches" : None,
85
- "appearance_based_matches" : None,
86
- "unmatched" : None
87
- }
88
- }
89
-
90
- def update_tracks_with_new_detections(self, detections: sv.Detections, xy: np.ndarray, frame: np.ndarray):
91
-
92
- detections.tracker_id = -np.ones(shape=(len(detections)), dtype=np.int64)
93
- masks = detections.mask
94
-
95
- '''First Layer | Mask-based tracking:
96
- Safely track players based on their masks coordinates. When in doubt, leave the detections untracked'''
97
- # Cost_matrix_ij = 1 - IoU(mask_i, mask_j)
98
- null_track = self.previous_detections.tracker_id == -1
99
- mask_cost_matrix = 1.0 - mask_iou(masks, self.previous_detections[~null_track].mask)
100
- matches, unmatched_rows_t, _ = linear_assignment(mask_cost_matrix, self.hungarian_mask_threshold)
101
-
102
- # Apply results
103
- detections.tracker_id[matches[:,0]] = self.previous_detections[~null_track].tracker_id[matches[:,1]]
104
-
105
- # Remainder
106
- unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
107
- mask_based_matches = len(matches)
108
-
109
- if len(unmatched_rows_t) == 0:
110
- self.save_statistics(detections, xy, mask_based_matches)
111
- return
112
-
113
- '''Second Layer | Court-position-based tracking:
114
- Safely track remaining un-matched player based on their court (x,y) coordinates.
115
- '''
116
- pos_based_matches = 0
117
- dist_cost_matrix = get_distance_cost_matrix(
118
- xy,
119
- self.previous_xy[~null_track],
120
- ord = 2, # EUCLIDIAN DISTANCE
121
- )
122
- dist_cost_matrix[matches[:,0], :] = 1e3
123
- dist_cost_matrix[:, matches[:,1]] = 1e3
124
-
125
- matches_, _, _ = linear_assignment(dist_cost_matrix, self.hungarian_pos_threshold)
126
-
127
- # Apply results
128
- for match_ in matches_:
129
- if match_[0] in matches[:,0]:
130
- continue
131
- detections.tracker_id[match_[0]] = self.previous_detections[~null_track].tracker_id[match_[1]]
132
- pos_based_matches += 1
133
-
134
- # Remainder
135
- unmatched_rows_t = [i for i in range(len(detections)) if detections.tracker_id[i] == -1]
136
- unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
137
-
138
- if len(unmatched_rows_t) == 0:
139
- self.save_statistics(detections, xy, mask_based_matches, pos_based_matches)
140
- return
141
-
142
- '''Third Layer | Appearance-based tracking:
143
- Use a vision model to match remaining player crops to their corresponding crop at t-1
144
- '''
145
-
146
- unmatched = 0
147
- appearance_based_matches = 0
148
- new_detections = 0
149
-
150
- while len(unmatched_rows_t) > 0:
151
-
152
- unmatched_row_t = unmatched_rows_t.pop(0)
153
-
154
- # If there is only one un-matched mask at t-1 and t, they must correspond to the same player (assuming all players have been detected once, so there's no new player)
155
- if self.all_players_detected and len(unmatched_track_ids_t_1) == 1 and len(unmatched_rows_t) == 0:
156
- detections.tracker_id[unmatched_row_t] = unmatched_track_ids_t_1[0]
157
- unmatched_track_ids_t_1.pop(0)
158
- break
159
-
160
- '''Appearance-based tracking: track remaining un-matched players'''
161
- query_crop = get_crops_from_masks(frame, detections[unmatched_row_t].mask)[0] # Crop unmatched player at time t
162
- base_candidate_crops = [self.track_id_to_crop[t_id][0] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
163
- latest_candidate_crops = [self.track_id_to_crop[t_id][1] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
164
-
165
- probs = self.matcher.predict(query_crop, base_candidate_crops)
166
- probs = (probs + self.matcher.predict(query_crop, latest_candidate_crops)) / 2
167
- prediction = matcher_probs_custom_argmax(probs)
168
-
169
- if prediction != len(base_candidate_crops):
170
- pred_track_id = unmatched_track_ids_t_1[prediction]
171
- detections.tracker_id[unmatched_row_t] = pred_track_id
172
-
173
- unmatched_track_ids_t_1.pop(prediction)
174
- appearance_based_matches += 1
175
-
176
- # still unmatched -> (likely) a new player
177
- elif not(self.all_players_detected):
178
- new_track_id = max(self.track_ids) + 1
179
- detections.tracker_id[unmatched_row_t] = new_track_id
180
-
181
- new_detections += 1
182
- self.track_ids.append(new_track_id)
183
- self.all_players_detected = len(self.track_ids) == 10
184
-
185
- else:
186
- unmatched += 1
187
-
188
- self.save_statistics(detections, xy, mask_based_matches, pos_based_matches, appearance_based_matches, new_detections, unmatched)
189
-
190
- def save_statistics(self, detections, xy, mask_based_matches, pos_based_matches=0, appearance_based_matches=0, new_detections=0, unmatched=0):
191
- '''Update tracking statistics'''
192
- self.frame_id += 1
193
- self.stats[self.frame_id] = {
194
- "detected_players" : len(detections),
195
- "all_players_detected" : self.all_players_detected,
196
- "mask_based_matches" : mask_based_matches,
197
- "position_based_matches" : pos_based_matches,
198
- "appearance_based_matches" : appearance_based_matches,
199
- "new_detections" : new_detections,
200
- "unmatched" : unmatched
201
- }
202
-
203
- for i in range(len(detections)):
204
- track_id = detections.tracker_id[i]
205
- if track_id != -1:
206
- self.track_id_to_crop[track_id][1] = get_crops_from_masks(frame, detections[i].mask)[0]
207
- self.previous_detections = detections
208
- self.previous_xy = xy
209
-
210
- if __name__ == "__main__":
211
-
212
- from basketball_analysis import Matcher
213
- from utils import show_annotations, annotate_frame
214
- from inference import get_model
215
-
216
- VIDEO_PATH = "DEN_SAC_1_2025.mp4"
217
- HUNGARIAN_MASK_THRESHOLD = 0.6
218
- HUNGARIAN_POS_THRESHOLD = 2.0
219
-
220
- SEGMENTATION_CONFIDENCE_THRESHOLD = 0.4
221
- SEG_MODEL = RFDETRSeg2XLarge(resolution=1008, pretrain_weights="checkpoint_best_ema.pth")
222
- SEG_MODEL.optimize_for_inference()
223
-
224
- ROBOFLOW_API_KEY = "PUNfWgLHrHDufisOOaZp"
225
- KEYPOINT_DETECTION_MODEL_ID = "basketball-court-detection-2/14"
226
- KEYPOINT_MODEL = get_model(model_id=KEYPOINT_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
227
- KEYPOINT_COLOR = sv.Color.from_hex('#FF1493')
228
-
229
- matcher = Matcher(10,8, "DINOv2_small")
230
- sd = torch.load("matcher_tuned.pt")
231
- matcher.load_state_dict(sd)
232
-
233
- for p in matcher.parameters():
234
- p.requires_grad_(False)
235
- matcher.eval();
236
-
237
- def get_models_predictions(frame):
238
-
239
- # Segmentation
240
- detections = SEG_MODEL.predict(frame, threshold=SEGMENTATION_CONFIDENCE_THRESHOLD)
241
- keep = mask_nms(detections.mask, detections.confidence, iou_thresh=0.2)
242
- detections = detections[keep]
243
- if len(detections) > 10:
244
- # keep first 10 detections (10 highest confidence detections)
245
- detections = detections[:10]
246
-
247
- # X,Y coordinates retrieval
248
- court_xy = get_players_court_xy(frame, detections, KEYPOINT_MODEL)
249
-
250
- return detections, court_xy
251
-
252
- video_iterator = sv.get_video_frames_generator(VIDEO_PATH)
253
- frame = toRGB(next(video_iterator))
254
- initial_detections, initial_xy = get_models_predictions(frame)
255
-
256
- history = []
257
- tracker = Tracker(initial_detections, initial_xy, frame, matcher, HUNGARIAN_MASK_THRESHOLD, HUNGARIAN_POS_THRESHOLD)
258
- history.append(annotate_frame(frame, initial_detections))
259
-
260
- for frame_id, frame in tqdm(enumerate(video_iterator, start=1)):
261
-
262
- frame = toRGB(frame)
263
- detections, xy = get_models_predictions(frame)
264
- tracker.update_tracks_with_new_detections(detections, xy, frame)
265
- history.append(annotate_frame(frame, detections))
266
- if frame_id == 150:
267
- Image.fromarray(history[-1]).save("-1.png")
268
- Image.fromarray(history[0]).save("0.png")
269
- interact(local=locals())
270
-