Upload folder using huggingface_hub
Browse files- basketball_analysis/__init__.py +13 -0
- basketball_analysis/matcherBeta.py +197 -0
- basketball_analysis/tracking.py +270 -0
- basketball_analysis/utils.py +347 -0
- basketball_analysis/view_transformer.py +61 -0
basketball_analysis/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .matcherBeta import Matcher
|
| 2 |
+
from .tracking import Tracker
|
| 3 |
+
from .utils import (
|
| 4 |
+
get_crops_from_masks,
|
| 5 |
+
toRGB,
|
| 6 |
+
xywhn_to_xywh,
|
| 7 |
+
mask_nms,
|
| 8 |
+
mask_iou,
|
| 9 |
+
matcher_probs_custom_argmax,
|
| 10 |
+
show_annotations,
|
| 11 |
+
annotate_frame,
|
| 12 |
+
COURT_KEYPOINT_COORDINATES,
|
| 13 |
+
)
|
basketball_analysis/matcherBeta.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, Tensor
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import Dinov2Model, Dinov2Config
|
| 5 |
+
from torchvision.transforms import v2
|
| 6 |
+
from code import interact
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import Union
|
| 12 |
+
|
| 13 |
+
transforms = v2.Compose([
|
| 14 |
+
v2.ToImage(),
|
| 15 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 16 |
+
v2.Resize((224, 224)),
|
| 17 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 18 |
+
])
|
| 19 |
+
|
| 20 |
+
class CrossAttention(nn.Module):
|
| 21 |
+
|
| 22 |
+
def __init__(self, d_model:int, *args, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self.Wq = nn.Linear(d_model, d_model)
|
| 25 |
+
self.Wk = nn.Linear(d_model, d_model)
|
| 26 |
+
self.Wv = nn.Linear(d_model, d_model)
|
| 27 |
+
|
| 28 |
+
def forward(self, queries, candidates):
|
| 29 |
+
|
| 30 |
+
Q = self.Wk(candidates) # (B, num_candidates, d_model)
|
| 31 |
+
K = self.Wq(queries) # (B, num_queries, d_model)
|
| 32 |
+
V = self.Wv(queries) # (B, num_queries, d_model)
|
| 33 |
+
attn_out = F.scaled_dot_product_attention(Q, K, V) # (B, num_candidates, d_model)
|
| 34 |
+
|
| 35 |
+
return attn_out
|
| 36 |
+
|
| 37 |
+
class JointTransformer(nn.Module):
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
d_model=384,
|
| 42 |
+
nhead=4,
|
| 43 |
+
num_layers=4,
|
| 44 |
+
*args, **kwargs
|
| 45 |
+
):
|
| 46 |
+
super().__init__(*args, **kwargs)
|
| 47 |
+
|
| 48 |
+
# Transformer encoder
|
| 49 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 50 |
+
d_model=d_model,
|
| 51 |
+
nhead=nhead,
|
| 52 |
+
dim_feedforward=4 * d_model,
|
| 53 |
+
batch_first=True,
|
| 54 |
+
dropout=0.0
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
|
| 58 |
+
|
| 59 |
+
def forward(self, query: Tensor, candidates: Tensor) -> Tensor :
|
| 60 |
+
Q = query.size(1)
|
| 61 |
+
assert Q == 1
|
| 62 |
+
|
| 63 |
+
x = torch.cat((query, candidates), dim=1) # (B, Q+C, D)
|
| 64 |
+
x = self.transformer(x) # (B, Q+C, D)
|
| 65 |
+
query = x[:,:Q,:] # (B, Q, D)
|
| 66 |
+
candidates = x[:, Q:, :] # (B, C, D)
|
| 67 |
+
|
| 68 |
+
return query, candidates
|
| 69 |
+
|
| 70 |
+
class MLP(nn.Module):
|
| 71 |
+
|
| 72 |
+
def __init__(self, emb_dim, expand_factor, *args, **kwargs):
|
| 73 |
+
super().__init__(*args, **kwargs)
|
| 74 |
+
self.lin1 = nn.Linear(emb_dim, emb_dim*expand_factor)
|
| 75 |
+
self.gelu = nn.GELU("tanh")
|
| 76 |
+
self.lin2 = nn.Linear(emb_dim*expand_factor, emb_dim)
|
| 77 |
+
|
| 78 |
+
def forward(self, x:Tensor) -> Tensor:
|
| 79 |
+
x = self.lin1(x)
|
| 80 |
+
x = self.gelu(x)
|
| 81 |
+
x = self.lin2(x)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
class Matcher(nn.Module):
|
| 85 |
+
|
| 86 |
+
def __init__(self, max_candidates, num_layers, dino_dir, *args, **kwargs):
|
| 87 |
+
super().__init__(*args, **kwargs)
|
| 88 |
+
|
| 89 |
+
# -------------- Pre-trained Encoder (frozen) -----------------
|
| 90 |
+
assert isinstance(dino_dir, str)
|
| 91 |
+
with open(os.path.join(dino_dir, "config.json"), "r") as f:
|
| 92 |
+
dino_cfg = json.load(f)
|
| 93 |
+
|
| 94 |
+
self.encoder = Dinov2Model.from_pretrained(dino_dir, config = Dinov2Config(**dino_cfg))
|
| 95 |
+
self.freeze_encoder()
|
| 96 |
+
|
| 97 |
+
# ----------------- Embeddings to distinguish queries and candidates ---------------------
|
| 98 |
+
self.query_image_embed = nn.Parameter(torch.randn(1, 1, dino_cfg["hidden_size"]))
|
| 99 |
+
self.candidates_image_embed = nn.Embedding(max_candidates, dino_cfg["hidden_size"])
|
| 100 |
+
self.null_candidate = nn.Parameter(torch.randn(1, 1, dino_cfg["hidden_size"])) # null candidate embedding
|
| 101 |
+
|
| 102 |
+
# ---------------- Joint transformer (trained) ----------------------
|
| 103 |
+
self.max_candidates = max_candidates
|
| 104 |
+
self.num_layers = num_layers
|
| 105 |
+
self.joint_transformer = JointTransformer(
|
| 106 |
+
d_model = dino_cfg["hidden_size"],
|
| 107 |
+
nhead = dino_cfg["num_attention_heads"],
|
| 108 |
+
num_layers = num_layers,
|
| 109 |
+
)
|
| 110 |
+
self.lnormq = nn.LayerNorm(dino_cfg["hidden_size"], )
|
| 111 |
+
self.lnormc = nn.LayerNorm(dino_cfg["hidden_size"], )
|
| 112 |
+
|
| 113 |
+
# ------------------------ Final operation ---------------------------
|
| 114 |
+
self.cross_attn = CrossAttention(dino_cfg["hidden_size"])
|
| 115 |
+
self.lnormc2 = nn.LayerNorm(dino_cfg["hidden_size"])
|
| 116 |
+
self.classification_layer = nn.Linear(dino_cfg["hidden_size"], 1)
|
| 117 |
+
|
| 118 |
+
def freeze_encoder(self) -> None:
|
| 119 |
+
for p in self.encoder.parameters():
|
| 120 |
+
p.requires_grad_(False)
|
| 121 |
+
|
| 122 |
+
def pre_process_img(self, image:Union[Image.Image, np.ndarray, str]):
|
| 123 |
+
|
| 124 |
+
if isinstance(image, str):
|
| 125 |
+
image = Image.open(image)
|
| 126 |
+
|
| 127 |
+
return transforms(image)
|
| 128 |
+
|
| 129 |
+
@torch.inference_mode()
|
| 130 |
+
def predict(self, query_crop: np.ndarray, candidate_crops: list[np.ndarray]):
|
| 131 |
+
|
| 132 |
+
query = transforms(query_crop)[None, None, ...]
|
| 133 |
+
candidates = torch.stack([transforms(candidate_crop) for candidate_crop in candidate_crops]).unsqueeze(0)
|
| 134 |
+
probs = self.forward(query, candidates).softmax(dim=-1)
|
| 135 |
+
|
| 136 |
+
return probs.numpy()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def forward(self, query: Tensor, candidates: Tensor) -> Tensor :
|
| 140 |
+
# query (B,1,3,H,W), candidates (B,C,3,H,W)
|
| 141 |
+
B, C, _, H, W = candidates.shape
|
| 142 |
+
|
| 143 |
+
query = self.encoder(
|
| 144 |
+
query.view(B, 3, H, W)
|
| 145 |
+
)['last_hidden_state'] # (B, T, D)
|
| 146 |
+
|
| 147 |
+
# pick the CLS_TOKEN
|
| 148 |
+
query = query[:,0,:].view(B, 1, -1) # (B, 1, D)
|
| 149 |
+
|
| 150 |
+
candidates = self.encoder(
|
| 151 |
+
candidates.view(B*C, 3, H, W)
|
| 152 |
+
)['last_hidden_state'] # (B*C, T, D)
|
| 153 |
+
|
| 154 |
+
# pick the CLS_TOKEN
|
| 155 |
+
candidates = candidates[:,0,:].view(B, C, -1) # (B, C, D)
|
| 156 |
+
|
| 157 |
+
# Add embeddings
|
| 158 |
+
query = query + self.query_image_embed.repeat(B, 1, 1) # (B, 1, D)
|
| 159 |
+
candidate_ids = torch.arange(C, device=query.device).view(1, C)
|
| 160 |
+
candidates = candidates + self.candidates_image_embed(candidate_ids) # (B, C, D)
|
| 161 |
+
candidates = torch.cat(
|
| 162 |
+
(
|
| 163 |
+
candidates,
|
| 164 |
+
self.null_candidate.repeat(B, 1, 1)
|
| 165 |
+
),
|
| 166 |
+
dim=1) # (B, C+1, D)
|
| 167 |
+
|
| 168 |
+
# Joint transformer, candidate and query tokens attend to each other
|
| 169 |
+
q, c = self.joint_transformer(query, candidates)
|
| 170 |
+
# skip connections
|
| 171 |
+
query = self.lnormq(query + q)
|
| 172 |
+
candidates = self.lnormc(candidates + c)
|
| 173 |
+
|
| 174 |
+
# Cross attention, query attends to candidates
|
| 175 |
+
c = self.cross_attn(query, candidates) # (B, C+1, D)
|
| 176 |
+
candidates = self.lnormc2(candidates + c)
|
| 177 |
+
candidates = candidates + c
|
| 178 |
+
logits = self.classification_layer(candidates) # (B, C+1, 1)
|
| 179 |
+
|
| 180 |
+
return logits.squeeze(-1)
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
|
| 184 |
+
import random
|
| 185 |
+
|
| 186 |
+
B, H, W = 1, 224, 224
|
| 187 |
+
max_candidates = 10
|
| 188 |
+
num_layers = 4
|
| 189 |
+
|
| 190 |
+
query = torch.randn((B, 1, 3, H, W))
|
| 191 |
+
candidates = torch.randn((B, random.randint(2, max_candidates), 3, H, W))
|
| 192 |
+
|
| 193 |
+
matcher = Matcher(max_candidates, num_layers, "DINOv2_base")
|
| 194 |
+
out = matcher(query, candidates)
|
| 195 |
+
|
| 196 |
+
interact(local=locals())
|
| 197 |
+
|
basketball_analysis/tracking.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import supervision as sv
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from rfdetr import RFDETRSeg2XLarge
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import cv2
|
| 8 |
+
from scipy.optimize import linear_sum_assignment
|
| 9 |
+
from .utils import (
|
| 10 |
+
mask_nms,
|
| 11 |
+
toRGB,
|
| 12 |
+
matcher_probs_custom_argmax,
|
| 13 |
+
get_distance_cost_matrix,
|
| 14 |
+
mask_iou,
|
| 15 |
+
get_crops_from_masks
|
| 16 |
+
)
|
| 17 |
+
from .view_transformer import (
|
| 18 |
+
get_players_court_xy
|
| 19 |
+
)
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from code import interact
|
| 22 |
+
|
| 23 |
+
np.set_printoptions(suppress=True, precision=4)
|
| 24 |
+
torch.set_printoptions(sci_mode=False)
|
| 25 |
+
|
| 26 |
+
def indices_to_matches(
|
| 27 |
+
cost_matrix, indices, thresh: float
|
| 28 |
+
):
|
| 29 |
+
matched_cost = cost_matrix[tuple(zip(*indices))]
|
| 30 |
+
matched_mask = matched_cost <= thresh
|
| 31 |
+
|
| 32 |
+
matches = indices[matched_mask]
|
| 33 |
+
unmatched_a = list(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
|
| 34 |
+
unmatched_b = list(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
|
| 35 |
+
return matches, unmatched_a, unmatched_b
|
| 36 |
+
|
| 37 |
+
def linear_assignment(
|
| 38 |
+
cost_matrix, thresh
|
| 39 |
+
):
|
| 40 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
| 41 |
+
indices = np.column_stack((row_ind, col_ind))
|
| 42 |
+
|
| 43 |
+
return indices_to_matches(cost_matrix, indices, thresh)
|
| 44 |
+
|
| 45 |
+
class Tracker:
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
initial_detections:sv.Detections,
|
| 50 |
+
initial_xy: np.ndarray,
|
| 51 |
+
initial_frame: np.ndarray,
|
| 52 |
+
matcher,
|
| 53 |
+
hungarian_mask_threshold: float,
|
| 54 |
+
hungarian_pos_threshold: float
|
| 55 |
+
):
|
| 56 |
+
|
| 57 |
+
self.frame_id = 0
|
| 58 |
+
self.track_ids = list(range(len(initial_detections)))
|
| 59 |
+
self.previous_detections = initial_detections
|
| 60 |
+
self.previous_xy = initial_xy
|
| 61 |
+
self.hungarian_mask_threshold = hungarian_mask_threshold
|
| 62 |
+
self.hungarian_pos_threshold = hungarian_pos_threshold
|
| 63 |
+
self.matcher = matcher
|
| 64 |
+
|
| 65 |
+
'''Initialize track_ids of all 10 players'''
|
| 66 |
+
self.all_players_detected = len(initial_detections) == 10
|
| 67 |
+
initial_detections.tracker_id = np.array(self.track_ids)
|
| 68 |
+
self.frame_id_to_xy = {
|
| 69 |
+
self.frame_id : dict(zip(initial_detections.tracker_id, initial_xy))
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Keep one "base selfie" and one "latest selfie" of all players in memory.
|
| 73 |
+
self.track_id_to_crop = defaultdict(list)
|
| 74 |
+
for track_id, crop in zip(initial_detections.tracker_id, get_crops_from_masks(initial_frame, initial_detections.mask)):
|
| 75 |
+
for _ in range(2):
|
| 76 |
+
self.track_id_to_crop[track_id].append(crop)
|
| 77 |
+
|
| 78 |
+
self.stats = {
|
| 79 |
+
self.frame_id : {
|
| 80 |
+
"detected_players" : len(initial_detections),
|
| 81 |
+
"new_detections" : None,
|
| 82 |
+
"all_players_detected" : self.all_players_detected,
|
| 83 |
+
"mask_based_matches" : None,
|
| 84 |
+
"position_based_matches" : None,
|
| 85 |
+
"appearance_based_matches" : None,
|
| 86 |
+
"unmatched" : None
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def update_tracks_with_new_detections(self, detections: sv.Detections, xy: np.ndarray, frame: np.ndarray):
|
| 91 |
+
|
| 92 |
+
detections.tracker_id = -np.ones(shape=(len(detections)), dtype=np.int64)
|
| 93 |
+
masks = detections.mask
|
| 94 |
+
|
| 95 |
+
'''First Layer | Mask-based tracking:
|
| 96 |
+
Safely track players based on their masks coordinates. When in doubt, leave the detections untracked'''
|
| 97 |
+
# Cost_matrix_ij = 1 - IoU(mask_i, mask_j)
|
| 98 |
+
null_track = self.previous_detections.tracker_id == -1
|
| 99 |
+
mask_cost_matrix = 1.0 - mask_iou(masks, self.previous_detections[~null_track].mask)
|
| 100 |
+
matches, unmatched_rows_t, _ = linear_assignment(mask_cost_matrix, self.hungarian_mask_threshold)
|
| 101 |
+
|
| 102 |
+
# Apply results
|
| 103 |
+
detections.tracker_id[matches[:,0]] = self.previous_detections[~null_track].tracker_id[matches[:,1]]
|
| 104 |
+
|
| 105 |
+
# Remainder
|
| 106 |
+
unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
|
| 107 |
+
mask_based_matches = len(matches)
|
| 108 |
+
|
| 109 |
+
if len(unmatched_rows_t) == 0:
|
| 110 |
+
self.save_statistics(detections, xy, mask_based_matches)
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
'''Second Layer | Court-position-based tracking:
|
| 114 |
+
Safely track remaining un-matched player based on their court (x,y) coordinates.
|
| 115 |
+
'''
|
| 116 |
+
pos_based_matches = 0
|
| 117 |
+
dist_cost_matrix = get_distance_cost_matrix(
|
| 118 |
+
xy,
|
| 119 |
+
self.previous_xy[~null_track],
|
| 120 |
+
ord = 2, # EUCLIDIAN DISTANCE
|
| 121 |
+
)
|
| 122 |
+
dist_cost_matrix[matches[:,0], :] = 1e3
|
| 123 |
+
dist_cost_matrix[:, matches[:,1]] = 1e3
|
| 124 |
+
|
| 125 |
+
matches_, _, _ = linear_assignment(dist_cost_matrix, self.hungarian_pos_threshold)
|
| 126 |
+
|
| 127 |
+
# Apply results
|
| 128 |
+
for match_ in matches_:
|
| 129 |
+
if match_[0] in matches[:,0]:
|
| 130 |
+
continue
|
| 131 |
+
detections.tracker_id[match_[0]] = self.previous_detections[~null_track].tracker_id[match_[1]]
|
| 132 |
+
pos_based_matches += 1
|
| 133 |
+
|
| 134 |
+
# Remainder
|
| 135 |
+
unmatched_rows_t = [i for i in range(len(detections)) if detections.tracker_id[i] == -1]
|
| 136 |
+
unmatched_track_ids_t_1 = list(set(self.track_ids) - set(detections.tracker_id[detections.tracker_id != -1]))
|
| 137 |
+
|
| 138 |
+
if len(unmatched_rows_t) == 0:
|
| 139 |
+
self.save_statistics(detections, xy, mask_based_matches, pos_based_matches)
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
'''Third Layer | Appearance-based tracking:
|
| 143 |
+
Use a vision model to match remaining player crops to their corresponding crop at t-1
|
| 144 |
+
'''
|
| 145 |
+
|
| 146 |
+
unmatched = 0
|
| 147 |
+
appearance_based_matches = 0
|
| 148 |
+
new_detections = 0
|
| 149 |
+
|
| 150 |
+
while len(unmatched_rows_t) > 0:
|
| 151 |
+
|
| 152 |
+
unmatched_row_t = unmatched_rows_t.pop(0)
|
| 153 |
+
|
| 154 |
+
# If there is only one un-matched mask at t-1 and t, they must correspond to the same player (assuming all players have been detected once, so there's no new player)
|
| 155 |
+
if self.all_players_detected and len(unmatched_track_ids_t_1) == 1 and len(unmatched_rows_t) == 0:
|
| 156 |
+
detections.tracker_id[unmatched_row_t] = unmatched_track_ids_t_1[0]
|
| 157 |
+
unmatched_track_ids_t_1.pop(0)
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
'''Appearance-based tracking: track remaining un-matched players'''
|
| 161 |
+
query_crop = get_crops_from_masks(frame, detections[unmatched_row_t].mask)[0] # Crop unmatched player at time t
|
| 162 |
+
base_candidate_crops = [self.track_id_to_crop[t_id][0] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
|
| 163 |
+
latest_candidate_crops = [self.track_id_to_crop[t_id][1] for t_id in unmatched_track_ids_t_1] # Previous crops of unmatched players
|
| 164 |
+
|
| 165 |
+
probs = self.matcher.predict(query_crop, base_candidate_crops)
|
| 166 |
+
probs = (probs + self.matcher.predict(query_crop, latest_candidate_crops)) / 2
|
| 167 |
+
prediction = matcher_probs_custom_argmax(probs)
|
| 168 |
+
|
| 169 |
+
if prediction != len(base_candidate_crops):
|
| 170 |
+
pred_track_id = unmatched_track_ids_t_1[prediction]
|
| 171 |
+
detections.tracker_id[unmatched_row_t] = pred_track_id
|
| 172 |
+
|
| 173 |
+
unmatched_track_ids_t_1.pop(prediction)
|
| 174 |
+
appearance_based_matches += 1
|
| 175 |
+
|
| 176 |
+
# still unmatched -> (likely) a new player
|
| 177 |
+
elif not(self.all_players_detected):
|
| 178 |
+
new_track_id = max(self.track_ids) + 1
|
| 179 |
+
detections.tracker_id[unmatched_row_t] = new_track_id
|
| 180 |
+
|
| 181 |
+
new_detections += 1
|
| 182 |
+
self.track_ids.append(new_track_id)
|
| 183 |
+
self.all_players_detected = len(self.track_ids) == 10
|
| 184 |
+
|
| 185 |
+
else:
|
| 186 |
+
unmatched += 1
|
| 187 |
+
|
| 188 |
+
self.save_statistics(detections, xy, mask_based_matches, pos_based_matches, appearance_based_matches, new_detections, unmatched)
|
| 189 |
+
|
| 190 |
+
def save_statistics(self, detections, xy, mask_based_matches, pos_based_matches=0, appearance_based_matches=0, new_detections=0, unmatched=0):
|
| 191 |
+
'''Update tracking statistics'''
|
| 192 |
+
self.frame_id += 1
|
| 193 |
+
self.stats[self.frame_id] = {
|
| 194 |
+
"detected_players" : len(detections),
|
| 195 |
+
"all_players_detected" : self.all_players_detected,
|
| 196 |
+
"mask_based_matches" : mask_based_matches,
|
| 197 |
+
"position_based_matches" : pos_based_matches,
|
| 198 |
+
"appearance_based_matches" : appearance_based_matches,
|
| 199 |
+
"new_detections" : new_detections,
|
| 200 |
+
"unmatched" : unmatched
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
for i in range(len(detections)):
|
| 204 |
+
track_id = detections.tracker_id[i]
|
| 205 |
+
if track_id != -1:
|
| 206 |
+
self.track_id_to_crop[track_id][1] = get_crops_from_masks(frame, detections[i].mask)[0]
|
| 207 |
+
self.previous_detections = detections
|
| 208 |
+
self.previous_xy = xy
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
|
| 212 |
+
from basketball_analysis import Matcher
|
| 213 |
+
from utils import show_annotations, annotate_frame
|
| 214 |
+
from inference import get_model
|
| 215 |
+
|
| 216 |
+
VIDEO_PATH = "DEN_SAC_1_2025.mp4"
|
| 217 |
+
HUNGARIAN_MASK_THRESHOLD = 0.6
|
| 218 |
+
HUNGARIAN_POS_THRESHOLD = 2.0
|
| 219 |
+
|
| 220 |
+
SEGMENTATION_CONFIDENCE_THRESHOLD = 0.4
|
| 221 |
+
SEG_MODEL = RFDETRSeg2XLarge(resolution=1008, pretrain_weights="checkpoint_best_ema.pth")
|
| 222 |
+
SEG_MODEL.optimize_for_inference()
|
| 223 |
+
|
| 224 |
+
ROBOFLOW_API_KEY = "PUNfWgLHrHDufisOOaZp"
|
| 225 |
+
KEYPOINT_DETECTION_MODEL_ID = "basketball-court-detection-2/14"
|
| 226 |
+
KEYPOINT_MODEL = get_model(model_id=KEYPOINT_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
|
| 227 |
+
KEYPOINT_COLOR = sv.Color.from_hex('#FF1493')
|
| 228 |
+
|
| 229 |
+
matcher = Matcher(10,8, "DINOv2_small")
|
| 230 |
+
sd = torch.load("matcher_tuned.pt")
|
| 231 |
+
matcher.load_state_dict(sd)
|
| 232 |
+
|
| 233 |
+
for p in matcher.parameters():
|
| 234 |
+
p.requires_grad_(False)
|
| 235 |
+
matcher.eval();
|
| 236 |
+
|
| 237 |
+
def get_models_predictions(frame):
|
| 238 |
+
|
| 239 |
+
# Segmentation
|
| 240 |
+
detections = SEG_MODEL.predict(frame, threshold=SEGMENTATION_CONFIDENCE_THRESHOLD)
|
| 241 |
+
keep = mask_nms(detections.mask, detections.confidence, iou_thresh=0.2)
|
| 242 |
+
detections = detections[keep]
|
| 243 |
+
if len(detections) > 10:
|
| 244 |
+
# keep first 10 detections (10 highest confidence detections)
|
| 245 |
+
detections = detections[:10]
|
| 246 |
+
|
| 247 |
+
# X,Y coordinates retrieval
|
| 248 |
+
court_xy = get_players_court_xy(frame, detections, KEYPOINT_MODEL)
|
| 249 |
+
|
| 250 |
+
return detections, court_xy
|
| 251 |
+
|
| 252 |
+
video_iterator = sv.get_video_frames_generator(VIDEO_PATH)
|
| 253 |
+
frame = toRGB(next(video_iterator))
|
| 254 |
+
initial_detections, initial_xy = get_models_predictions(frame)
|
| 255 |
+
|
| 256 |
+
history = []
|
| 257 |
+
tracker = Tracker(initial_detections, initial_xy, frame, matcher, HUNGARIAN_MASK_THRESHOLD, HUNGARIAN_POS_THRESHOLD)
|
| 258 |
+
history.append(annotate_frame(frame, initial_detections))
|
| 259 |
+
|
| 260 |
+
for frame_id, frame in tqdm(enumerate(video_iterator, start=1)):
|
| 261 |
+
|
| 262 |
+
frame = toRGB(frame)
|
| 263 |
+
detections, xy = get_models_predictions(frame)
|
| 264 |
+
tracker.update_tracks_with_new_detections(detections, xy, frame)
|
| 265 |
+
history.append(annotate_frame(frame, detections))
|
| 266 |
+
if frame_id == 150:
|
| 267 |
+
Image.fromarray(history[-1]).save("-1.png")
|
| 268 |
+
Image.fromarray(history[0]).save("0.png")
|
| 269 |
+
interact(local=locals())
|
| 270 |
+
|
basketball_analysis/utils.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import supervision as sv
|
| 5 |
+
from pycocotools import mask as mask_utils
|
| 6 |
+
import cv2
|
| 7 |
+
import ffmpeg
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import List, Iterable
|
| 11 |
+
from matplotlib import pyplot as plt
|
| 12 |
+
|
| 13 |
+
class SAM2Tracker:
|
| 14 |
+
def __init__(self, predictor) -> None:
|
| 15 |
+
self.predictor = predictor
|
| 16 |
+
self._prompted = False
|
| 17 |
+
|
| 18 |
+
def prompt_first_frame(self, frame: np.ndarray, detections: sv.Detections) -> None:
|
| 19 |
+
if len(detections) == 0:
|
| 20 |
+
raise ValueError("detections must contain at least one box")
|
| 21 |
+
|
| 22 |
+
if detections.tracker_id is None:
|
| 23 |
+
detections.tracker_id = list(range(1, len(detections) + 1))
|
| 24 |
+
|
| 25 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 26 |
+
self.predictor.load_first_frame(frame)
|
| 27 |
+
for xyxy, obj_id in zip(detections.xyxy, detections.tracker_id):
|
| 28 |
+
bbox = np.asarray([xyxy], dtype=np.float32)
|
| 29 |
+
self.predictor.add_new_prompt(
|
| 30 |
+
frame_idx=0,
|
| 31 |
+
obj_id=int(obj_id),
|
| 32 |
+
bbox=bbox,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self._prompted = True
|
| 36 |
+
|
| 37 |
+
def propagate(self, frame: np.ndarray) -> sv.Detections:
|
| 38 |
+
if not self._prompted:
|
| 39 |
+
raise RuntimeError("Call prompt_first_frame before propagate")
|
| 40 |
+
|
| 41 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 42 |
+
tracker_ids, mask_logits = self.predictor.track(frame)
|
| 43 |
+
|
| 44 |
+
tracker_ids = np.asarray(tracker_ids, dtype=np.int32)
|
| 45 |
+
masks = (mask_logits > 0.0).cpu().numpy()
|
| 46 |
+
masks = np.squeeze(masks).astype(bool)
|
| 47 |
+
|
| 48 |
+
if masks.ndim == 2:
|
| 49 |
+
masks = masks[None, ...]
|
| 50 |
+
|
| 51 |
+
masks = np.array([
|
| 52 |
+
sv.filter_segments_by_distance(mask, relative_distance=0.03, mode="edge")
|
| 53 |
+
for mask in masks
|
| 54 |
+
])
|
| 55 |
+
|
| 56 |
+
xyxy = sv.mask_to_xyxy(masks=masks)
|
| 57 |
+
detections = sv.Detections(xyxy=xyxy, mask=masks, tracker_id=tracker_ids)
|
| 58 |
+
return detections
|
| 59 |
+
|
| 60 |
+
def reset(self) -> None:
|
| 61 |
+
self._prompted = False
|
| 62 |
+
|
| 63 |
+
def get_crops_from_masks(frame: np.ndarray, masks: np.ndarray) -> list[np.ndarray]:
|
| 64 |
+
"""
|
| 65 |
+
Args:mask_index
|
| 66 |
+
frame: (H, W, 3) image
|
| 67 |
+
masks: (N, H, W) binary masks
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
List of cropped images, one per mask. Each crop is a rectangular
|
| 71 |
+
bounding box around the mask, with black pixels outside the mask.
|
| 72 |
+
"""
|
| 73 |
+
crops = []
|
| 74 |
+
|
| 75 |
+
for mask in masks:
|
| 76 |
+
|
| 77 |
+
# Find bounding box of the mask
|
| 78 |
+
ys, xs = np.where(mask)
|
| 79 |
+
if len(xs) == 0 or len(ys) == 0:
|
| 80 |
+
# Empty mask → skip or return empty crop
|
| 81 |
+
crops.append(np.zeros((0, 0, 3), dtype=frame.dtype))
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
y_min, y_max = ys.min(), ys.max() + 1
|
| 85 |
+
x_min, x_max = xs.min(), xs.max() + 1
|
| 86 |
+
|
| 87 |
+
# Crop the frame and mask
|
| 88 |
+
frame_crop = frame[y_min:y_max, x_min:x_max]
|
| 89 |
+
mask_crop = mask[y_min:y_max, x_min:x_max]
|
| 90 |
+
|
| 91 |
+
# Apply mask: keep pixels where mask is True, else black
|
| 92 |
+
crop = np.zeros_like(frame_crop)
|
| 93 |
+
crop[mask_crop] = frame_crop[mask_crop]
|
| 94 |
+
|
| 95 |
+
crops.append(crop)
|
| 96 |
+
|
| 97 |
+
return crops
|
| 98 |
+
|
| 99 |
+
def f(detections: sv.Detections, track_history: dict, frame_index):
|
| 100 |
+
|
| 101 |
+
for i in range(len(detections)):
|
| 102 |
+
|
| 103 |
+
mask = detections.mask[i]
|
| 104 |
+
rle = mask_utils.encode(np.asfortranarray(mask))
|
| 105 |
+
track_history[int(detections.tracker_id[i])].append((frame_index, rle['counts']))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def toRGB(img: np.ndarray):
|
| 109 |
+
return cv2.cvtColor(img, code=cv2.COLOR_BGR2RGB)
|
| 110 |
+
|
| 111 |
+
def read_frame_from_video(in_filename, frame_num):
|
| 112 |
+
raw_bytes, err = (
|
| 113 |
+
ffmpeg
|
| 114 |
+
.input(in_filename)
|
| 115 |
+
.filter('select', 'gte(n,{})'.format(frame_num))
|
| 116 |
+
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
|
| 117 |
+
.global_args('-loglevel', 'error')
|
| 118 |
+
.run(capture_stdout=True)
|
| 119 |
+
)
|
| 120 |
+
assert len(raw_bytes) == 1080 * 1920 * 3
|
| 121 |
+
return np.frombuffer(raw_bytes, np.uint8).reshape(1, 1080, 1920, 3).copy()
|
| 122 |
+
|
| 123 |
+
def read_consecutive_frames_from_video(in_filename, start_frame, num_frames) -> np.ndarray:
|
| 124 |
+
|
| 125 |
+
out, err = ffmpeg.input(in_filename)\
|
| 126 |
+
.output(
|
| 127 |
+
'pipe:1',
|
| 128 |
+
vf=f'select=between(n\\,{start_frame}\\,{start_frame + num_frames - 1})',
|
| 129 |
+
vsync=0,
|
| 130 |
+
vframes=num_frames,
|
| 131 |
+
format='rawvideo',
|
| 132 |
+
pix_fmt='rgb24'
|
| 133 |
+
).global_args('-loglevel', 'error')\
|
| 134 |
+
.run(capture_stdout=True, capture_stderr=True)
|
| 135 |
+
|
| 136 |
+
W, H = 1920, 1080
|
| 137 |
+
frame_size = W * H * 3
|
| 138 |
+
frames = np.frombuffer(out, np.uint8)
|
| 139 |
+
|
| 140 |
+
if frames.size != num_frames * frame_size:
|
| 141 |
+
raise RuntimeError(
|
| 142 |
+
f'Expected {num_frames * frame_size} bytes, got {frames.size}\n'
|
| 143 |
+
f'ffmpeg stderr:\n{err.decode()}'
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# frames.setflags(write=True)
|
| 147 |
+
return frames.reshape(num_frames, H, W, 3).copy()
|
| 148 |
+
|
| 149 |
+
def xywhn_to_xywh(xywhn:list, height:int, width:int):
|
| 150 |
+
|
| 151 |
+
x,y,w,h = xywhn
|
| 152 |
+
|
| 153 |
+
return [int(x * width), int(y * height), int(w * width), int(h * height)]
|
| 154 |
+
|
| 155 |
+
def crop_frame_at_mask_from_bbox(frame: np.ndarray, mask: np.ndarray, bbox: list) -> np.array:
|
| 156 |
+
|
| 157 |
+
x,y,w,h = bbox
|
| 158 |
+
crop = frame[y: y+h, x: x+w]
|
| 159 |
+
cropped_mask = mask[y: y+h, x: x+w]
|
| 160 |
+
# from code import interact; interact(local=locals())
|
| 161 |
+
crop[~cropped_mask] = np.array([0,0,0], dtype=np.uint8)
|
| 162 |
+
|
| 163 |
+
return crop
|
| 164 |
+
|
| 165 |
+
def find_consecutive_streaks(nums: list|Iterable):
|
| 166 |
+
|
| 167 |
+
if isinstance(nums, Iterable): nums = list(nums)
|
| 168 |
+
if not nums:
|
| 169 |
+
return []
|
| 170 |
+
|
| 171 |
+
streaks = []
|
| 172 |
+
start = nums[0]
|
| 173 |
+
for i in range(1, len(nums)):
|
| 174 |
+
if nums[i] != nums[i-1] + 1:
|
| 175 |
+
stop = nums[i-1]
|
| 176 |
+
streaks.append(range(start, stop + 1))
|
| 177 |
+
start = nums[i]
|
| 178 |
+
|
| 179 |
+
streaks.append(range(start, nums[-1] + 1))
|
| 180 |
+
return streaks
|
| 181 |
+
|
| 182 |
+
def save_loss_history(fpath, loss:float):
|
| 183 |
+
|
| 184 |
+
with open(fpath, "a+") as f:
|
| 185 |
+
f.write(f"{loss:.6f}\n")
|
| 186 |
+
|
| 187 |
+
def save_loss_history_plot(loss_history: list[float], fpath):
|
| 188 |
+
|
| 189 |
+
plt.plot(loss_history)
|
| 190 |
+
plt.savefig(fpath)
|
| 191 |
+
|
| 192 |
+
def save_checkpoint(
|
| 193 |
+
path,
|
| 194 |
+
model,
|
| 195 |
+
optimizer,
|
| 196 |
+
epoch,
|
| 197 |
+
step,
|
| 198 |
+
):
|
| 199 |
+
|
| 200 |
+
ckpt = {
|
| 201 |
+
"model": model.state_dict(),
|
| 202 |
+
"optimizer": optimizer.state_dict(),
|
| 203 |
+
"epoch": epoch,
|
| 204 |
+
"step": step,
|
| 205 |
+
}
|
| 206 |
+
torch.save(ckpt, path)
|
| 207 |
+
|
| 208 |
+
def load_checkpoint(
|
| 209 |
+
path,
|
| 210 |
+
model,
|
| 211 |
+
optimizer,
|
| 212 |
+
device="cuda"
|
| 213 |
+
):
|
| 214 |
+
ckpt = torch.load(path, map_location=device)
|
| 215 |
+
|
| 216 |
+
model.load_state_dict(ckpt["model"])
|
| 217 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 218 |
+
|
| 219 |
+
epoch = ckpt.get("epoch", 0)
|
| 220 |
+
step = ckpt.get("step", 0)
|
| 221 |
+
|
| 222 |
+
return epoch, step
|
| 223 |
+
|
| 224 |
+
def mask_iou_pair(m1, m2):
|
| 225 |
+
inter = np.logical_and(m1, m2).sum()
|
| 226 |
+
if inter == 0:
|
| 227 |
+
return 0.0
|
| 228 |
+
union = m1.sum() + m2.sum() - inter
|
| 229 |
+
return inter / (union + 1e-6)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def mask_nms(masks, scores, iou_thresh=0.6):
|
| 233 |
+
order = np.argsort(-scores)
|
| 234 |
+
keep = []
|
| 235 |
+
suppressed = np.zeros(len(masks), dtype=bool)
|
| 236 |
+
|
| 237 |
+
for i in order:
|
| 238 |
+
if suppressed[i]:
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
keep.append(i)
|
| 242 |
+
|
| 243 |
+
for j in order:
|
| 244 |
+
if j <= i or suppressed[j]:
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
iou = mask_iou_pair(masks[i], masks[j])
|
| 248 |
+
if iou > iou_thresh:
|
| 249 |
+
suppressed[j] = True
|
| 250 |
+
|
| 251 |
+
return keep
|
| 252 |
+
|
| 253 |
+
def mask_iou(masks_t: np.ndarray, masks_t1):
|
| 254 |
+
# Flatten
|
| 255 |
+
N, H, W = masks_t.shape
|
| 256 |
+
M = masks_t1.shape[0]
|
| 257 |
+
|
| 258 |
+
masks_t = masks_t.reshape(N, -1).astype(float) # (N, HW)
|
| 259 |
+
masks_t1 = masks_t1.reshape(M, -1).astype(float) # (M, HW)
|
| 260 |
+
|
| 261 |
+
# Intersection: (N, M)
|
| 262 |
+
intersection = masks_t @ masks_t1.T
|
| 263 |
+
|
| 264 |
+
# Areas
|
| 265 |
+
area_t = masks_t.sum(1, keepdims=True) # (N, 1)
|
| 266 |
+
area_t1 = masks_t1.sum(1, keepdims=True) # (M, 1)
|
| 267 |
+
|
| 268 |
+
# Union
|
| 269 |
+
union = area_t + area_t1.T - intersection
|
| 270 |
+
|
| 271 |
+
iou = intersection / (union + 1e-6)
|
| 272 |
+
return iou # (N, M)
|
| 273 |
+
|
| 274 |
+
COURT_KEYPOINT_COORDINATES = np.array([
|
| 275 |
+
(0.0, 0.0),
|
| 276 |
+
(0.0, 2.99),
|
| 277 |
+
(0.0, 17.0),
|
| 278 |
+
(0.0, 33.01),
|
| 279 |
+
(0.0, 47.02),
|
| 280 |
+
(0.0, 50.0),
|
| 281 |
+
(5.25, 25.0),
|
| 282 |
+
(13.92, 2.99),
|
| 283 |
+
(13.92, 47.02),
|
| 284 |
+
(19.0, 17.0),
|
| 285 |
+
(19.0, 25.0),
|
| 286 |
+
(19.0, 33.01),
|
| 287 |
+
(27.4, 0.0),
|
| 288 |
+
(29.01, 25.0),
|
| 289 |
+
(27.4, 50.0),
|
| 290 |
+
(46.99, 0.0),
|
| 291 |
+
(46.99, 25.0),
|
| 292 |
+
(46.99, 50.0),
|
| 293 |
+
(66.61, 0.0),
|
| 294 |
+
(65.0, 25.0),
|
| 295 |
+
(66.61, 50.0),
|
| 296 |
+
(75.0, 17.0),
|
| 297 |
+
(75.0, 25.0),
|
| 298 |
+
(75.0, 33.01),
|
| 299 |
+
(80.09, 2.99),
|
| 300 |
+
(80.09, 47.02),
|
| 301 |
+
(88.75, 25.0),
|
| 302 |
+
(94.0, 0.0),
|
| 303 |
+
(94.0, 2.99),
|
| 304 |
+
(94.0, 17.0),
|
| 305 |
+
(94.0, 33.01),
|
| 306 |
+
(94.0, 47.02),
|
| 307 |
+
(94.0, 50.0)
|
| 308 |
+
])
|
| 309 |
+
|
| 310 |
+
def get_distance_cost_matrix(arr1:np.ndarray, arr2:np.ndarray, ord=1) :
|
| 311 |
+
|
| 312 |
+
cost_matrix = np.empty(shape=(len(arr1), len(arr2)), dtype=np.float64)
|
| 313 |
+
|
| 314 |
+
for i in range(len(arr1)):
|
| 315 |
+
cost_matrix[i] = np.linalg.norm(arr1[i] - arr2, ord=ord, axis=-1)
|
| 316 |
+
|
| 317 |
+
return torch.tensor(cost_matrix)
|
| 318 |
+
|
| 319 |
+
def matcher_probs_custom_argmax(probs:np.ndarray, confidence_threshold=0.7):
|
| 320 |
+
probs = probs.squeeze(0)
|
| 321 |
+
pred = probs.argmax()
|
| 322 |
+
# if matcher predicts the null prediction, but it is not confident
|
| 323 |
+
if pred == len(probs) - 1 and probs[pred] < confidence_threshold:
|
| 324 |
+
# predict the second most confident prediction if it has high weight
|
| 325 |
+
second_best = probs[:-1].argmax()
|
| 326 |
+
if probs[second_best] > 1.0 - confidence_threshold - 0.05:
|
| 327 |
+
pred = second_best
|
| 328 |
+
|
| 329 |
+
return pred
|
| 330 |
+
|
| 331 |
+
def show_annotations(frame_, detections_):
|
| 332 |
+
annotated_frame = frame_.copy()
|
| 333 |
+
annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_)
|
| 334 |
+
annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id))
|
| 335 |
+
return Image.fromarray(annotated_frame)
|
| 336 |
+
|
| 337 |
+
def annotate_frame(frame_, detections_):
|
| 338 |
+
annotated_frame = frame_.copy()
|
| 339 |
+
annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_)
|
| 340 |
+
annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id))
|
| 341 |
+
return annotated_frame
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
from code import interact
|
| 345 |
+
frames = read_consecutive_frames_from_video("nba_sample_videos/batch2/SAC_LAL_1.mp4", 199, 1)
|
| 346 |
+
# crop_frame_at_mask_from_bbox(np.zeros((1080, 1920, 3)), )
|
| 347 |
+
interact(local=locals())
|
basketball_analysis/view_transformer.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sports import MeasurementUnit
|
| 2 |
+
from sports.basketball import CourtConfiguration, League, draw_court, draw_points_on_court
|
| 3 |
+
import numpy as np
|
| 4 |
+
import supervision as sv
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
CONFIG = CourtConfiguration(league=League.NBA, measurement_unit=MeasurementUnit.FEET).vertices
|
| 8 |
+
|
| 9 |
+
def frame_xy_to_court_xy(frame_xy: np.ndarray, H: np.ndarray):
|
| 10 |
+
|
| 11 |
+
assert frame_xy.shape[1] == 2
|
| 12 |
+
n_points = frame_xy.shape[0]
|
| 13 |
+
|
| 14 |
+
court_xy = np.hstack((frame_xy, np.ones(shape=(n_points, 1)))) @ H.T
|
| 15 |
+
court_xy_norm = court_xy[:, :2] / court_xy[:, [-1]]
|
| 16 |
+
return court_xy_norm
|
| 17 |
+
|
| 18 |
+
def get_players_court_xy(frame, detections, model, use_bottom_center=True, normalize=False):
|
| 19 |
+
KEYPOINT_DETECTION_MODEL_CONFIDENCE = 0.3
|
| 20 |
+
KEYPOINT_DETECTION_MODEL_ANCHOR_CONFIDENCE = 0.5
|
| 21 |
+
|
| 22 |
+
# Locate court keypoints (or reference points)
|
| 23 |
+
result = model.infer(frame, confidence=KEYPOINT_DETECTION_MODEL_CONFIDENCE)[0]
|
| 24 |
+
key_points = sv.KeyPoints.from_inference(result)
|
| 25 |
+
filter_mask = key_points.confidence[0] > KEYPOINT_DETECTION_MODEL_ANCHOR_CONFIDENCE
|
| 26 |
+
|
| 27 |
+
# Compute homography matrix H
|
| 28 |
+
court_landmarks = np.array(CONFIG)[filter_mask]
|
| 29 |
+
frame_landmarks = key_points[:, filter_mask].xy[0]
|
| 30 |
+
H, _ = cv2.findHomography(frame_landmarks, court_landmarks)
|
| 31 |
+
|
| 32 |
+
# From the player detections, retrieve their position on the court
|
| 33 |
+
x1 = detections.xyxy[:, 0]
|
| 34 |
+
x2 = detections.xyxy[:, 2]
|
| 35 |
+
y1 = detections.xyxy[:, 1]
|
| 36 |
+
y2 = detections.xyxy[:, 3]
|
| 37 |
+
if use_bottom_center:
|
| 38 |
+
# Take the bottom center of the bounding box as the (x,y) coordinate
|
| 39 |
+
frame_xy = np.vstack(
|
| 40 |
+
(x1 + (x2 - x1) / 2, y2)
|
| 41 |
+
).T
|
| 42 |
+
else:
|
| 43 |
+
frame_xy = np.vstack(
|
| 44 |
+
(x1 + (x2 - x1) / 2, y1 + (y2 - y1) / 2)
|
| 45 |
+
).T
|
| 46 |
+
# apply homographic transformation
|
| 47 |
+
court_xy = frame_xy_to_court_xy(frame_xy, H)
|
| 48 |
+
|
| 49 |
+
if normalize:
|
| 50 |
+
court_xy = court_xy / np.array([94.0, 50.0])
|
| 51 |
+
|
| 52 |
+
return court_xy
|
| 53 |
+
|
| 54 |
+
def show_positions_on_court(court_xy):
|
| 55 |
+
court = draw_court(config=CONFIG)
|
| 56 |
+
court = draw_points_on_court(
|
| 57 |
+
config=CONFIG,
|
| 58 |
+
xy=court_xy,
|
| 59 |
+
court=court
|
| 60 |
+
)
|
| 61 |
+
sv.plot_image(court)
|