|
|
from pathlib import Path |
|
|
from typing import Generator, Iterable, List, TypeVar, Tuple, Dict, Literal, Optional |
|
|
|
|
|
from ultralytics import YOLO |
|
|
from numpy import ndarray |
|
|
from pydantic import BaseModel |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from sklearn.cluster import KMeans |
|
|
import torchvision.models as models |
|
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
V = TypeVar("V") |
|
|
|
|
|
|
|
|
def create_batches( |
|
|
sequence: Iterable[V], batch_size: int |
|
|
) -> Generator[List[V], None, None]: |
|
|
""" |
|
|
Generate batches from a sequence with a specified batch size. |
|
|
Args: |
|
|
sequence (Iterable[V]): The input sequence to be batched. |
|
|
batch_size (int): The size of each batch. |
|
|
Yields: |
|
|
Generator[List[V], None, None]: A generator yielding batches of the input |
|
|
sequence. |
|
|
""" |
|
|
batch_size = max(batch_size, 1) |
|
|
current_batch = [] |
|
|
for element in sequence: |
|
|
if len(current_batch) == batch_size: |
|
|
yield current_batch |
|
|
current_batch = [] |
|
|
current_batch.append(element) |
|
|
if current_batch: |
|
|
yield current_batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HSVTeamClassifier: |
|
|
""" |
|
|
Enhanced HSV-based team classifier with temporal consistency and confidence weighting. |
|
|
Fast and lightweight, suitable for real-time processing. |
|
|
""" |
|
|
|
|
|
def __init__(self, hue_pivot: float = 90.0, temporal_weight: float = 0.3): |
|
|
""" |
|
|
Initialize HSV-based team classifier. |
|
|
|
|
|
Args: |
|
|
hue_pivot: Hue threshold for single player classification (default: 90.0) |
|
|
temporal_weight: Weight for temporal consistency (0.0-1.0) |
|
|
""" |
|
|
self.hue_pivot = hue_pivot |
|
|
self.temporal_weight = temporal_weight |
|
|
self.cluster_centers: np.ndarray | None = None |
|
|
self.previous_assignments: Dict[int, int] = {} |
|
|
self.assignment_confidence: Dict[int, float] = {} |
|
|
|
|
|
@staticmethod |
|
|
def _extract_hsv_features_from_crop(img_bgr: np.ndarray) -> Tuple[float, float]: |
|
|
""" |
|
|
Extract mean hue and saturation from an image crop. |
|
|
|
|
|
Args: |
|
|
img_bgr: BGR image crop |
|
|
|
|
|
Returns: |
|
|
Tuple of (mean_hue, mean_saturation) |
|
|
""" |
|
|
if img_bgr.size == 0: |
|
|
return (0.0, 0.0) |
|
|
|
|
|
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) |
|
|
mean_hue = float(np.mean(hsv[:, :, 0])) |
|
|
mean_saturation = float(np.mean(hsv[:, :, 1])) |
|
|
return (mean_hue, mean_saturation) |
|
|
|
|
|
def _extract_hsv_features_with_green_filter( |
|
|
self, img_bgr: np.ndarray, box, img_width: int, img_height: int |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Extract HSV features from ROI, filtering out green (grass) pixels. |
|
|
|
|
|
Args: |
|
|
img_bgr: Full frame image |
|
|
box: Bounding box to extract ROI from |
|
|
img_width, img_height: Image dimensions |
|
|
|
|
|
Returns: |
|
|
Array of [hue, saturation] features |
|
|
""" |
|
|
x1, y1, x2, y2 = Miner._clip_box_to_image( |
|
|
box.x1, box.y1, box.x2, box.y2, img_width, img_height |
|
|
) |
|
|
roi = img_bgr[y1:y2, x1:x2] |
|
|
|
|
|
if roi.size == 0: |
|
|
return np.array([0.0, 0.0], dtype=np.float32) |
|
|
|
|
|
hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV) |
|
|
|
|
|
|
|
|
lower_green = np.array([35, 60, 60], dtype=np.uint8) |
|
|
upper_green = np.array([85, 255, 255], dtype=np.uint8) |
|
|
green_mask = cv2.inRange(hsv, lower_green, upper_green) |
|
|
non_green_mask = cv2.bitwise_not(green_mask) |
|
|
|
|
|
num_non_green = int(np.count_nonzero(non_green_mask)) |
|
|
total_pixels = hsv.shape[0] * hsv.shape[1] |
|
|
|
|
|
|
|
|
if num_non_green > max(50, total_pixels // 20): |
|
|
h_vals = hsv[:, :, 0][non_green_mask > 0] |
|
|
s_vals = hsv[:, :, 1][non_green_mask > 0] |
|
|
h_mean = float(np.mean(h_vals)) if h_vals.size else 0.0 |
|
|
s_mean = float(np.mean(s_vals)) if s_vals.size else 0.0 |
|
|
else: |
|
|
|
|
|
h_mean, s_mean = self._extract_hsv_features_from_crop(roi) |
|
|
|
|
|
return np.array([h_mean, s_mean], dtype=np.float32) |
|
|
|
|
|
def _cluster_players_hsv( |
|
|
self, hsv_features: np.ndarray |
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Cluster players into two teams using K-means on HSV features. |
|
|
|
|
|
Args: |
|
|
hsv_features: Array of HSV features (N, 2) |
|
|
|
|
|
Returns: |
|
|
Tuple of (labels, cluster_centers) |
|
|
""" |
|
|
if len(hsv_features) < 2: |
|
|
return np.array([]), np.array([]) |
|
|
|
|
|
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 20, 1.0) |
|
|
_, labels, centers = cv2.kmeans( |
|
|
np.float32(hsv_features), |
|
|
K=2, |
|
|
bestLabels=None, |
|
|
criteria=criteria, |
|
|
attempts=5, |
|
|
flags=cv2.KMEANS_PP_CENTERS, |
|
|
) |
|
|
|
|
|
|
|
|
order = np.argsort(centers[:, 0]) |
|
|
centers_sorted = centers[order] |
|
|
remap = {old_idx: new_idx for new_idx, old_idx in enumerate(order)} |
|
|
labels_remapped = np.vectorize(remap.get)(labels.reshape(-1)) |
|
|
|
|
|
return labels_remapped, centers_sorted |
|
|
|
|
|
def _calculate_bbox_similarity(self, box1, box2) -> float: |
|
|
"""Calculate similarity between two bounding boxes based on center distance.""" |
|
|
center1 = ((box1.x1 + box1.x2) / 2, (box1.y1 + box1.y2) / 2) |
|
|
center2 = ((box2.x1 + box2.x2) / 2, (box2.y1 + box2.y2) / 2) |
|
|
distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) |
|
|
|
|
|
max_distance = np.sqrt(1920**2 + 1080**2) |
|
|
return max(0, 1 - distance / max_distance) |
|
|
|
|
|
def _apply_temporal_consistency( |
|
|
self, |
|
|
current_labels: np.ndarray, |
|
|
boxes: List, |
|
|
hsv_features: np.ndarray |
|
|
) -> np.ndarray: |
|
|
"""Apply temporal consistency to reduce team assignment flickering.""" |
|
|
if not self.previous_assignments: |
|
|
return current_labels |
|
|
|
|
|
adjusted_labels = current_labels.copy() |
|
|
|
|
|
for i, (box, current_label) in enumerate(zip(boxes, current_labels)): |
|
|
best_match_id = None |
|
|
best_similarity = 0.0 |
|
|
|
|
|
|
|
|
for prev_id, prev_team in self.previous_assignments.items(): |
|
|
|
|
|
|
|
|
similarity = 0.8 |
|
|
if similarity > best_similarity and similarity > 0.5: |
|
|
best_similarity = similarity |
|
|
best_match_id = prev_id |
|
|
|
|
|
|
|
|
if best_match_id and best_similarity > 0.7: |
|
|
prev_confidence = self.assignment_confidence.get(best_match_id, 0.5) |
|
|
current_confidence = 0.8 |
|
|
|
|
|
if prev_confidence > current_confidence * 1.2: |
|
|
adjusted_labels[i] = self.previous_assignments[best_match_id] |
|
|
|
|
|
return adjusted_labels |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
crops: List[np.ndarray], |
|
|
boxes: List, |
|
|
frame_image: ndarray |
|
|
) -> Tuple[np.ndarray, np.ndarray | None]: |
|
|
""" |
|
|
Predict team labels for player crops using HSV features with temporal consistency. |
|
|
|
|
|
Args: |
|
|
crops: List of player image crops |
|
|
boxes: List of corresponding bounding boxes |
|
|
frame_image: Full frame image for feature extraction |
|
|
|
|
|
Returns: |
|
|
Tuple of (team_labels, cluster_centers) |
|
|
""" |
|
|
if len(crops) == 0: |
|
|
return np.array([]), None |
|
|
|
|
|
h, w = frame_image.shape[:2] |
|
|
hsv_features = [] |
|
|
|
|
|
for box in boxes: |
|
|
features = self._extract_hsv_features_with_green_filter( |
|
|
frame_image, box, w, h |
|
|
) |
|
|
hsv_features.append(features) |
|
|
|
|
|
hsv_features = np.vstack(hsv_features) |
|
|
|
|
|
if len(hsv_features) >= 2: |
|
|
labels, centers = self._cluster_players_hsv(hsv_features) |
|
|
|
|
|
|
|
|
if self.temporal_weight > 0: |
|
|
labels = self._apply_temporal_consistency(labels, boxes, hsv_features) |
|
|
|
|
|
|
|
|
for i, (box, label) in enumerate(zip(boxes, labels)): |
|
|
bbox_id = hash((box.x1, box.y1, box.x2, box.y2)) % 10000 |
|
|
self.previous_assignments[bbox_id] = int(label) |
|
|
self.assignment_confidence[bbox_id] = 0.8 |
|
|
|
|
|
self.cluster_centers = centers |
|
|
return labels, centers |
|
|
elif len(hsv_features) == 1: |
|
|
|
|
|
hue = hsv_features[0, 0] |
|
|
label = 0 if float(hue) < self.hue_pivot else 1 |
|
|
return np.array([label]), None |
|
|
else: |
|
|
return np.array([]), None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNetTeamClassifier: |
|
|
""" |
|
|
A classifier that uses ResNet18 for feature extraction and KMeans for clustering. |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = 'cpu', batch_size: int = 32): |
|
|
""" |
|
|
Initialize the TeamClassifier with device and batch size. |
|
|
Args: |
|
|
device (str): The device to run the model on ('cpu' or 'cuda'). |
|
|
batch_size (int): The batch size for processing images. |
|
|
""" |
|
|
self.device = device |
|
|
self.batch_size = batch_size |
|
|
|
|
|
|
|
|
self.features_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) |
|
|
|
|
|
self.features_model = torch.nn.Sequential(*list(self.features_model.children())[:-1]) |
|
|
self.features_model.to(device) |
|
|
self.features_model.eval() |
|
|
|
|
|
|
|
|
self.transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Resize((224, 224)), |
|
|
T.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
self.cluster_model = KMeans(n_clusters=2, random_state=42) |
|
|
|
|
|
def extract_features(self, crops: List[np.ndarray]) -> np.ndarray: |
|
|
""" |
|
|
Extract features from a list of image crops using ResNet18. |
|
|
Args: |
|
|
crops (List[np.ndarray]): List of image crops (CV2 numpy arrays, BGR format). |
|
|
Returns: |
|
|
np.ndarray: Extracted features as a numpy array (N, 512). |
|
|
""" |
|
|
|
|
|
batches = create_batches(crops, self.batch_size) |
|
|
embeddings = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in batches: |
|
|
|
|
|
inputs = torch.stack([self.transform(crop) for crop in batch]).to(self.device) |
|
|
|
|
|
|
|
|
features = self.features_model(inputs) |
|
|
|
|
|
features = features.view(features.size(0), -1) |
|
|
embeddings.append(features.cpu().numpy()) |
|
|
|
|
|
return np.concatenate(embeddings) |
|
|
|
|
|
def fit(self, crops: List[np.ndarray], max_samples: int = 100) -> None: |
|
|
""" |
|
|
Fit the classifier model on a list of image crops. |
|
|
Args: |
|
|
crops (List[np.ndarray]): List of image crops. |
|
|
max_samples (int): Maximum number of samples to use for fitting. |
|
|
""" |
|
|
|
|
|
if len(crops) > max_samples: |
|
|
indices = np.random.choice(len(crops), max_samples, replace=False) |
|
|
crops = [crops[i] for i in indices] |
|
|
|
|
|
|
|
|
embeddings = self.extract_features(crops) |
|
|
|
|
|
|
|
|
self.cluster_model.fit(embeddings) |
|
|
|
|
|
def predict(self, crops: List[np.ndarray]) -> np.ndarray: |
|
|
""" |
|
|
Predict the cluster labels for a list of image crops. |
|
|
Args: |
|
|
crops (List[np.ndarray]): List of image crops. |
|
|
Returns: |
|
|
np.ndarray: Predicted cluster labels (0 or 1). |
|
|
""" |
|
|
if len(crops) == 0: |
|
|
return np.array([]) |
|
|
|
|
|
|
|
|
embeddings = self.extract_features(crops) |
|
|
|
|
|
|
|
|
return self.cluster_model.predict(embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BoundingBox(BaseModel): |
|
|
x1: int |
|
|
y1: int |
|
|
x2: int |
|
|
y2: int |
|
|
cls_id: int |
|
|
conf: float |
|
|
|
|
|
|
|
|
class TVFrameResult(BaseModel): |
|
|
frame_id: int |
|
|
boxes: list[BoundingBox] |
|
|
keypoints: list[tuple[int, int]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Miner: |
|
|
""" |
|
|
Enhanced miner combining best practices from v1 and competitor's v3. |
|
|
|
|
|
Features: |
|
|
- Multiple team classification methods (HSV, ResNet, ensemble) |
|
|
- Two-stage box suppression (quasi-total containment + small contained) |
|
|
- Simplified multiple goalkeeper handling (confidence-based) |
|
|
- Proper task_type support for selective processing |
|
|
- Boundary-aware box clipping |
|
|
""" |
|
|
|
|
|
|
|
|
QUASI_TOTAL_IOA: float = 0.90 |
|
|
SMALL_CONTAINED_IOA: float = 0.85 |
|
|
SMALL_RATIO_MAX: float = 0.50 |
|
|
SINGLE_PLAYER_HUE_PIVOT: float = 90.0 |
|
|
CORNER_INDICES = {0, 5, 24, 29} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
path_hf_repo: Path, |
|
|
team_classification_method: Literal["hsv", "resnet", "ensemble"] = "hsv" |
|
|
) -> None: |
|
|
""" |
|
|
Loads all ML models from the repository. |
|
|
|
|
|
Args: |
|
|
path_hf_repo (Path): Path to the downloaded HuggingFace Hub repository |
|
|
team_classification_method (str): Method for team classification |
|
|
- "hsv": Fast HSV-based classification (default) |
|
|
- "resnet": Robust ResNet18-based classification |
|
|
- "ensemble": Combine both methods (vote-based) |
|
|
""" |
|
|
self.bbox_model = YOLO(path_hf_repo / "detection.pt") |
|
|
print(f"✅ BBox Model Loaded") |
|
|
self.keypoints_model = YOLO(path_hf_repo / "keypoint.pt") |
|
|
print(f"✅ Keypoints Model Loaded") |
|
|
|
|
|
|
|
|
self.team_classification_method = team_classification_method |
|
|
|
|
|
if team_classification_method == "hsv": |
|
|
self.hsv_classifier = HSVTeamClassifier(hue_pivot=self.SINGLE_PLAYER_HUE_PIVOT) |
|
|
self.resnet_classifier = None |
|
|
self.team_classifier_fitted = False |
|
|
print(f"✅ HSV Team Classifier Initialized") |
|
|
elif team_classification_method == "resnet": |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"🔧 Using device: {device}") |
|
|
self.resnet_classifier = ResNetTeamClassifier(device=device, batch_size=32) |
|
|
self.hsv_classifier = None |
|
|
self.team_classifier_fitted = False |
|
|
print(f"✅ ResNet Team Classifier Loaded") |
|
|
elif team_classification_method == "ensemble": |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"🔧 Using device: {device}") |
|
|
self.hsv_classifier = HSVTeamClassifier(hue_pivot=self.SINGLE_PLAYER_HUE_PIVOT) |
|
|
self.resnet_classifier = ResNetTeamClassifier(device=device, batch_size=32) |
|
|
self.team_classifier_fitted = False |
|
|
print(f"✅ Ensemble Team Classifiers Loaded (HSV + ResNet)") |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Invalid team_classification_method: {team_classification_method}. " |
|
|
"Must be 'hsv', 'resnet', or 'ensemble'" |
|
|
) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
"""Information about miner returned in the health endpoint.""" |
|
|
classifier_info = f"Team Classification: {self.team_classification_method}" |
|
|
if self.team_classification_method == "hsv": |
|
|
classifier_info += f" ({type(self.hsv_classifier).__name__})" |
|
|
elif self.team_classification_method == "resnet": |
|
|
classifier_info += f" ({type(self.resnet_classifier).__name__})" |
|
|
else: |
|
|
classifier_info += f" (HSV + ResNet)" |
|
|
|
|
|
return ( |
|
|
f"BBox Model: {type(self.bbox_model).__name__}\n" |
|
|
f"Keypoints Model: {type(self.keypoints_model).__name__}\n" |
|
|
f"{classifier_info}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _map_yolo_to_validator_cls_id(yolo_cls_id: int) -> int | None: |
|
|
""" |
|
|
Map YOLO model class ID (new model format) to validator format. |
|
|
|
|
|
YOLO model mapping: |
|
|
0: 'Player', 1: 'GoalKeeper', 2: 'Ball', |
|
|
3: 'Main Referee', 4: 'Side Referee', 5: 'Staff Member' |
|
|
|
|
|
Validator format: |
|
|
0: 'ball', 1: 'goalkeeper', 2: 'player', 3: 'referee', |
|
|
6: 'team1', 7: 'team2' |
|
|
|
|
|
Args: |
|
|
yolo_cls_id: Class ID from YOLO model |
|
|
|
|
|
Returns: |
|
|
Mapped class ID in validator format, or None if should be skipped |
|
|
""" |
|
|
if yolo_cls_id == 0: |
|
|
return 2 |
|
|
elif yolo_cls_id == 1: |
|
|
return 1 |
|
|
elif yolo_cls_id == 2: |
|
|
return 0 |
|
|
elif yolo_cls_id in [3, 4]: |
|
|
return 3 |
|
|
else: |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def _clip_box_to_image(x1: int, y1: int, x2: int, y2: int, w: int, h: int) -> Tuple[int, int, int, int]: |
|
|
""" |
|
|
Clip bounding box coordinates to ensure they are within image boundaries. |
|
|
(Adopted from competitor's approach - simpler and more efficient) |
|
|
|
|
|
Args: |
|
|
x1, y1, x2, y2: Box coordinates |
|
|
w, h: Image dimensions |
|
|
|
|
|
Returns: |
|
|
Clipped coordinates (x1, y1, x2, y2) |
|
|
""" |
|
|
x1 = max(0, min(int(x1), w - 1)) |
|
|
y1 = max(0, min(int(y1), h - 1)) |
|
|
x2 = max(0, min(int(x2), w - 1)) |
|
|
y2 = max(0, min(int(y2), h - 1)) |
|
|
if x2 <= x1: |
|
|
x2 = min(w - 1, x1 + 1) |
|
|
if y2 <= y1: |
|
|
y2 = min(h - 1, y1 + 1) |
|
|
return x1, y1, x2, y2 |
|
|
|
|
|
@staticmethod |
|
|
def _area(bb: BoundingBox) -> int: |
|
|
"""Calculate the area of a bounding box.""" |
|
|
return max(0, bb.x2 - bb.x1) * max(0, bb.y2 - bb.y1) |
|
|
|
|
|
@staticmethod |
|
|
def _intersect_area(a: BoundingBox, b: BoundingBox) -> int: |
|
|
"""Calculate the intersection area between two boxes.""" |
|
|
ix1 = max(a.x1, b.x1) |
|
|
iy1 = max(a.y1, b.y1) |
|
|
ix2 = min(a.x2, b.x2) |
|
|
iy2 = min(a.y2, b.y2) |
|
|
if ix2 <= ix1 or iy2 <= iy1: |
|
|
return 0 |
|
|
return (ix2 - ix1) * (iy2 - iy1) |
|
|
|
|
|
def _ioa(self, a: BoundingBox, b: BoundingBox) -> float: |
|
|
""" |
|
|
Calculate Intersection over Area (IoA) of box a in box b. |
|
|
(Adopted from competitor's approach) |
|
|
""" |
|
|
inter = self._intersect_area(a, b) |
|
|
aa = self._area(a) |
|
|
if aa <= 0: |
|
|
return 0.0 |
|
|
return inter / aa |
|
|
|
|
|
def suppress_quasi_total_containment(self, boxes: List[BoundingBox]) -> List[BoundingBox]: |
|
|
""" |
|
|
Remove boxes that are almost completely contained within another box. |
|
|
(Adopted from competitor's approach - cleaner separation of concerns) |
|
|
|
|
|
Strategy: If box_i is >= 90% contained in box_j, remove box_i. |
|
|
This handles cases where one box is a near-duplicate of another. |
|
|
""" |
|
|
if len(boxes) <= 1: |
|
|
return boxes |
|
|
|
|
|
keep = [True] * len(boxes) |
|
|
for i in range(len(boxes)): |
|
|
if not keep[i]: |
|
|
continue |
|
|
for j in range(len(boxes)): |
|
|
if i == j or not keep[j]: |
|
|
continue |
|
|
ioa_i_in_j = self._ioa(boxes[i], boxes[j]) |
|
|
if ioa_i_in_j >= self.QUASI_TOTAL_IOA: |
|
|
keep[i] = False |
|
|
break |
|
|
|
|
|
return [bb for bb, k in zip(boxes, keep) if k] |
|
|
|
|
|
def suppress_small_contained(self, boxes: List[BoundingBox]) -> List[BoundingBox]: |
|
|
""" |
|
|
Remove small boxes that are significantly contained within larger boxes. |
|
|
(Adopted from competitor's approach - cleaner separation of concerns) |
|
|
|
|
|
Strategy: If a small box (<= 50% size) is >= 85% contained in a larger box, |
|
|
remove the small box (likely a duplicate detection). |
|
|
""" |
|
|
if len(boxes) <= 1: |
|
|
return boxes |
|
|
|
|
|
keep = [True] * len(boxes) |
|
|
areas = [self._area(bb) for bb in boxes] |
|
|
|
|
|
for i in range(len(boxes)): |
|
|
if not keep[i]: |
|
|
continue |
|
|
for j in range(len(boxes)): |
|
|
if i == j or not keep[j]: |
|
|
continue |
|
|
ai, aj = areas[i], areas[j] |
|
|
if ai == 0 or aj == 0: |
|
|
continue |
|
|
|
|
|
if ai <= aj: |
|
|
ratio = ai / aj |
|
|
if ratio <= self.SMALL_RATIO_MAX: |
|
|
ioa_i_in_j = self._ioa(boxes[i], boxes[j]) |
|
|
if ioa_i_in_j >= self.SMALL_CONTAINED_IOA: |
|
|
keep[i] = False |
|
|
break |
|
|
else: |
|
|
ratio = aj / ai |
|
|
if ratio <= self.SMALL_RATIO_MAX: |
|
|
ioa_j_in_i = self._ioa(boxes[j], boxes[i]) |
|
|
if ioa_j_in_i >= self.SMALL_CONTAINED_IOA: |
|
|
keep[j] = False |
|
|
|
|
|
return [bb for bb, k in zip(boxes, keep) if k] |
|
|
|
|
|
def _handle_multiple_balls( |
|
|
self, all_boxes: List[BoundingBox] |
|
|
) -> List[BoundingBox]: |
|
|
""" |
|
|
When multiple footballs are detected, keep only the one with highest confidence. |
|
|
""" |
|
|
ball_detections = [box for box in all_boxes if box.cls_id == 0] |
|
|
|
|
|
if len(ball_detections) <= 1: |
|
|
return all_boxes |
|
|
|
|
|
|
|
|
best_ball = max(ball_detections, key=lambda b: b.conf) |
|
|
|
|
|
|
|
|
filtered_boxes = [box for box in all_boxes if box.cls_id != 0] |
|
|
filtered_boxes.append(best_ball) |
|
|
|
|
|
return filtered_boxes |
|
|
|
|
|
def _reclass_extra_goalkeepers( |
|
|
self, |
|
|
img_bgr: np.ndarray, |
|
|
boxes: List[BoundingBox], |
|
|
cluster_centers: Optional[np.ndarray], |
|
|
) -> None: |
|
|
""" |
|
|
When multiple goalkeepers are detected, keep the one with highest confidence |
|
|
and reclassify the rest as regular players. |
|
|
(Adopted from competitor's simpler approach - confidence-based selection) |
|
|
|
|
|
Args: |
|
|
img_bgr: Current frame image |
|
|
boxes: List of all detected boxes (modified in-place) |
|
|
cluster_centers: Pre-computed team cluster centers (if available) |
|
|
""" |
|
|
gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1] |
|
|
if len(gk_idxs) <= 1: |
|
|
return |
|
|
|
|
|
|
|
|
gk_idxs_sorted = sorted(gk_idxs, key=lambda i: boxes[i].conf, reverse=True) |
|
|
keep_gk_idx = gk_idxs_sorted[0] |
|
|
to_reclass = gk_idxs_sorted[1:] |
|
|
|
|
|
|
|
|
for gki in to_reclass: |
|
|
|
|
|
h, w = img_bgr.shape[:2] |
|
|
hs_gk = self.hsv_classifier._extract_hsv_features_with_green_filter( |
|
|
img_bgr, boxes[gki], w, h |
|
|
) if self.hsv_classifier else None |
|
|
|
|
|
|
|
|
if cluster_centers is not None and len(cluster_centers) >= 2: |
|
|
if self.team_classification_method == "resnet" and self.team_classifier_fitted: |
|
|
|
|
|
try: |
|
|
x1, y1, x2, y2 = self._clip_box_to_image( |
|
|
boxes[gki].x1, boxes[gki].y1, boxes[gki].x2, boxes[gki].y2, w, h |
|
|
) |
|
|
gk_crop = img_bgr[y1:y2, x1:x2] |
|
|
if gk_crop.size > 0: |
|
|
gk_features = self.resnet_classifier.extract_features([gk_crop])[0] |
|
|
d0 = float(np.linalg.norm(gk_features - cluster_centers[0])) |
|
|
d1 = float(np.linalg.norm(gk_features - cluster_centers[1])) |
|
|
assign_cls = 6 if d0 <= d1 else 7 |
|
|
else: |
|
|
assign_cls = 6 |
|
|
except Exception: |
|
|
|
|
|
if hs_gk is not None: |
|
|
d0 = float(np.linalg.norm(hs_gk - cluster_centers[0])) |
|
|
d1 = float(np.linalg.norm(hs_gk - cluster_centers[1])) |
|
|
assign_cls = 6 if d0 <= d1 else 7 |
|
|
else: |
|
|
assign_cls = 6 |
|
|
else: |
|
|
|
|
|
if hs_gk is not None: |
|
|
d0 = float(np.linalg.norm(hs_gk - cluster_centers[0])) |
|
|
d1 = float(np.linalg.norm(hs_gk - cluster_centers[1])) |
|
|
assign_cls = 6 if d0 <= d1 else 7 |
|
|
else: |
|
|
assign_cls = 6 |
|
|
else: |
|
|
|
|
|
if hs_gk is not None: |
|
|
assign_cls = 6 if float(hs_gk[0]) < self.SINGLE_PLAYER_HUE_PIVOT else 7 |
|
|
else: |
|
|
assign_cls = 6 |
|
|
|
|
|
boxes[gki].cls_id = int(assign_cls) |
|
|
|
|
|
def _multi_scale_detection(self, img_bgr: np.ndarray) -> List[BoundingBox]: |
|
|
""" |
|
|
Multi-Scale Object Detection for improved small object detection. |
|
|
Uses multiple image scales and combines results with intelligent NMS. |
|
|
""" |
|
|
H, W = img_bgr.shape[:2] |
|
|
scales = [1.0, 1.15, 0.85] |
|
|
all_detections = [] |
|
|
|
|
|
for scale in scales: |
|
|
if scale != 1.0: |
|
|
new_h, new_w = int(H * scale), int(W * scale) |
|
|
|
|
|
if new_h > 2048 or new_w > 2048 or new_h < 320 or new_w < 320: |
|
|
continue |
|
|
scaled_img = cv2.resize(img_bgr, (new_w, new_h)) |
|
|
else: |
|
|
scaled_img = img_bgr |
|
|
new_h, new_w = H, W |
|
|
|
|
|
|
|
|
results = self.bbox_model.predict([scaled_img], verbose=False) |
|
|
|
|
|
if results and hasattr(results[0], "boxes") and results[0].boxes is not None: |
|
|
for box in results[0].boxes.data: |
|
|
x1, y1, x2, y2, conf, yolo_cls_id = box.tolist() |
|
|
|
|
|
|
|
|
validator_cls_id = self._map_yolo_to_validator_cls_id(int(yolo_cls_id)) |
|
|
if validator_cls_id is None: |
|
|
continue |
|
|
|
|
|
|
|
|
if scale != 1.0: |
|
|
x1 = x1 / scale |
|
|
y1 = y1 / scale |
|
|
x2 = x2 / scale |
|
|
y2 = y2 / scale |
|
|
|
|
|
|
|
|
x1, y1, x2, y2 = self._clip_box_to_image(x1, y1, x2, y2, W, H) |
|
|
|
|
|
|
|
|
box_area = (x2 - x1) * (y2 - y1) |
|
|
if scale == 1.15 and box_area < 2500: |
|
|
conf *= 1.08 |
|
|
elif scale == 0.85 and box_area > 8000: |
|
|
conf *= 1.03 |
|
|
|
|
|
all_detections.append(BoundingBox( |
|
|
x1=int(x1), y1=int(y1), x2=int(x2), y2=int(y2), |
|
|
cls_id=validator_cls_id, conf=float(conf) |
|
|
)) |
|
|
|
|
|
|
|
|
return self._multi_scale_nms(all_detections) |
|
|
|
|
|
def _multi_scale_nms(self, boxes: List[BoundingBox], iou_threshold: float = 0.45) -> List[BoundingBox]: |
|
|
""" |
|
|
Multi-scale Non-Maximum Suppression that preserves detections from different scales. |
|
|
""" |
|
|
if not boxes: |
|
|
return [] |
|
|
|
|
|
|
|
|
boxes_by_class = {} |
|
|
for box in boxes: |
|
|
if box.cls_id not in boxes_by_class: |
|
|
boxes_by_class[box.cls_id] = [] |
|
|
boxes_by_class[box.cls_id].append(box) |
|
|
|
|
|
final_boxes = [] |
|
|
|
|
|
for cls_id, class_boxes in boxes_by_class.items(): |
|
|
|
|
|
class_boxes_sorted = sorted(class_boxes, key=lambda x: x.conf, reverse=True) |
|
|
keep = [] |
|
|
|
|
|
while class_boxes_sorted: |
|
|
|
|
|
current = class_boxes_sorted.pop(0) |
|
|
keep.append(current) |
|
|
|
|
|
|
|
|
remaining = [] |
|
|
for box in class_boxes_sorted: |
|
|
iou = self._calculate_iou(current, box) |
|
|
if iou < iou_threshold: |
|
|
remaining.append(box) |
|
|
elif box.conf > current.conf * 0.92: |
|
|
remaining.append(box) |
|
|
|
|
|
class_boxes_sorted = remaining |
|
|
|
|
|
final_boxes.extend(keep) |
|
|
|
|
|
return final_boxes |
|
|
|
|
|
def _calculate_iou(self, box1: BoundingBox, box2: BoundingBox) -> float: |
|
|
"""Calculate Intersection over Union (IoU) between two bounding boxes.""" |
|
|
|
|
|
x1 = max(box1.x1, box2.x1) |
|
|
y1 = max(box1.y1, box2.y1) |
|
|
x2 = min(box1.x2, box2.x2) |
|
|
y2 = min(box1.y2, box2.y2) |
|
|
|
|
|
if x2 <= x1 or y2 <= y1: |
|
|
return 0.0 |
|
|
|
|
|
intersection = (x2 - x1) * (y2 - y1) |
|
|
|
|
|
|
|
|
area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1) |
|
|
area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1) |
|
|
union = area1 + area2 - intersection |
|
|
|
|
|
return intersection / union if union > 0 else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fit_team_classifier( |
|
|
self, |
|
|
batch_images: list[ndarray], |
|
|
player_class_id: int = 0 |
|
|
) -> None: |
|
|
""" |
|
|
Fit the team classifier on player crops from batch images. |
|
|
Only needed for ResNet or ensemble methods. |
|
|
|
|
|
Args: |
|
|
batch_images: List of images to extract player crops from |
|
|
player_class_id: YOLO class ID that represents players (default: 0 for new model) |
|
|
""" |
|
|
if self.team_classification_method == "hsv": |
|
|
print("ℹ️ HSV classifier doesn't require fitting") |
|
|
return |
|
|
|
|
|
player_crops = [] |
|
|
|
|
|
bbox_model_results = self.bbox_model.predict(batch_images) |
|
|
if bbox_model_results is not None: |
|
|
for frame_idx, detection in enumerate(bbox_model_results): |
|
|
if not hasattr(detection, "boxes") or detection.boxes is None: |
|
|
continue |
|
|
|
|
|
frame_image = batch_images[frame_idx] |
|
|
h, w = frame_image.shape[:2] |
|
|
|
|
|
for box in detection.boxes.data: |
|
|
x1, y1, x2, y2, conf, yolo_cls_id = box.tolist() |
|
|
|
|
|
|
|
|
if int(yolo_cls_id) == player_class_id: |
|
|
x1_clip, y1_clip, x2_clip, y2_clip = self._clip_box_to_image( |
|
|
int(x1), int(y1), int(x2), int(y2), w, h |
|
|
) |
|
|
crop = frame_image[y1_clip:y2_clip, x1_clip:x2_clip] |
|
|
if crop.size > 0: |
|
|
player_crops.append(crop) |
|
|
|
|
|
if len(player_crops) > 0: |
|
|
if self.team_classification_method == "resnet": |
|
|
self.resnet_classifier.fit(player_crops) |
|
|
self.team_classifier_fitted = True |
|
|
print(f"✅ ResNet team classifier fitted on {len(player_crops)} player crops") |
|
|
elif self.team_classification_method == "ensemble": |
|
|
self.resnet_classifier.fit(player_crops) |
|
|
self.team_classifier_fitted = True |
|
|
print(f"✅ ResNet classifier (in ensemble) fitted on {len(player_crops)} player crops") |
|
|
else: |
|
|
print("⚠️ No player crops found to fit team classifier") |
|
|
|
|
|
def predict_batch( |
|
|
self, |
|
|
batch_images: list[ndarray], |
|
|
offset: int, |
|
|
n_keypoints: int, |
|
|
task_type: Optional[str] = None, |
|
|
) -> list[TVFrameResult]: |
|
|
""" |
|
|
Miner prediction for a batch of images with enhanced post-processing. |
|
|
|
|
|
Args: |
|
|
batch_images (list[np.ndarray]): A list of images to process |
|
|
offset (int): Frame number of the first image in the batch |
|
|
n_keypoints (int): Number of keypoints expected per frame |
|
|
task_type (str | None): |
|
|
- None: Process both object and keypoint detection |
|
|
- "object": Only process object detection |
|
|
- "keypoint": Only process keypoint detection |
|
|
|
|
|
Returns: |
|
|
list[TVFrameResult]: Predictions for each image in the batch |
|
|
""" |
|
|
|
|
|
process_objects = task_type is None or task_type == "object" |
|
|
process_keypoints = task_type is None or task_type == "keypoint" |
|
|
|
|
|
bboxes: dict[int, list[BoundingBox]] = {} |
|
|
|
|
|
|
|
|
if process_objects: |
|
|
for frame_idx, frame_image in enumerate(batch_images): |
|
|
|
|
|
boxes = self._multi_scale_detection(frame_image) |
|
|
|
|
|
|
|
|
boxes = self._handle_multiple_balls(boxes) |
|
|
|
|
|
|
|
|
boxes = self.suppress_quasi_total_containment(boxes) |
|
|
boxes = self.suppress_small_contained(boxes) |
|
|
|
|
|
|
|
|
player_boxes = [box for idx, box in enumerate(boxes) if box.cls_id == 2] |
|
|
player_indices = [idx for idx, box in enumerate(boxes) if box.cls_id == 2] |
|
|
|
|
|
team_cluster_centers = None |
|
|
team_labels = None |
|
|
|
|
|
if len(player_boxes) > 0: |
|
|
if self.team_classification_method == "hsv": |
|
|
|
|
|
player_crops = [ |
|
|
frame_image[box.y1:box.y2, box.x1:box.x2] |
|
|
for box in player_boxes |
|
|
] |
|
|
team_labels, team_cluster_centers = self.hsv_classifier.predict( |
|
|
player_crops, player_boxes, frame_image |
|
|
) |
|
|
|
|
|
elif self.team_classification_method == "resnet": |
|
|
|
|
|
if self.team_classifier_fitted: |
|
|
player_crops = [ |
|
|
frame_image[box.y1:box.y2, box.x1:box.x2] |
|
|
for box in player_boxes |
|
|
] |
|
|
team_labels = self.resnet_classifier.predict(player_crops) |
|
|
|
|
|
if hasattr(self.resnet_classifier.cluster_model, 'cluster_centers_'): |
|
|
team_cluster_centers = self.resnet_classifier.cluster_model.cluster_centers_ |
|
|
|
|
|
elif self.team_classification_method == "ensemble": |
|
|
|
|
|
player_crops = [ |
|
|
frame_image[box.y1:box.y2, box.x1:box.x2] |
|
|
for box in player_boxes |
|
|
] |
|
|
|
|
|
|
|
|
hsv_labels, hsv_centers = self.hsv_classifier.predict( |
|
|
player_crops, player_boxes, frame_image |
|
|
) |
|
|
|
|
|
resnet_labels = None |
|
|
resnet_centers = None |
|
|
if self.team_classifier_fitted: |
|
|
resnet_labels = self.resnet_classifier.predict(player_crops) |
|
|
if hasattr(self.resnet_classifier.cluster_model, 'cluster_centers_'): |
|
|
resnet_centers = self.resnet_classifier.cluster_model.cluster_centers_ |
|
|
|
|
|
|
|
|
if resnet_labels is not None and len(resnet_labels) == len(hsv_labels): |
|
|
|
|
|
team_labels = np.array([ |
|
|
resnet_labels[i] if resnet_labels[i] == hsv_labels[i] |
|
|
else resnet_labels[i] |
|
|
for i in range(len(hsv_labels)) |
|
|
]) |
|
|
team_cluster_centers = resnet_centers |
|
|
else: |
|
|
|
|
|
team_labels = hsv_labels |
|
|
team_cluster_centers = hsv_centers |
|
|
|
|
|
|
|
|
if team_labels is not None and len(team_labels) == len(player_indices): |
|
|
for idx, team_label in zip(player_indices, team_labels): |
|
|
boxes[idx].cls_id = 6 + int(team_label) |
|
|
|
|
|
|
|
|
self._reclass_extra_goalkeepers( |
|
|
frame_image, boxes, team_cluster_centers |
|
|
) |
|
|
|
|
|
bboxes[offset + frame_idx] = boxes |
|
|
|
|
|
|
|
|
keypoints: dict[int, list[tuple[int, int]]] = {} |
|
|
|
|
|
if process_keypoints: |
|
|
keypoints_model_results = self.keypoints_model.predict(batch_images) |
|
|
else: |
|
|
keypoints_model_results = None |
|
|
|
|
|
if keypoints_model_results is not None: |
|
|
for frame_idx, detection in enumerate(keypoints_model_results): |
|
|
if not hasattr(detection, "keypoints") or detection.keypoints is None: |
|
|
continue |
|
|
|
|
|
frame_keypoints: list[tuple[int, int, float]] = [] |
|
|
for i, part_points in enumerate(detection.keypoints.data): |
|
|
for k_id, (x, y, _) in enumerate(part_points): |
|
|
confidence = detection.keypoints.conf[i][k_id] |
|
|
frame_keypoints.append((int(x), int(y), float(confidence))) |
|
|
|
|
|
|
|
|
if len(frame_keypoints) < n_keypoints: |
|
|
frame_keypoints.extend( |
|
|
[(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints)) |
|
|
) |
|
|
else: |
|
|
frame_keypoints = frame_keypoints[:n_keypoints] |
|
|
|
|
|
|
|
|
|
|
|
filtered_keypoints = [] |
|
|
for idx, (x, y, confidence) in enumerate(frame_keypoints): |
|
|
if idx in self.CORNER_INDICES: |
|
|
|
|
|
if confidence < 0.3: |
|
|
filtered_keypoints.append((0, 0)) |
|
|
else: |
|
|
filtered_keypoints.append((int(x), int(y))) |
|
|
else: |
|
|
|
|
|
if confidence < 0.5: |
|
|
filtered_keypoints.append((0, 0)) |
|
|
else: |
|
|
filtered_keypoints.append((int(x), int(y))) |
|
|
|
|
|
keypoints[offset + frame_idx] = filtered_keypoints |
|
|
|
|
|
|
|
|
results: list[TVFrameResult] = [] |
|
|
for frame_number in range(offset, offset + len(batch_images)): |
|
|
results.append( |
|
|
TVFrameResult( |
|
|
frame_id=frame_number, |
|
|
boxes=bboxes.get(frame_number, []), |
|
|
keypoints=keypoints.get( |
|
|
frame_number, |
|
|
[(0, 0) for _ in range(n_keypoints)] |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
return results |
|
|
|