| from pathlib import Path |
| from typing import List, Tuple, Dict |
| import sys |
| import os |
|
|
| from numpy import ndarray |
| from pydantic import BaseModel |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| from keypoint_helper import run_keypoints_post_processing |
|
|
| from ultralytics import YOLO |
| from team_cluster import TeamClassifier |
| from utils import ( |
| BoundingBox, |
| Constants, |
| ) |
|
|
| import time |
| import torch |
| import gc |
| from pitch import process_batch_input, get_cls_net |
| import yaml |
|
|
|
|
| class BoundingBox(BaseModel): |
| x1: int |
| y1: int |
| x2: int |
| y2: int |
| cls_id: int |
| conf: float |
|
|
|
|
| class TVFrameResult(BaseModel): |
| frame_id: int |
| boxes: List[BoundingBox] |
| keypoints: List[Tuple[int, int]] |
|
|
|
|
| class Miner: |
| SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA |
| SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX |
| SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT |
| CORNER_INDICES = Constants.CORNER_INDICES |
| KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE |
| CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE |
| GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN |
| MIN_SAMPLES_FOR_FIT = 16 |
| MAX_SAMPLES_FOR_FIT = 700 |
|
|
| def __init__(self, path_hf_repo: Path) -> None: |
| try: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model_path = path_hf_repo / "football_object_detection.onnx" |
| self.bbox_model = YOLO(model_path) |
| |
| print("BBox Model Loaded") |
|
|
| team_model_path = path_hf_repo / "osnet_model.pth.tar-100" |
| self.team_classifier = TeamClassifier( |
| device=device, |
| batch_size=32, |
| model_name=str(team_model_path) |
| ) |
| print("Team Classifier Loaded") |
| |
| |
| self.team_classifier_fitted = False |
| self.player_crops_for_fit = [] |
|
|
| model_kp_path = path_hf_repo / 'keypoint' |
| config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' |
| cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) |
| |
| loaded_state_kp = torch.load(model_kp_path, map_location=device) |
| model = get_cls_net(cfg_kp) |
| model.load_state_dict(loaded_state_kp) |
| model.to(device) |
| model.eval() |
|
|
| self.keypoints_model = model |
| self.kp_threshold = 0.1 |
| self.pitch_batch_size = 4 |
| self.health = "healthy" |
| print("✅ Keypoints Model Loaded") |
| except Exception as e: |
| self.health = "❌ Miner initialization failed: " + str(e) |
| print(self.health) |
|
|
| def __repr__(self) -> str: |
| if self.health == 'healthy': |
| return ( |
| f"health: {self.health}\n" |
| f"BBox Model: {type(self.bbox_model).__name__}\n" |
| f"Keypoints Model: {type(self.keypoints_model).__name__}" |
| ) |
| else: |
| return self.health |
|
|
| def _calculate_iou(self, box1: Tuple[float, float, float, float], |
| box2: Tuple[float, float, float, float]) -> float: |
| """ |
| Calculate Intersection over Union (IoU) between two bounding boxes. |
| Args: |
| box1: (x1, y1, x2, y2) |
| box2: (x1, y1, x2, y2) |
| Returns: |
| IoU score (0-1) |
| """ |
| x1_1, y1_1, x2_1, y2_1 = box1 |
| x1_2, y1_2, x2_2, y2_2 = box2 |
|
|
| |
| x_left = max(x1_1, x1_2) |
| y_top = max(y1_1, y1_2) |
| x_right = min(x2_1, x2_2) |
| y_bottom = min(y2_1, y2_2) |
|
|
| if x_right < x_left or y_bottom < y_top: |
| return 0.0 |
|
|
| intersection_area = (x_right - x_left) * (y_bottom - y_top) |
|
|
| |
| box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) |
| box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) |
| union_area = box1_area + box2_area - intersection_area |
|
|
| if union_area == 0: |
| return 0.0 |
|
|
| return intersection_area / union_area |
|
|
| def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]: |
| batch_size = 16 |
| detection_results = [] |
| n_frames = len(decoded_images) |
| for frame_number in range(0, n_frames, batch_size): |
| batch_images = decoded_images[frame_number: frame_number + batch_size] |
| detections = self.bbox_model(batch_images, verbose=False, save=False) |
| detection_results.extend(detections) |
| |
| return detection_results |
|
|
| def _team_classify(self, detection_results, decoded_images, offset): |
| self.team_classifier_fitted = False |
| start = time.time() |
| |
| fit_sample_size = 700 |
| player_crops_for_fit = [] |
|
|
| for frame_id in range(len(detection_results)): |
| detection_box = detection_results[frame_id].boxes.data |
| if len(detection_box) < 4: |
| continue |
| |
| if len(player_crops_for_fit) < fit_sample_size: |
| frame_image = decoded_images[frame_id] |
| for box in detection_box: |
| x1, y1, x2, y2, conf, cls_id = box.tolist() |
| if conf < 0.5: |
| continue |
| mapped_cls_id = str(int(cls_id)) |
| |
| if mapped_cls_id == '2': |
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
| if crop.size > 0: |
| player_crops_for_fit.append(crop) |
|
|
| |
| if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size: |
| print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") |
| self.team_classifier.fit(player_crops_for_fit) |
| self.team_classifier_fitted = True |
| break |
| if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16: |
| print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") |
| self.team_classifier.fit(player_crops_for_fit) |
| self.team_classifier_fitted = True |
| end = time.time() |
| print(f"Fitting Kmeans time: {end - start}") |
|
|
| |
| start = time.time() |
|
|
| |
| prediction_interval = 1 |
| iou_threshold = 0.3 |
|
|
| print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}") |
|
|
| |
| predicted_frame_data = {} |
|
|
| |
| frames_to_predict = [] |
| for frame_id in range(len(detection_results)): |
| if frame_id % prediction_interval == 0: |
| frames_to_predict.append(frame_id) |
|
|
| print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames " |
| f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)") |
|
|
| for frame_id in frames_to_predict: |
| detection_box = detection_results[frame_id].boxes.data |
| frame_image = decoded_images[frame_id] |
|
|
| |
| frame_player_crops = [] |
| frame_player_indices = [] |
| frame_player_boxes = [] |
|
|
| for idx, box in enumerate(detection_box): |
| x1, y1, x2, y2, conf, cls_id = box.tolist() |
| if cls_id == 2 and conf < 0.6: |
| continue |
| mapped_cls_id = str(int(cls_id)) |
|
|
| |
| if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': |
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
| if crop.size > 0: |
| frame_player_crops.append(crop) |
| frame_player_indices.append(idx) |
| frame_player_boxes.append((x1, y1, x2, y2)) |
|
|
| |
| if len(frame_player_crops) > 0: |
| team_ids = self.team_classifier.predict(frame_player_crops) |
| predicted_frame_data[frame_id] = {} |
| for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids): |
| |
| team_cls_id = str(6 + int(team_id)) |
| predicted_frame_data[frame_id][idx] = (bbox, team_cls_id) |
|
|
| |
| fallback_count = 0 |
| interpolated_count = 0 |
| bboxes: dict[int, list[BoundingBox]] = {} |
| for frame_id in range(len(detection_results)): |
| detection_box = detection_results[frame_id].boxes.data |
| frame_image = decoded_images[frame_id] |
| boxes = [] |
|
|
| team_predictions = {} |
|
|
| if frame_id % prediction_interval == 0: |
| |
| if frame_id in predicted_frame_data: |
| for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items(): |
| team_predictions[idx] = team_cls_id |
| else: |
| |
| |
| prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval |
| next_predicted_frame = prev_predicted_frame + prediction_interval |
|
|
| |
| for idx, box in enumerate(detection_box): |
| x1, y1, x2, y2, conf, cls_id = box.tolist() |
| if cls_id == 2 and conf < 0.6: |
| continue |
| mapped_cls_id = str(int(cls_id)) |
|
|
| if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': |
| target_box = (x1, y1, x2, y2) |
|
|
| |
| best_team_id = None |
| best_iou = 0.0 |
|
|
| if prev_predicted_frame in predicted_frame_data: |
| team_id, iou = self._find_best_match( |
| target_box, |
| predicted_frame_data[prev_predicted_frame], |
| iou_threshold |
| ) |
| if team_id is not None: |
| best_team_id = team_id |
| best_iou = iou |
|
|
| |
| if best_team_id is None and next_predicted_frame < len(detection_results): |
| if next_predicted_frame in predicted_frame_data: |
| team_id, iou = self._find_best_match( |
| target_box, |
| predicted_frame_data[next_predicted_frame], |
| iou_threshold |
| ) |
| if team_id is not None and iou > best_iou: |
| best_team_id = team_id |
| best_iou = iou |
|
|
| |
| if best_team_id is not None: |
| interpolated_count += 1 |
| else: |
| |
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
| if crop.size > 0: |
| team_id = self.team_classifier.predict([crop])[0] |
| best_team_id = str(6 + int(team_id)) |
| fallback_count += 1 |
|
|
| if best_team_id is not None: |
| team_predictions[idx] = best_team_id |
|
|
| |
| for idx, box in enumerate(detection_box): |
| x1, y1, x2, y2, conf, cls_id = box.tolist() |
| if cls_id == 2 and conf < 0.6: |
| continue |
|
|
| |
| overlap_staff = False |
| for idy, boxy in enumerate(detection_box): |
| s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist() |
| if cls_id == 2 and s_cls_id == 4: |
| staff_iou = self._calculate_iou(box[:4], boxy[:4]) |
| if staff_iou >= 0.8: |
| overlap_staff = True |
| break |
| if overlap_staff: |
| continue |
|
|
| mapped_cls_id = str(int(cls_id)) |
|
|
| |
| if idx in team_predictions: |
| mapped_cls_id = team_predictions[idx] |
| if mapped_cls_id != '4': |
| if int(mapped_cls_id) == 3 and conf < 0.5: |
| continue |
| boxes.append( |
| BoundingBox( |
| x1=int(x1), |
| y1=int(y1), |
| x2=int(x2), |
| y2=int(y2), |
| cls_id=int(mapped_cls_id), |
| conf=float(conf), |
| ) |
| ) |
| |
| footballs = [bb for bb in boxes if int(bb.cls_id) == 0] |
| if len(footballs) > 1: |
| best_ball = max(footballs, key=lambda b: b.conf) |
| boxes = [bb for bb in boxes if int(bb.cls_id) != 0] |
| boxes.append(best_ball) |
| |
| bboxes[offset + frame_id] = boxes |
| return bboxes |
|
|
|
|
| def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: |
| start = time.time() |
| detection_results = self._detect_objects_batch(batch_images) |
| end = time.time() |
| print(f"Detection time: {end - start}") |
| start = time.time() |
| bboxes = self._team_classify(detection_results, batch_images, offset) |
| end = time.time() |
| print(f"Team classify time: {end - start}") |
|
|
| pitch_batch_size = min(self.pitch_batch_size, len(batch_images)) |
| keypoints: Dict[int, List[Tuple[int, int]]] = {} |
|
|
| start = time.time() |
| while True: |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| device_str = "cuda" |
| keypoints_result = process_batch_input( |
| batch_images, |
| self.keypoints_model, |
| self.kp_threshold, |
| device_str, |
| batch_size=pitch_batch_size, |
| ) |
| if keypoints_result is not None and len(keypoints_result) > 0: |
| for frame_number_in_batch, kp_dict in enumerate(keypoints_result): |
| if frame_number_in_batch >= len(batch_images): |
| break |
| frame_keypoints: List[Tuple[int, int]] = [] |
| try: |
| height, width = batch_images[frame_number_in_batch].shape[:2] |
| if kp_dict is not None and isinstance(kp_dict, dict): |
| for idx in range(32): |
| x, y = 0, 0 |
| kp_idx = idx + 1 |
| if kp_idx in kp_dict: |
| try: |
| kp_data = kp_dict[kp_idx] |
| if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data: |
| x = int(kp_data["x"] * width) |
| y = int(kp_data["y"] * height) |
| except (KeyError, TypeError, ValueError): |
| pass |
| frame_keypoints.append((x, y)) |
| except (IndexError, ValueError, AttributeError): |
| frame_keypoints = [(0, 0)] * 32 |
| if len(frame_keypoints) < n_keypoints: |
| frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints))) |
| else: |
| frame_keypoints = frame_keypoints[:n_keypoints] |
| keypoints[offset + frame_number_in_batch] = frame_keypoints |
| break |
| end = time.time() |
| print(f"Keypoint time: {end - start}") |
|
|
|
|
| results: List[TVFrameResult] = [] |
| for frame_number in range(offset, offset + len(batch_images)): |
| frame_boxes = bboxes.get(frame_number, []) |
| frame_keypoints = keypoints.get(frame_number, [(0, 0) for _ in range(n_keypoints)]) |
| result = TVFrameResult( |
| frame_id=frame_number, |
| boxes=frame_boxes, |
| keypoints=frame_keypoints, |
| ) |
| results.append(result) |
|
|
| if len(batch_images) > 0: |
| h, w = batch_images[0].shape[:2] |
| results = run_keypoints_post_processing(results, w, h) |
|
|
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| return results |