| from pathlib import Path |
| from typing import List, Tuple, Dict, Optional |
| import sys |
| import os |
|
|
| from numpy import ndarray |
| from pydantic import BaseModel |
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| from keypoint_helper import run_keypoints_post_processing |
| from keypoint_helper_v2 import run_keypoints_post_processing as run_keypoints_post_processing_v2 |
|
|
| from ultralytics import YOLO |
| from team_cluster import TeamClassifier |
| from utils import ( |
| BoundingBox, |
| Constants, |
| ) |
|
|
| import time |
| import torch |
| import gc |
| import cv2 |
| import numpy as np |
| from collections import defaultdict |
| from pitch import process_batch_input, get_cls_net |
| from keypoint_evaluation import ( |
| evaluate_keypoints_for_frame, |
| evaluate_keypoints_for_frame_gpu, |
| load_template_from_file, |
| evaluate_keypoints_for_frame_opencv_cuda, |
| evaluate_keypoints_batch_for_frame, |
| ) |
|
|
| import yaml |
|
|
|
|
| class BoundingBox(BaseModel): |
| x1: int |
| y1: int |
| x2: int |
| y2: int |
| cls_id: int |
| conf: float |
|
|
|
|
| class TVFrameResult(BaseModel): |
| frame_id: int |
| boxes: List[BoundingBox] |
| keypoints: List[Tuple[int, int]] |
|
|
|
|
| class Miner: |
| SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA |
| SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX |
| SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT |
| CORNER_INDICES = Constants.CORNER_INDICES |
| KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE |
| CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE |
| GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN |
| MIN_SAMPLES_FOR_FIT = 16 |
| MAX_SAMPLES_FOR_FIT = 600 |
|
|
| def __init__(self, path_hf_repo: Path) -> None: |
| try: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model_path = path_hf_repo / "detection.onnx" |
| self.bbox_model = YOLO(model_path) |
| |
| print(f"BBox Model Loaded: class name {self.bbox_model.names}") |
|
|
| team_model_path = path_hf_repo / "osnet_model.pth.tar-100" |
| self.team_classifier = TeamClassifier( |
| device=device, |
| batch_size=32, |
| model_name=str(team_model_path) |
| ) |
| print("Team Classifier Loaded") |
| |
| |
| self.team_classifier_fitted = False |
| self.player_crops_for_fit = [] |
|
|
| self.keypoints_model_yolo = YOLO(path_hf_repo / "keypoint.pt") |
|
|
| model_kp_path = path_hf_repo / 'keypoint' |
| config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' |
| cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) |
| |
| loaded_state_kp = torch.load(model_kp_path, map_location=device) |
| model = get_cls_net(cfg_kp) |
| model.load_state_dict(loaded_state_kp) |
| model.to(device) |
| model.eval() |
|
|
| self.keypoints_model = model |
| print("Keypoints Model (keypoint.pt) Loaded") |
|
|
| template_image_path = path_hf_repo / "football_pitch_template.png" |
| self.template_image, self.template_keypoints = load_template_from_file(str(template_image_path)) |
|
|
| self.kp_threshold = 0.1 |
| self.pitch_batch_size = 4 |
| self.health = "healthy" |
|
|
| print("✅ Keypoints Model Loaded") |
| except Exception as e: |
| self.health = "❌ Miner initialization failed: " + str(e) |
| print(self.health) |
|
|
| def __repr__(self) -> str: |
| if self.health == 'healthy': |
| return ( |
| f"health: {self.health}\n" |
| f"BBox Model: {type(self.bbox_model).__name__}\n" |
| f"Keypoints Model: {type(self.keypoints_model).__name__}" |
| ) |
| else: |
| return self.health |
|
|
| def _calculate_iou(self, box1: Tuple[float, float, float, float], |
| box2: Tuple[float, float, float, float]) -> float: |
| """ |
| Calculate Intersection over Union (IoU) between two bounding boxes. |
| Args: |
| box1: (x1, y1, x2, y2) |
| box2: (x1, y1, x2, y2) |
| Returns: |
| IoU score (0-1) |
| """ |
| x1_1, y1_1, x2_1, y2_1 = box1 |
| x1_2, y1_2, x2_2, y2_2 = box2 |
|
|
| |
| x_left = max(x1_1, x1_2) |
| y_top = max(y1_1, y1_2) |
| x_right = min(x2_1, x2_2) |
| y_bottom = min(y2_1, y2_2) |
|
|
| if x_right < x_left or y_bottom < y_top: |
| return 0.0 |
|
|
| intersection_area = (x_right - x_left) * (y_bottom - y_top) |
|
|
| |
| box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) |
| box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) |
| union_area = box1_area + box2_area - intersection_area |
|
|
| if union_area == 0: |
| return 0.0 |
|
|
| return intersection_area / union_area |
|
|
| def _extract_jersey_region(self, crop: ndarray) -> ndarray: |
| """ |
| Extract jersey region (upper body) from player crop. |
| For close-ups, focuses on upper 60%, for distant shots uses full crop. |
| """ |
| if crop is None or crop.size == 0: |
| return crop |
| |
| h, w = crop.shape[:2] |
| if h < 10 or w < 10: |
| return crop |
| |
| |
| is_closeup = h > 100 or (h * w) > 12000 |
| if is_closeup: |
| |
| jersey_top = 0 |
| jersey_bottom = int(h * 0.60) |
| jersey_left = max(0, int(w * 0.05)) |
| jersey_right = min(w, int(w * 0.95)) |
| return crop[jersey_top:jersey_bottom, jersey_left:jersey_right] |
| return crop |
|
|
| def _extract_color_signature(self, crop: ndarray) -> Optional[np.ndarray]: |
| """ |
| Extract color signature from jersey region using HSV and LAB color spaces. |
| Returns a feature vector with dominant colors and color statistics. |
| """ |
| if crop is None or crop.size == 0: |
| return None |
| |
| jersey_region = self._extract_jersey_region(crop) |
| if jersey_region.size == 0: |
| return None |
| |
| try: |
| |
| hsv = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2HSV) |
| lab = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2LAB) |
| |
| |
| hsv_flat = hsv.reshape(-1, 3).astype(np.float32) |
| lab_flat = lab.reshape(-1, 3).astype(np.float32) |
| |
| |
| hsv_mean = np.mean(hsv_flat, axis=0) / 255.0 |
| hsv_std = np.std(hsv_flat, axis=0) / 255.0 |
| |
| |
| lab_mean = np.mean(lab_flat, axis=0) / 255.0 |
| lab_std = np.std(lab_flat, axis=0) / 255.0 |
| |
| |
| hue_hist, _ = np.histogram(hsv_flat[:, 0], bins=36, range=(0, 180)) |
| dominant_hue = np.argmax(hue_hist) * 5 |
| |
| |
| color_features = np.concatenate([ |
| hsv_mean, |
| hsv_std, |
| lab_mean[:2], |
| lab_std[:2], |
| [dominant_hue / 180.0] |
| ]) |
| |
| return color_features |
| except Exception as e: |
| print(f"Error extracting color signature: {e}") |
| return None |
|
|
| def _get_spatial_position(self, bbox: Tuple[float, float, float, float], |
| frame_width: int, frame_height: int) -> Tuple[float, float]: |
| """ |
| Get normalized spatial position of player on the pitch. |
| Returns (x_normalized, y_normalized) where 0,0 is top-left. |
| """ |
| x1, y1, x2, y2 = bbox |
| center_x = (x1 + x2) / 2.0 |
| center_y = (y1 + y2) / 2.0 |
| |
| |
| x_norm = center_x / frame_width if frame_width > 0 else 0.5 |
| y_norm = center_y / frame_height if frame_height > 0 else 0.5 |
| |
| return (x_norm, y_norm) |
|
|
| def _find_best_match(self, target_box: Tuple[float, float, float, float], |
| predicted_frame_data: Dict[int, Tuple[Tuple, str]], |
| iou_threshold: float) -> Tuple[Optional[str], float]: |
| """ |
| Find best matching box in predicted frame data using IoU. |
| """ |
| best_iou = 0.0 |
| best_team_id = None |
| |
| for idx, (bbox, team_cls_id) in predicted_frame_data.items(): |
| iou = self._calculate_iou(target_box, bbox) |
| if iou > best_iou and iou >= iou_threshold: |
| best_iou = iou |
| best_team_id = team_cls_id |
| |
| return (best_team_id, best_iou) |
|
|
| def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]: |
| batch_size = 16 |
| detection_results = [] |
| n_frames = len(decoded_images) |
| for frame_number in range(0, n_frames, batch_size): |
| batch_images = decoded_images[frame_number: frame_number + batch_size] |
| detections = self.bbox_model(batch_images, verbose=False, save=False) |
| detection_results.extend(detections) |
| |
| return detection_results |
|
|
| def _team_classify(self, detection_results, decoded_images, offset): |
| self.team_classifier_fitted = False |
| start = time.time() |
| |
| fit_sample_size = 600 |
| player_crops_for_fit = [] |
|
|
| for frame_id in range(len(detection_results)): |
| detection_box = detection_results[frame_id].boxes.data |
| if len(detection_box) < 4: |
| continue |
| |
| if len(player_crops_for_fit) < fit_sample_size: |
| frame_image = decoded_images[frame_id] |
| for box in detection_box: |
| x1, y1, x2, y2, conf, cls_id = box.tolist() |
| if conf < 0.5: |
| continue |
| mapped_cls_id = str(int(cls_id)) |
| |
| if mapped_cls_id == '2': |
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
| if crop.size > 0: |
| player_crops_for_fit.append(crop) |
|
|
| |
| if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size: |
| print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") |
| self.team_classifier.fit(player_crops_for_fit) |
| self.team_classifier_fitted = True |
| break |
| if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16: |
| print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") |
| self.team_classifier.fit(player_crops_for_fit) |
| self.team_classifier_fitted = True |
| end = time.time() |
| print(f"Fitting Kmeans time: {end - start}") |
|
|
| |
| start = time.time() |
|
|
| |
| prediction_interval = 1 |
| iou_threshold = 0.3 |
|
|
| print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}") |
|
|
| |
| predicted_frame_data = {} |
|
|
| |
| frames_to_predict = [] |
| for frame_id in range(len(detection_results)): |
| if frame_id % prediction_interval == 0: |
| frames_to_predict.append(frame_id) |
|
|
| print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames " |
| f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)") |
|
|
| for frame_id in frames_to_predict: |
| detection_box = detection_results[frame_id].boxes.data |
| frame_image = decoded_images[frame_id] |
|
|
| |
| frame_player_crops = [] |
| frame_player_indices = [] |
| frame_player_boxes = [] |
|
|
| for idx, box in enumerate(detection_box): |
| x1, y1, x2, y2, conf, cls_id = box.tolist() |
| if cls_id == 2 and conf < 0.6: |
| continue |
| mapped_cls_id = str(int(cls_id)) |
|
|
| |
| if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': |
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
| if crop.size > 0: |
| frame_player_crops.append(crop) |
| frame_player_indices.append(idx) |
| frame_player_boxes.append((x1, y1, x2, y2)) |
|
|
| |
| if len(frame_player_crops) > 0: |
| team_ids = self.team_classifier.predict(frame_player_crops) |
| predicted_frame_data[frame_id] = {} |
| for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids): |
| |
| team_cls_id = str(6 + int(team_id)) |
| predicted_frame_data[frame_id][idx] = (bbox, team_cls_id) |
|
|
| |
| fallback_count = 0 |
| interpolated_count = 0 |
| bboxes: dict[int, list[BoundingBox]] = {} |
| for frame_id in range(len(detection_results)): |
| detection_box = detection_results[frame_id].boxes.data |
| frame_image = decoded_images[frame_id] |
| boxes = [] |
|
|
| team_predictions = {} |
|
|
| if frame_id % prediction_interval == 0: |
| |
| if frame_id in predicted_frame_data: |
| for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items(): |
| team_predictions[idx] = team_cls_id |
| else: |
| |
| |
| prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval |
| next_predicted_frame = prev_predicted_frame + prediction_interval |
|
|
| |
| for idx, box in enumerate(detection_box): |
| x1, y1, x2, y2, conf, cls_id = box.tolist() |
| if cls_id == 2 and conf < 0.6: |
| continue |
| mapped_cls_id = str(int(cls_id)) |
|
|
| if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': |
| target_box = (x1, y1, x2, y2) |
|
|
| |
| best_team_id = None |
| best_iou = 0.0 |
|
|
| if prev_predicted_frame in predicted_frame_data: |
| team_id, iou = self._find_best_match( |
| target_box, |
| predicted_frame_data[prev_predicted_frame], |
| iou_threshold |
| ) |
| if team_id is not None: |
| best_team_id = team_id |
| best_iou = iou |
|
|
| |
| if best_team_id is None and next_predicted_frame < len(detection_results): |
| if next_predicted_frame in predicted_frame_data: |
| team_id, iou = self._find_best_match( |
| target_box, |
| predicted_frame_data[next_predicted_frame], |
| iou_threshold |
| ) |
| if team_id is not None and iou > best_iou: |
| best_team_id = team_id |
| best_iou = iou |
|
|
| |
| if best_team_id is not None: |
| interpolated_count += 1 |
| else: |
| |
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
| if crop.size > 0: |
| team_id = self.team_classifier.predict([crop])[0] |
| best_team_id = str(6 + int(team_id)) |
| fallback_count += 1 |
|
|
| if best_team_id is not None: |
| team_predictions[idx] = best_team_id |
|
|
| |
| for idx, box in enumerate(detection_box): |
| x1, y1, x2, y2, conf, cls_id = box.tolist() |
| if cls_id == 2 and conf < 0.6: |
| continue |
|
|
| |
| overlap_staff = False |
| for idy, boxy in enumerate(detection_box): |
| s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist() |
| if cls_id == 2 and s_cls_id == 4: |
| staff_iou = self._calculate_iou(box[:4], boxy[:4]) |
| if staff_iou >= 0.8: |
| overlap_staff = True |
| break |
| if overlap_staff: |
| continue |
|
|
| mapped_cls_id = str(int(cls_id)) |
|
|
| |
| if idx in team_predictions: |
| mapped_cls_id = team_predictions[idx] |
| if mapped_cls_id != '4': |
| if int(mapped_cls_id) == 3 and conf < 0.5: |
| continue |
| boxes.append( |
| BoundingBox( |
| x1=int(x1), |
| y1=int(y1), |
| x2=int(x2), |
| y2=int(y2), |
| cls_id=int(mapped_cls_id), |
| conf=float(conf), |
| ) |
| ) |
| |
| footballs = [bb for bb in boxes if int(bb.cls_id) == 0] |
| if len(footballs) > 1: |
| best_ball = max(footballs, key=lambda b: b.conf) |
| boxes = [bb for bb in boxes if int(bb.cls_id) != 0] |
| boxes.append(best_ball) |
| |
| bboxes[offset + frame_id] = boxes |
| return bboxes |
|
|
|
|
| def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: |
| print('=' * 10) |
| print(f"Offset: {offset}, Batch size: {len(batch_images)}") |
| print('=' * 10) |
|
|
| start = time.time() |
| detection_results = self._detect_objects_batch(batch_images) |
| end = time.time() |
| print(f"Detection time: {end - start}") |
| |
| |
| start = time.time() |
| bboxes = self._team_classify(detection_results, batch_images, offset) |
| end = time.time() |
| print(f"Team classify time: {end - start}") |
|
|
| |
| keypoints_yolo: Dict[int, List[Tuple[int, int]]] = {} |
|
|
| keypoints_yolo = self._detect_keypoints_batch(batch_images, offset, n_keypoints) |
|
|
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| results: List[TVFrameResult] = [] |
| for frame_number in range(offset, offset + len(batch_images)): |
| frame_boxes = bboxes.get(frame_number, []) |
| result = TVFrameResult( |
| frame_id=frame_number, |
| boxes=frame_boxes, |
| keypoints=keypoints_yolo.get( |
| frame_number, |
| [(0, 0) for _ in range(n_keypoints)], |
| ), |
| ) |
| results.append(result) |
|
|
| start = time.time() |
| if len(batch_images) > 0: |
| h, w = batch_images[0].shape[:2] |
| results = run_keypoints_post_processing_v2( |
| results, w, h, |
| frames=batch_images, |
| template_keypoints=self.template_keypoints, |
| floor_markings_template=self.template_image, |
| offset=offset |
| ) |
| end = time.time() |
| print(f"Keypoint post processing time: {end - start}") |
|
|
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| return results |
|
|
| def _detect_keypoints_batch(self, batch_images: List[ndarray], |
| offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]: |
| """ |
| Phase 3: Keypoint detection for all frames in batch. |
| |
| Args: |
| batch_images: List of images to process |
| offset: Frame offset for numbering |
| n_keypoints: Number of keypoints expected |
| |
| Returns: |
| Dictionary mapping frame_id to list of keypoint coordinates |
| """ |
| keypoints: Dict[int, List[Tuple[int, int]]] = {} |
| keypoints_model_results = self.keypoints_model_yolo.predict(batch_images) |
| |
| if keypoints_model_results is None: |
| return keypoints |
| |
| for frame_idx_in_batch, detection in enumerate(keypoints_model_results): |
| if not hasattr(detection, "keypoints") or detection.keypoints is None: |
| continue |
| |
| |
| frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] |
| for i, part_points in enumerate(detection.keypoints.data): |
| for k_id, (x, y, _) in enumerate(part_points): |
| confidence = float(detection.keypoints.conf[i][k_id]) |
| frame_keypoints_with_conf.append((int(x), int(y), confidence)) |
| |
| |
| if len(frame_keypoints_with_conf) < n_keypoints: |
| frame_keypoints_with_conf.extend( |
| [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) |
| ) |
| else: |
| frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] |
| |
| |
| filtered_keypoints: List[Tuple[int, int]] = [] |
| for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): |
| if idx in self.CORNER_INDICES: |
| |
| if confidence < 0.3: |
| filtered_keypoints.append((0, 0)) |
| else: |
| filtered_keypoints.append((int(x), int(y))) |
| else: |
| |
| if confidence < 0.5: |
| filtered_keypoints.append((0, 0)) |
| else: |
| filtered_keypoints.append((int(x), int(y))) |
| |
| frame_id = offset + frame_idx_in_batch |
| keypoints[frame_id] = filtered_keypoints |
| |
| return keypoints |
| |