poolay2 commited on
Commit
bbc0514
·
verified ·
1 Parent(s): 1c3b35d

Upload folder using huggingface_hub

Browse files
basketball_analysis/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .matcherBeta import Matcher
2
+ from .tracking import Tracker
3
+ from .utils import (
4
+ get_crops_from_masks,
5
+ toRGB,
6
+ xywhn_to_xywh,
7
+ mask_nms,
8
+ mask_iou,
9
+ matcher_probs_custom_argmax,
10
+ show_annotations,
11
+ annotate_frame,
12
+ COURT_KEYPOINT_COORDINATES,
13
+ )
basketball_analysis/matcherBeta.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from transformers import Dinov2Model, Dinov2Config
5
+ from torchvision.transforms import v2
6
+ from code import interact
7
+ import json
8
+ import os
9
+ from PIL import Image
10
+ import numpy as np
11
+ from typing import Union
12
+
13
+ transforms = v2.Compose([
14
+ v2.ToImage(),
15
+ v2.ToDtype(torch.float32, scale=True),
16
+ v2.Resize((224, 224)),
17
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
18
+ ])
19
+
20
+ class CrossAttention(nn.Module):
21
+
22
+ def __init__(self, d_model:int, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self.Wq = nn.Linear(d_model, d_model)
25
+ self.Wk = nn.Linear(d_model, d_model)
26
+ self.Wv = nn.Linear(d_model, d_model)
27
+
28
+ def forward(self, queries, candidates):
29
+
30
+ Q = self.Wk(candidates) # (B, num_candidates, d_model)
31
+ K = self.Wq(queries) # (B, num_queries, d_model)
32
+ V = self.Wv(queries) # (B, num_queries, d_model)
33
+ attn_out = F.scaled_dot_product_attention(Q, K, V) # (B, num_candidates, d_model)
34
+
35
+ return attn_out
36
+
37
+ class JointTransformer(nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ d_model=384,
42
+ nhead=4,
43
+ num_layers=4,
44
+ *args, **kwargs
45
+ ):
46
+ super().__init__(*args, **kwargs)
47
+
48
+ # Transformer encoder
49
+ encoder_layer = nn.TransformerEncoderLayer(
50
+ d_model=d_model,
51
+ nhead=nhead,
52
+ dim_feedforward=4 * d_model,
53
+ batch_first=True,
54
+ dropout=0.0
55
+ )
56
+
57
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
58
+
59
+ def forward(self, query: Tensor, candidates: Tensor) -> Tensor :
60
+ Q = query.size(1)
61
+ assert Q == 1
62
+
63
+ x = torch.cat((query, candidates), dim=1) # (B, Q+C, D)
64
+ x = self.transformer(x) # (B, Q+C, D)
65
+ query = x[:,:Q,:] # (B, Q, D)
66
+ candidates = x[:, Q:, :] # (B, C, D)
67
+
68
+ return query, candidates
69
+
70
+ class MLP(nn.Module):
71
+
72
+ def __init__(self, emb_dim, expand_factor, *args, **kwargs):
73
+ super().__init__(*args, **kwargs)
74
+ self.lin1 = nn.Linear(emb_dim, emb_dim*expand_factor)
75
+ self.gelu = nn.GELU("tanh")
76
+ self.lin2 = nn.Linear(emb_dim*expand_factor, emb_dim)
77
+
78
+ def forward(self, x:Tensor) -> Tensor:
79
+ x = self.lin1(x)
80
+ x = self.gelu(x)
81
+ x = self.lin2(x)
82
+ return x
83
+
84
+ class Matcher(nn.Module):
85
+
86
+ def __init__(self, max_candidates, num_layers, dino_dir, *args, **kwargs):
87
+ super().__init__(*args, **kwargs)
88
+
89
+ # -------------- Pre-trained Encoder (frozen) -----------------
90
+ assert isinstance(dino_dir, str)
91
+ with open(os.path.join(dino_dir, "config.json"), "r") as f:
92
+ dino_cfg = json.load(f)
93
+
94
+ self.encoder = Dinov2Model.from_pretrained(dino_dir, config = Dinov2Config(**dino_cfg))
95
+ self.freeze_encoder()
96
+
97
+ # ----------------- Embeddings to distinguish queries and candidates ---------------------
98
+ self.query_image_embed = nn.Parameter(torch.randn(1, 1, dino_cfg["hidden_size"]))
99
+ self.candidates_image_embed = nn.Embedding(max_candidates, dino_cfg["hidden_size"])
100
+ self.null_candidate = nn.Parameter(torch.randn(1, 1, dino_cfg["hidden_size"])) # null candidate embedding
101
+
102
+ # ---------------- Joint transformer (trained) ----------------------
103
+ self.max_candidates = max_candidates
104
+ self.num_layers = num_layers
105
+ self.joint_transformer = JointTransformer(
106
+ d_model = dino_cfg["hidden_size"],
107
+ nhead = dino_cfg["num_attention_heads"],
108
+ num_layers = num_layers,
109
+ )
110
+ self.lnormq = nn.LayerNorm(dino_cfg["hidden_size"], )
111
+ self.lnormc = nn.LayerNorm(dino_cfg["hidden_size"], )
112
+
113
+ # ------------------------ Final operation ---------------------------
114
+ self.cross_attn = CrossAttention(dino_cfg["hidden_size"])
115
+ self.lnormc2 = nn.LayerNorm(dino_cfg["hidden_size"])
116
+ self.classification_layer = nn.Linear(dino_cfg["hidden_size"], 1)
117
+
118
+ def freeze_encoder(self) -> None:
119
+ for p in self.encoder.parameters():
120
+ p.requires_grad_(False)
121
+
122
+ def pre_process_img(self, image:Union[Image.Image, np.ndarray, str]):
123
+
124
+ if isinstance(image, str):
125
+ image = Image.open(image)
126
+
127
+ return transforms(image)
128
+
129
+ @torch.inference_mode()
130
+ def predict(self, query_crop: np.ndarray, candidate_crops: list[np.ndarray]):
131
+
132
+ query = transforms(query_crop)[None, None, ...]
133
+ candidates = torch.stack([transforms(candidate_crop) for candidate_crop in candidate_crops]).unsqueeze(0)
134
+ probs = self.forward(query, candidates).softmax(dim=-1)
135
+
136
+ return probs.numpy()
137
+
138
+
139
+ def forward(self, query: Tensor, candidates: Tensor) -> Tensor :
140
+ # query (B,1,3,H,W), candidates (B,C,3,H,W)
141
+ B, C, _, H, W = candidates.shape
142
+
143
+ query = self.encoder(
144
+ query.view(B, 3, H, W)
145
+ )['last_hidden_state'] # (B, T, D)
146
+
147
+ # pick the CLS_TOKEN
148
+ query = query[:,0,:].view(B, 1, -1) # (B, 1, D)
149
+
150
+ candidates = self.encoder(
151
+ candidates.view(B*C, 3, H, W)
152
+ )['last_hidden_state'] # (B*C, T, D)
153
+
154
+ # pick the CLS_TOKEN
155
+ candidates = candidates[:,0,:].view(B, C, -1) # (B, C, D)
156
+
157
+ # Add embeddings
158
+ query = query + self.query_image_embed.repeat(B, 1, 1) # (B, 1, D)
159
+ candidate_ids = torch.arange(C, device=query.device).view(1, C)
160
+ candidates = candidates + self.candidates_image_embed(candidate_ids) # (B, C, D)
161
+ candidates = torch.cat(
162
+ (
163
+ candidates,
164
+ self.null_candidate.repeat(B, 1, 1)
165
+ ),
166
+ dim=1) # (B, C+1, D)
167
+
168
+ # Joint transformer, candidate and query tokens attend to each other
169
+ q, c = self.joint_transformer(query, candidates)
170
+ # skip connections
171
+ query = self.lnormq(query + q)
172
+ candidates = self.lnormc(candidates + c)
173
+
174
+ # Cross attention, query attends to candidates
175
+ c = self.cross_attn(query, candidates) # (B, C+1, D)
176
+ candidates = self.lnormc2(candidates + c)
177
+ candidates = candidates + c
178
+ logits = self.classification_layer(candidates) # (B, C+1, 1)
179
+
180
+ return logits.squeeze(-1)
181
+
182
+ if __name__ == "__main__":
183
+
184
+ import random
185
+
186
+ B, H, W = 1, 224, 224
187
+ max_candidates = 10
188
+ num_layers = 4
189
+
190
+ query = torch.randn((B, 1, 3, H, W))
191
+ candidates = torch.randn((B, random.randint(2, max_candidates), 3, H, W))
192
+
193
+ matcher = Matcher(max_candidates, num_layers, "DINOv2_base")
194
+ out = matcher(query, candidates)
195
+
196
+ interact(local=locals())
197
+
basketball_analysis/tracking.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import supervision as sv
2
+ import torch
3
+ import numpy as np
4
+ from collections import defaultdict
5
+ from rfdetr import RFDETRSeg2XLarge
6
+ from PIL import Image
7
+ import cv2
8
+ from scipy.optimize import linear_sum_assignment
9
+ from .utils import (
10
+ mask_nms,
11
+ toRGB,
12
+ matcher_probs_custom_argmax,
13
+ get_distance_cost_matrix,
14
+ mask_iou,
15
+ get_crops_from_masks
16
+ )
17
+ from .view_transformer import (
18
+ get_players_court_xy
19
+ )
20
+ from tqdm import tqdm
21
+ from code import interact
22
+
23
+ np.set_printoptions(suppress=True, precision=4)
24
+ torch.set_printoptions(sci_mode=False)
25
+
26
+ def indices_to_matches(
27
+ cost_matrix, indices, thresh: float
28
+ ):
29
+ matched_cost = cost_matrix[tuple(zip(*indices))]
30
+ matched_mask = matched_cost <= thresh
31
+
32
+ matches = indices[matched_mask]
33
+ unmatched_a = list(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
34
+ unmatched_b = list(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
35
+ return matches, unmatched_a, unmatched_b
36
+
37
+ def linear_assignment(
38
+ cost_matrix, thresh
39
+ ):
40
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
41
+ indices = np.column_stack((row_ind, col_ind))
42
+
43
+ return indices_to_matches(cost_matrix, indices, thresh)
44
+
45
+ class Tracker:
46
+
47
+ def __init__(
48
+ self,
49
+ initial_detections:sv.Detections,
50
+ initial_xy: np.ndarray,
51
+ initial_frame: np.ndarray,
52
+ matcher,
53
+ hungarian_mask_threshold: float,
54
+ hungarian_pos_threshold: float
55
+ ):
56
+
57
+ self.frame_id = 0
58
+ self.track_ids = list(range(len(initial_detections)))
59
+ self.previous_detections = initial_detections
60
+ self.previous_xy = initial_xy
61
+ self.hungarian_mask_threshold = hungarian_mask_threshold
62
+ self.hungarian_pos_threshold = hungarian_pos_threshold
63
+ self.matcher = matcher
64
+
65
+ '''Initialize track_ids of all 10 players'''
66
+ self.all_players_detected = len(initial_detections) == 10
67
+ initial_detections.tracker_id = np.array(self.track_ids)
68
+ self.frame_id_to_xy = {
69
+ self.frame_id : dict(zip(initial_detections.tracker_id, initial_xy))
70
+ }
71
+
72
+ # Keep one "base selfie" and one "latest selfie" of all players in memory.
73
+ self.track_id_to_crop = defaultdict(list)
74
+ for track_id, crop in zip(initial_detections.tracker_id, get_crops_from_masks(initial_frame, initial_detections.mask)):
75
+ for _ in range(2):
76
+ self.track_id_to_crop[track_id].append(crop)
77
+
78
+ self.stats = {
79
+ self.frame_id : {
80
+ "detected_players" : len(initial_detections),
81
+ "new_detections" : None,
82
+ "all_players_detected" : self.all_players_detected,
83
+ "mask_based_matches" : None,
84
+ "position_based_matches" : None,
85
+ "appearance_based_matches" : None,
86
+ "unmatched" : None
87
+ }
88
+ }
89
+
90
+ def update_tracks_with_new_detections(self, detections: sv.Detections, xy: np.ndarray, frame: np.ndarray):
91
+
92
+ detections.tracker_id = -np.ones(shape=(len(detections)), dtype=np.int64)
93
+ masks = detections.mask
94
+
95
+ '''First Layer | Mask-based tracking:
96
+ Safely track players based on their masks coordinates. When in doubt, leave the detections untracked'''
97
+ # Cost_matrix_ij = 1 - IoU(mask_i, mask_j)
98
+ null_track = self.previous_detections.tracker_id == -1
99
+ mask_cost_matrix = 1.0 - mask_iou(masks, self.previous_detections[~null_track].mask)
100
+ matches, unmatched_rows_t, _ = linear_assignment(mask_cost_matrix, self.hungarian_mask_threshold)
101
+
102
+ # Apply results
103
+ detections.tracker_id[matches[:,0]] = self.previous_detections[~null_track].tracker_id[matches[:,1]]
104
+
105
+ # Remainder
106
+ unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
107
+ mask_based_matches = len(matches)
108
+
109
+ if len(unmatched_rows_t) == 0:
110
+ self.save_statistics(detections, xy, mask_based_matches)
111
+ return
112
+
113
+ '''Second Layer | Court-position-based tracking:
114
+ Safely track remaining un-matched player based on their court (x,y) coordinates.
115
+ '''
116
+ pos_based_matches = 0
117
+ dist_cost_matrix = get_distance_cost_matrix(
118
+ xy,
119
+ self.previous_xy[~null_track],
120
+ ord = 2, # EUCLIDIAN DISTANCE
121
+ )
122
+ dist_cost_matrix[matches[:,0], :] = 1e3
123
+ dist_cost_matrix[:, matches[:,1]] = 1e3
124
+
125
+ matches_, _, _ = linear_assignment(dist_cost_matrix, self.hungarian_pos_threshold)
126
+
127
+ # Apply results
128
+ for match_ in matches_:
129
+ if match_[0] in matches[:,0]:
130
+ continue
131
+ detections.tracker_id[match_[0]] = self.previous_detections[~null_track].tracker_id[match_[1]]
132
+ pos_based_matches += 1
133
+
134
+ # Remainder
135
+ unmatched_rows_t = [i for i in range(len(detections)) if detections.tracker_id[i] == -1]
136
+ unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
137
+
138
+ if len(unmatched_rows_t) == 0:
139
+ self.save_statistics(detections, xy, mask_based_matches, pos_based_matches)
140
+ return
141
+
142
+ '''Third Layer | Appearance-based tracking:
143
+ Use a vision model to match remaining player crops to their corresponding crop at t-1
144
+ '''
145
+
146
+ unmatched = 0
147
+ appearance_based_matches = 0
148
+ new_detections = 0
149
+
150
+ while len(unmatched_rows_t) > 0:
151
+
152
+ unmatched_row_t = unmatched_rows_t.pop(0)
153
+
154
+ # If there is only one un-matched mask at t-1 and t, they must correspond to the same player (assuming all players have been detected once, so there's no new player)
155
+ if self.all_players_detected and len(unmatched_track_ids_t_1) == 1 and len(unmatched_rows_t) == 0:
156
+ detections.tracker_id[unmatched_row_t] = unmatched_track_ids_t_1[0]
157
+ unmatched_track_ids_t_1.pop(0)
158
+ break
159
+
160
+ '''Appearance-based tracking: track remaining un-matched players'''
161
+ query_crop = get_crops_from_masks(frame, detections[unmatched_row_t].mask)[0] # Crop unmatched player at time t
162
+ base_candidate_crops = [self.track_id_to_crop[t_id][0] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
163
+ latest_candidate_crops = [self.track_id_to_crop[t_id][1] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
164
+
165
+ probs = self.matcher.predict(query_crop, base_candidate_crops)
166
+ probs = (probs + self.matcher.predict(query_crop, latest_candidate_crops)) / 2
167
+ prediction = matcher_probs_custom_argmax(probs)
168
+
169
+ if prediction != len(base_candidate_crops):
170
+ pred_track_id = unmatched_track_ids_t_1[prediction]
171
+ detections.tracker_id[unmatched_row_t] = pred_track_id
172
+
173
+ unmatched_track_ids_t_1.pop(prediction)
174
+ appearance_based_matches += 1
175
+
176
+ # still unmatched -> (likely) a new player
177
+ elif not(self.all_players_detected):
178
+ new_track_id = max(self.track_ids) + 1
179
+ detections.tracker_id[unmatched_row_t] = new_track_id
180
+
181
+ new_detections += 1
182
+ self.track_ids.append(new_track_id)
183
+ self.all_players_detected = len(self.track_ids) == 10
184
+
185
+ else:
186
+ unmatched += 1
187
+
188
+ self.save_statistics(detections, xy, mask_based_matches, pos_based_matches, appearance_based_matches, new_detections, unmatched)
189
+
190
+ def save_statistics(self, detections, xy, mask_based_matches, pos_based_matches=0, appearance_based_matches=0, new_detections=0, unmatched=0):
191
+ '''Update tracking statistics'''
192
+ self.frame_id += 1
193
+ self.stats[self.frame_id] = {
194
+ "detected_players" : len(detections),
195
+ "all_players_detected" : self.all_players_detected,
196
+ "mask_based_matches" : mask_based_matches,
197
+ "position_based_matches" : pos_based_matches,
198
+ "appearance_based_matches" : appearance_based_matches,
199
+ "new_detections" : new_detections,
200
+ "unmatched" : unmatched
201
+ }
202
+
203
+ for i in range(len(detections)):
204
+ track_id = detections.tracker_id[i]
205
+ if track_id != -1:
206
+ self.track_id_to_crop[track_id][1] = get_crops_from_masks(frame, detections[i].mask)[0]
207
+ self.previous_detections = detections
208
+ self.previous_xy = xy
209
+
210
+ if __name__ == "__main__":
211
+
212
+ from basketball_analysis import Matcher
213
+ from utils import show_annotations, annotate_frame
214
+ from inference import get_model
215
+
216
+ VIDEO_PATH = "DEN_SAC_1_2025.mp4"
217
+ HUNGARIAN_MASK_THRESHOLD = 0.6
218
+ HUNGARIAN_POS_THRESHOLD = 2.0
219
+
220
+ SEGMENTATION_CONFIDENCE_THRESHOLD = 0.4
221
+ SEG_MODEL = RFDETRSeg2XLarge(resolution=1008, pretrain_weights="checkpoint_best_ema.pth")
222
+ SEG_MODEL.optimize_for_inference()
223
+
224
+ ROBOFLOW_API_KEY = "PUNfWgLHrHDufisOOaZp"
225
+ KEYPOINT_DETECTION_MODEL_ID = "basketball-court-detection-2/14"
226
+ KEYPOINT_MODEL = get_model(model_id=KEYPOINT_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
227
+ KEYPOINT_COLOR = sv.Color.from_hex('#FF1493')
228
+
229
+ matcher = Matcher(10,8, "DINOv2_small")
230
+ sd = torch.load("matcher_tuned.pt")
231
+ matcher.load_state_dict(sd)
232
+
233
+ for p in matcher.parameters():
234
+ p.requires_grad_(False)
235
+ matcher.eval();
236
+
237
+ def get_models_predictions(frame):
238
+
239
+ # Segmentation
240
+ detections = SEG_MODEL.predict(frame, threshold=SEGMENTATION_CONFIDENCE_THRESHOLD)
241
+ keep = mask_nms(detections.mask, detections.confidence, iou_thresh=0.2)
242
+ detections = detections[keep]
243
+ if len(detections) > 10:
244
+ # keep first 10 detections (10 highest confidence detections)
245
+ detections = detections[:10]
246
+
247
+ # X,Y coordinates retrieval
248
+ court_xy = get_players_court_xy(frame, detections, KEYPOINT_MODEL)
249
+
250
+ return detections, court_xy
251
+
252
+ video_iterator = sv.get_video_frames_generator(VIDEO_PATH)
253
+ frame = toRGB(next(video_iterator))
254
+ initial_detections, initial_xy = get_models_predictions(frame)
255
+
256
+ history = []
257
+ tracker = Tracker(initial_detections, initial_xy, frame, matcher, HUNGARIAN_MASK_THRESHOLD, HUNGARIAN_POS_THRESHOLD)
258
+ history.append(annotate_frame(frame, initial_detections))
259
+
260
+ for frame_id, frame in tqdm(enumerate(video_iterator, start=1)):
261
+
262
+ frame = toRGB(frame)
263
+ detections, xy = get_models_predictions(frame)
264
+ tracker.update_tracks_with_new_detections(detections, xy, frame)
265
+ history.append(annotate_frame(frame, detections))
266
+ if frame_id == 150:
267
+ Image.fromarray(history[-1]).save("-1.png")
268
+ Image.fromarray(history[0]).save("0.png")
269
+ interact(local=locals())
270
+
basketball_analysis/utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+ import numpy as np
4
+ import supervision as sv
5
+ from pycocotools import mask as mask_utils
6
+ import cv2
7
+ import ffmpeg
8
+ from PIL import Image
9
+ import numpy as np
10
+ from typing import List, Iterable
11
+ from matplotlib import pyplot as plt
12
+
13
+ class SAM2Tracker:
14
+ def __init__(self, predictor) -> None:
15
+ self.predictor = predictor
16
+ self._prompted = False
17
+
18
+ def prompt_first_frame(self, frame: np.ndarray, detections: sv.Detections) -> None:
19
+ if len(detections) == 0:
20
+ raise ValueError("detections must contain at least one box")
21
+
22
+ if detections.tracker_id is None:
23
+ detections.tracker_id = list(range(1, len(detections) + 1))
24
+
25
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
26
+ self.predictor.load_first_frame(frame)
27
+ for xyxy, obj_id in zip(detections.xyxy, detections.tracker_id):
28
+ bbox = np.asarray([xyxy], dtype=np.float32)
29
+ self.predictor.add_new_prompt(
30
+ frame_idx=0,
31
+ obj_id=int(obj_id),
32
+ bbox=bbox,
33
+ )
34
+
35
+ self._prompted = True
36
+
37
+ def propagate(self, frame: np.ndarray) -> sv.Detections:
38
+ if not self._prompted:
39
+ raise RuntimeError("Call prompt_first_frame before propagate")
40
+
41
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
42
+ tracker_ids, mask_logits = self.predictor.track(frame)
43
+
44
+ tracker_ids = np.asarray(tracker_ids, dtype=np.int32)
45
+ masks = (mask_logits > 0.0).cpu().numpy()
46
+ masks = np.squeeze(masks).astype(bool)
47
+
48
+ if masks.ndim == 2:
49
+ masks = masks[None, ...]
50
+
51
+ masks = np.array([
52
+ sv.filter_segments_by_distance(mask, relative_distance=0.03, mode="edge")
53
+ for mask in masks
54
+ ])
55
+
56
+ xyxy = sv.mask_to_xyxy(masks=masks)
57
+ detections = sv.Detections(xyxy=xyxy, mask=masks, tracker_id=tracker_ids)
58
+ return detections
59
+
60
+ def reset(self) -> None:
61
+ self._prompted = False
62
+
63
+ def get_crops_from_masks(frame: np.ndarray, masks: np.ndarray) -> list[np.ndarray]:
64
+ """
65
+ Args:mask_index
66
+ frame: (H, W, 3) image
67
+ masks: (N, H, W) binary masks
68
+
69
+ Returns:
70
+ List of cropped images, one per mask. Each crop is a rectangular
71
+ bounding box around the mask, with black pixels outside the mask.
72
+ """
73
+ crops = []
74
+
75
+ for mask in masks:
76
+
77
+ # Find bounding box of the mask
78
+ ys, xs = np.where(mask)
79
+ if len(xs) == 0 or len(ys) == 0:
80
+ # Empty mask → skip or return empty crop
81
+ crops.append(np.zeros((0, 0, 3), dtype=frame.dtype))
82
+ continue
83
+
84
+ y_min, y_max = ys.min(), ys.max() + 1
85
+ x_min, x_max = xs.min(), xs.max() + 1
86
+
87
+ # Crop the frame and mask
88
+ frame_crop = frame[y_min:y_max, x_min:x_max]
89
+ mask_crop = mask[y_min:y_max, x_min:x_max]
90
+
91
+ # Apply mask: keep pixels where mask is True, else black
92
+ crop = np.zeros_like(frame_crop)
93
+ crop[mask_crop] = frame_crop[mask_crop]
94
+
95
+ crops.append(crop)
96
+
97
+ return crops
98
+
99
+ def f(detections: sv.Detections, track_history: dict, frame_index):
100
+
101
+ for i in range(len(detections)):
102
+
103
+ mask = detections.mask[i]
104
+ rle = mask_utils.encode(np.asfortranarray(mask))
105
+ track_history[int(detections.tracker_id[i])].append((frame_index, rle['counts']))
106
+
107
+
108
+ def toRGB(img: np.ndarray):
109
+ return cv2.cvtColor(img, code=cv2.COLOR_BGR2RGB)
110
+
111
+ def read_frame_from_video(in_filename, frame_num):
112
+ raw_bytes, err = (
113
+ ffmpeg
114
+ .input(in_filename)
115
+ .filter('select', 'gte(n,{})'.format(frame_num))
116
+ .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
117
+ .global_args('-loglevel', 'error')
118
+ .run(capture_stdout=True)
119
+ )
120
+ assert len(raw_bytes) == 1080 * 1920 * 3
121
+ return np.frombuffer(raw_bytes, np.uint8).reshape(1, 1080, 1920, 3).copy()
122
+
123
+ def read_consecutive_frames_from_video(in_filename, start_frame, num_frames) -> np.ndarray:
124
+
125
+ out, err = ffmpeg.input(in_filename)\
126
+ .output(
127
+ 'pipe:1',
128
+ vf=f'select=between(n\\,{start_frame}\\,{start_frame + num_frames - 1})',
129
+ vsync=0,
130
+ vframes=num_frames,
131
+ format='rawvideo',
132
+ pix_fmt='rgb24'
133
+ ).global_args('-loglevel', 'error')\
134
+ .run(capture_stdout=True, capture_stderr=True)
135
+
136
+ W, H = 1920, 1080
137
+ frame_size = W * H * 3
138
+ frames = np.frombuffer(out, np.uint8)
139
+
140
+ if frames.size != num_frames * frame_size:
141
+ raise RuntimeError(
142
+ f'Expected {num_frames * frame_size} bytes, got {frames.size}\n'
143
+ f'ffmpeg stderr:\n{err.decode()}'
144
+ )
145
+
146
+ # frames.setflags(write=True)
147
+ return frames.reshape(num_frames, H, W, 3).copy()
148
+
149
+ def xywhn_to_xywh(xywhn:list, height:int, width:int):
150
+
151
+ x,y,w,h = xywhn
152
+
153
+ return [int(x * width), int(y * height), int(w * width), int(h * height)]
154
+
155
+ def crop_frame_at_mask_from_bbox(frame: np.ndarray, mask: np.ndarray, bbox: list) -> np.array:
156
+
157
+ x,y,w,h = bbox
158
+ crop = frame[y: y+h, x: x+w]
159
+ cropped_mask = mask[y: y+h, x: x+w]
160
+ # from code import interact; interact(local=locals())
161
+ crop[~cropped_mask] = np.array([0,0,0], dtype=np.uint8)
162
+
163
+ return crop
164
+
165
+ def find_consecutive_streaks(nums: list|Iterable):
166
+
167
+ if isinstance(nums, Iterable): nums = list(nums)
168
+ if not nums:
169
+ return []
170
+
171
+ streaks = []
172
+ start = nums[0]
173
+ for i in range(1, len(nums)):
174
+ if nums[i] != nums[i-1] + 1:
175
+ stop = nums[i-1]
176
+ streaks.append(range(start, stop + 1))
177
+ start = nums[i]
178
+
179
+ streaks.append(range(start, nums[-1] + 1))
180
+ return streaks
181
+
182
+ def save_loss_history(fpath, loss:float):
183
+
184
+ with open(fpath, "a+") as f:
185
+ f.write(f"{loss:.6f}\n")
186
+
187
+ def save_loss_history_plot(loss_history: list[float], fpath):
188
+
189
+ plt.plot(loss_history)
190
+ plt.savefig(fpath)
191
+
192
+ def save_checkpoint(
193
+ path,
194
+ model,
195
+ optimizer,
196
+ epoch,
197
+ step,
198
+ ):
199
+
200
+ ckpt = {
201
+ "model": model.state_dict(),
202
+ "optimizer": optimizer.state_dict(),
203
+ "epoch": epoch,
204
+ "step": step,
205
+ }
206
+ torch.save(ckpt, path)
207
+
208
+ def load_checkpoint(
209
+ path,
210
+ model,
211
+ optimizer,
212
+ device="cuda"
213
+ ):
214
+ ckpt = torch.load(path, map_location=device)
215
+
216
+ model.load_state_dict(ckpt["model"])
217
+ optimizer.load_state_dict(ckpt["optimizer"])
218
+
219
+ epoch = ckpt.get("epoch", 0)
220
+ step = ckpt.get("step", 0)
221
+
222
+ return epoch, step
223
+
224
+ def mask_iou_pair(m1, m2):
225
+ inter = np.logical_and(m1, m2).sum()
226
+ if inter == 0:
227
+ return 0.0
228
+ union = m1.sum() + m2.sum() - inter
229
+ return inter / (union + 1e-6)
230
+
231
+
232
+ def mask_nms(masks, scores, iou_thresh=0.6):
233
+ order = np.argsort(-scores)
234
+ keep = []
235
+ suppressed = np.zeros(len(masks), dtype=bool)
236
+
237
+ for i in order:
238
+ if suppressed[i]:
239
+ continue
240
+
241
+ keep.append(i)
242
+
243
+ for j in order:
244
+ if j <= i or suppressed[j]:
245
+ continue
246
+
247
+ iou = mask_iou_pair(masks[i], masks[j])
248
+ if iou > iou_thresh:
249
+ suppressed[j] = True
250
+
251
+ return keep
252
+
253
+ def mask_iou(masks_t: np.ndarray, masks_t1):
254
+ # Flatten
255
+ N, H, W = masks_t.shape
256
+ M = masks_t1.shape[0]
257
+
258
+ masks_t = masks_t.reshape(N, -1).astype(float) # (N, HW)
259
+ masks_t1 = masks_t1.reshape(M, -1).astype(float) # (M, HW)
260
+
261
+ # Intersection: (N, M)
262
+ intersection = masks_t @ masks_t1.T
263
+
264
+ # Areas
265
+ area_t = masks_t.sum(1, keepdims=True) # (N, 1)
266
+ area_t1 = masks_t1.sum(1, keepdims=True) # (M, 1)
267
+
268
+ # Union
269
+ union = area_t + area_t1.T - intersection
270
+
271
+ iou = intersection / (union + 1e-6)
272
+ return iou # (N, M)
273
+
274
+ COURT_KEYPOINT_COORDINATES = np.array([
275
+ (0.0, 0.0),
276
+ (0.0, 2.99),
277
+ (0.0, 17.0),
278
+ (0.0, 33.01),
279
+ (0.0, 47.02),
280
+ (0.0, 50.0),
281
+ (5.25, 25.0),
282
+ (13.92, 2.99),
283
+ (13.92, 47.02),
284
+ (19.0, 17.0),
285
+ (19.0, 25.0),
286
+ (19.0, 33.01),
287
+ (27.4, 0.0),
288
+ (29.01, 25.0),
289
+ (27.4, 50.0),
290
+ (46.99, 0.0),
291
+ (46.99, 25.0),
292
+ (46.99, 50.0),
293
+ (66.61, 0.0),
294
+ (65.0, 25.0),
295
+ (66.61, 50.0),
296
+ (75.0, 17.0),
297
+ (75.0, 25.0),
298
+ (75.0, 33.01),
299
+ (80.09, 2.99),
300
+ (80.09, 47.02),
301
+ (88.75, 25.0),
302
+ (94.0, 0.0),
303
+ (94.0, 2.99),
304
+ (94.0, 17.0),
305
+ (94.0, 33.01),
306
+ (94.0, 47.02),
307
+ (94.0, 50.0)
308
+ ])
309
+
310
+ def get_distance_cost_matrix(arr1:np.ndarray, arr2:np.ndarray, ord=1) :
311
+
312
+ cost_matrix = np.empty(shape=(len(arr1), len(arr2)), dtype=np.float64)
313
+
314
+ for i in range(len(arr1)):
315
+ cost_matrix[i] = np.linalg.norm(arr1[i] - arr2, ord=ord, axis=-1)
316
+
317
+ return torch.tensor(cost_matrix)
318
+
319
+ def matcher_probs_custom_argmax(probs:np.ndarray, confidence_threshold=0.7):
320
+ probs = probs.squeeze(0)
321
+ pred = probs.argmax()
322
+ # if matcher predicts the null prediction, but it is not confident
323
+ if pred == len(probs) - 1 and probs[pred] < confidence_threshold:
324
+ # predict the second most confident prediction if it has high weight
325
+ second_best = probs[:-1].argmax()
326
+ if probs[second_best] > 1.0 - confidence_threshold - 0.05:
327
+ pred = second_best
328
+
329
+ return pred
330
+
331
+ def show_annotations(frame_, detections_):
332
+ annotated_frame = frame_.copy()
333
+ annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_)
334
+ annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id))
335
+ return Image.fromarray(annotated_frame)
336
+
337
+ def annotate_frame(frame_, detections_):
338
+ annotated_frame = frame_.copy()
339
+ annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_)
340
+ annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id))
341
+ return annotated_frame
342
+
343
+ if __name__ == "__main__":
344
+ from code import interact
345
+ frames = read_consecutive_frames_from_video("nba_sample_videos/batch2/SAC_LAL_1.mp4", 199, 1)
346
+ # crop_frame_at_mask_from_bbox(np.zeros((1080, 1920, 3)), )
347
+ interact(local=locals())
basketball_analysis/view_transformer.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sports import MeasurementUnit
2
+ from sports.basketball import CourtConfiguration, League, draw_court, draw_points_on_court
3
+ import numpy as np
4
+ import supervision as sv
5
+ import cv2
6
+
7
+ CONFIG = CourtConfiguration(league=League.NBA, measurement_unit=MeasurementUnit.FEET).vertices
8
+
9
+ def frame_xy_to_court_xy(frame_xy: np.ndarray, H: np.ndarray):
10
+
11
+ assert frame_xy.shape[1] == 2
12
+ n_points = frame_xy.shape[0]
13
+
14
+ court_xy = np.hstack((frame_xy, np.ones(shape=(n_points, 1)))) @ H.T
15
+ court_xy_norm = court_xy[:, :2] / court_xy[:, [-1]]
16
+ return court_xy_norm
17
+
18
+ def get_players_court_xy(frame, detections, model, use_bottom_center=True, normalize=False):
19
+ KEYPOINT_DETECTION_MODEL_CONFIDENCE = 0.3
20
+ KEYPOINT_DETECTION_MODEL_ANCHOR_CONFIDENCE = 0.5
21
+
22
+ # Locate court keypoints (or reference points)
23
+ result = model.infer(frame, confidence=KEYPOINT_DETECTION_MODEL_CONFIDENCE)[0]
24
+ key_points = sv.KeyPoints.from_inference(result)
25
+ filter_mask = key_points.confidence[0] > KEYPOINT_DETECTION_MODEL_ANCHOR_CONFIDENCE
26
+
27
+ # Compute homography matrix H
28
+ court_landmarks = np.array(CONFIG)[filter_mask]
29
+ frame_landmarks = key_points[:, filter_mask].xy[0]
30
+ H, _ = cv2.findHomography(frame_landmarks, court_landmarks)
31
+
32
+ # From the player detections, retrieve their position on the court
33
+ x1 = detections.xyxy[:, 0]
34
+ x2 = detections.xyxy[:, 2]
35
+ y1 = detections.xyxy[:, 1]
36
+ y2 = detections.xyxy[:, 3]
37
+ if use_bottom_center:
38
+ # Take the bottom center of the bounding box as the (x,y) coordinate
39
+ frame_xy = np.vstack(
40
+ (x1 + (x2 - x1) / 2, y2)
41
+ ).T
42
+ else:
43
+ frame_xy = np.vstack(
44
+ (x1 + (x2 - x1) / 2, y1 + (y2 - y1) / 2)
45
+ ).T
46
+ # apply homographic transformation
47
+ court_xy = frame_xy_to_court_xy(frame_xy, H)
48
+
49
+ if normalize:
50
+ court_xy = court_xy / np.array([94.0, 50.0])
51
+
52
+ return court_xy
53
+
54
+ def show_positions_on_court(court_xy):
55
+ court = draw_court(config=CONFIG)
56
+ court = draw_points_on_court(
57
+ config=CONFIG,
58
+ xy=court_xy,
59
+ court=court
60
+ )
61
+ sv.plot_image(court)