| from pathlib import Path |
| from typing import List, Tuple, Dict, Optional |
| import sys, os |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| import onnxruntime as ort |
| import numpy as np |
| import cv2 |
| from torchvision.ops import batched_nms |
| import torch |
| from ultralytics import YOLO |
| from numpy import ndarray |
| from pydantic import BaseModel |
| from team_cluster import TeamClassifier |
| from utils import ( |
| BoundingBox, |
| Constants, |
| suppress_small_contained_boxes, |
| classify_teams_batch, |
| ) |
|
|
|
|
| class TVFrameResult(BaseModel): |
| frame_id: int |
| boxes: List[BoundingBox] |
| keypoints: List[Tuple[int, int]] |
|
|
|
|
| class Miner: |
| """ |
| Football video analysis system for object detection and team classification. |
| """ |
| |
| SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA |
| SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX |
| SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT |
| CORNER_INDICES = Constants.CORNER_INDICES |
| KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE |
| CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE |
| GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN |
| MIN_SAMPLES_FOR_FIT = 16 |
| MAX_SAMPLES_FOR_FIT = 500 |
|
|
| def __init__(self, path_hf_repo: Path) -> None: |
| providers = [ |
| 'CUDAExecutionProvider', |
| 'CPUExecutionProvider' |
| ] |
| model_path = path_hf_repo / "detection.onnx" |
| session = ort.InferenceSession(model_path, providers=providers) |
|
|
| input_name = session.get_inputs()[0].name |
| height = width = 640 |
| dummy = np.zeros((1, 3, height, width), dtype=np.float32) |
| session.run(None, {input_name: dummy}) |
| model = session |
| self.bbox_model = model |
| |
| print("BBox Model Loaded") |
| self.keypoints_model = YOLO(path_hf_repo / "keypoint.pt") |
| print("Keypoints Model (keypoint.pt) Loaded") |
| |
| team_model_path = path_hf_repo / "osnet_model.pth.tar-100" |
| device = 'cuda' |
| self.team_classifier = TeamClassifier( |
| device=device, |
| batch_size=32, |
| model_name=str(team_model_path) |
| ) |
| print("Team Classifier Loaded") |
| |
| |
| self.team_classifier_fitted = False |
| self.player_crops_for_fit = [] |
|
|
| def __repr__(self) -> str: |
| return ( |
| f"BBox Model: {type(self.bbox_model).__name__}\n" |
| f"Keypoints Model: {type(self.keypoints_model).__name__}" |
| ) |
|
|
|
|
|
|
| def _handle_multiple_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]: |
| """ |
| Handle goalkeeper detection issues: |
| 1. Fix misplaced goalkeepers (standing in middle of field) |
| 2. Limit to maximum 2 goalkeepers (one from each team) |
| |
| Returns: |
| Filtered list of boxes with corrected goalkeepers |
| """ |
| |
| |
| boxes = self._fix_misplaced_goalkeepers(boxes) |
| |
| |
| gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1] |
| if len(gk_idxs) <= 2: |
| return boxes |
| |
| |
| gk_idxs_sorted = sorted(gk_idxs, key=lambda i: boxes[i].conf, reverse=True) |
| keep_gk_idxs = set(gk_idxs_sorted[:2]) |
| |
| |
| filtered_boxes = [] |
| for i, box in enumerate(boxes): |
| if int(box.cls_id) == 1: |
| |
| if i in keep_gk_idxs: |
| filtered_boxes.append(box) |
| |
| else: |
| |
| filtered_boxes.append(box) |
| |
| return filtered_boxes |
|
|
| def _fix_misplaced_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]: |
| """ |
| """ |
| gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1] |
| player_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 2] |
| |
| if len(gk_idxs) == 0 or len(player_idxs) < 2: |
| return boxes |
| |
| updated_boxes = boxes.copy() |
| |
| for gk_idx in gk_idxs: |
| if boxes[gk_idx].conf < 0.3: |
| updated_boxes[gk_idx].cls_id = 2 |
| |
| return updated_boxes |
|
|
|
|
| def _pre_process_img(self, frames: List[np.ndarray], scale: float = 640.0) -> np.ndarray: |
| """ |
| Preprocess images for ONNX inference. |
| |
| Args: |
| frames: List of BGR frames |
| scale: Target scale for resizing |
| |
| Returns: |
| Preprocessed numpy array ready for ONNX inference |
| """ |
| imgs = np.stack([cv2.resize(frame, (int(scale), int(scale))) for frame in frames]) |
| imgs = imgs.transpose(0, 3, 1, 2) |
| imgs = imgs.astype(np.float32) / 255.0 |
| return imgs |
|
|
| def _post_process_output(self, outputs: np.ndarray, x_scale: float, y_scale: float, |
| conf_thresh: float = 0.6, nms_thresh: float = 0.55) -> List[List[Tuple]]: |
| """ |
| Post-process ONNX model outputs to get detections. |
| |
| Args: |
| outputs: Raw ONNX model outputs |
| x_scale: X-axis scaling factor |
| y_scale: Y-axis scaling factor |
| conf_thresh: Confidence threshold |
| nms_thresh: NMS threshold |
| |
| Returns: |
| List of detections for each frame: [(box, conf, class_id), ...] |
| """ |
| B, C, N = outputs.shape |
| outputs = torch.from_numpy(outputs) |
| outputs = outputs.permute(0, 2, 1) |
| |
| boxes = outputs[..., :4] |
| class_scores = 1 / (1 + torch.exp(-outputs[..., 4:])) |
| conf, class_id = class_scores.max(dim=2) |
|
|
| mask = conf > conf_thresh |
| |
| |
| for i in range(class_id.shape[0]): |
| |
| ball_mask = class_id[i] == 0 |
| ball_idx = ball_mask.nonzero(as_tuple=True)[0] |
| if ball_idx.numel() > 0: |
| |
| best_ball_idx = ball_idx[conf[i, ball_idx].argmax()] |
| if conf[i, best_ball_idx] >= 0.55: |
| mask[i, best_ball_idx] = True |
| |
| batch_idx, pred_idx = mask.nonzero(as_tuple=True) |
|
|
| if len(batch_idx) == 0: |
| return [[] for _ in range(B)] |
| |
| boxes = boxes[batch_idx, pred_idx] |
| conf = conf[batch_idx, pred_idx] |
| class_id = class_id[batch_idx, pred_idx] |
|
|
| |
| x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] |
| x1 = (x - w / 2) * x_scale |
| y1 = (y - h / 2) * y_scale |
| x2 = (x + w / 2) * x_scale |
| y2 = (y + h / 2) * y_scale |
| boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1) |
|
|
| |
| max_coord = 1e4 |
| offset = batch_idx.to(boxes_xyxy) * max_coord |
| boxes_for_nms = boxes_xyxy + offset[:, None] |
|
|
| keep = batched_nms(boxes_for_nms, conf, batch_idx, nms_thresh) |
|
|
| boxes_final = boxes_xyxy[keep] |
| conf_final = conf[keep] |
| class_final = class_id[keep] |
| batch_final = batch_idx[keep] |
|
|
| |
| results = [[] for _ in range(B)] |
| for b in range(B): |
| mask_b = batch_final == b |
| if mask_b.sum() == 0: |
| continue |
| results[b] = list(zip(boxes_final[mask_b].numpy(), |
| conf_final[mask_b].numpy(), |
| class_final[mask_b].numpy())) |
| return results |
|
|
| def _ioa(self, a: BoundingBox, b: BoundingBox) -> float: |
| inter = self._intersect_area(a, b) |
| aa = self._area(a) |
| if aa <= 0: |
| return 0.0 |
| return inter / aa |
|
|
| def suppress_small_contained(self, boxes: List[BoundingBox]) -> List[BoundingBox]: |
| if len(boxes) <= 1: |
| return boxes |
| keep = [True] * len(boxes) |
| areas = [self._area(bb) for bb in boxes] |
| for i in range(len(boxes)): |
| if not keep[i]: |
| continue |
| for j in range(len(boxes)): |
| if i == j or not keep[j]: |
| continue |
| ai, aj = areas[i], areas[j] |
| if ai == 0 or aj == 0: |
| continue |
| if ai <= aj: |
| ratio = ai / aj |
| if ratio <= self.SMALL_RATIO_MAX: |
| ioa_i_in_j = self._ioa(boxes[i], boxes[j]) |
| if ioa_i_in_j >= self.SMALL_CONTAINED_IOA: |
| keep[i] = False |
| break |
| else: |
| ratio = aj / ai |
| if ratio <= self.SMALL_RATIO_MAX: |
| ioa_j_in_i = self._ioa(boxes[j], boxes[i]) |
| if ioa_j_in_i >= self.SMALL_CONTAINED_IOA: |
| keep[j] = False |
| return [bb for bb, k in zip(boxes, keep) if k] |
|
|
| def _detect_objects_batch(self, batch_images: List[ndarray], offset: int) -> Dict[int, List[BoundingBox]]: |
| """ |
| Phase 1: Object detection for all frames in batch. |
| Returns detected objects with players still having class_id=2 (before team classification). |
| |
| Args: |
| batch_images: List of images to process |
| offset: Frame offset for numbering |
| |
| Returns: |
| Dictionary mapping frame_id to list of detected boxes |
| """ |
| bboxes: Dict[int, List[BoundingBox]] = {} |
|
|
| if len(batch_images) == 0: |
| return bboxes |
| |
| print(f"Processing batch of {len(batch_images)} images") |
| |
| |
| height, width = batch_images[0].shape[:2] |
| scale = 640.0 |
| x_scale = width / scale |
| y_scale = height / scale |
| |
| |
| max_batch_size = 32 |
| if len(batch_images) > max_batch_size: |
| print(f"Large batch detected ({len(batch_images)} images), splitting into smaller batches of {max_batch_size}") |
| |
| all_bboxes = {} |
| for chunk_start in range(0, len(batch_images), max_batch_size): |
| chunk_end = min(chunk_start + max_batch_size, len(batch_images)) |
| chunk_images = batch_images[chunk_start:chunk_end] |
| chunk_offset = offset + chunk_start |
| print(f"Processing chunk {chunk_start//max_batch_size + 1}: images {chunk_start}-{chunk_end-1}") |
| chunk_bboxes = self._detect_objects_batch(chunk_images, chunk_offset) |
| all_bboxes.update(chunk_bboxes) |
| return all_bboxes |
| |
| |
| imgs = self._pre_process_img(batch_images, scale) |
| actual_batch_size = len(batch_images) |
| |
| |
| model_batch_size = self.bbox_model.get_inputs()[0].shape[0] |
| print(f"Model input shape: {self.bbox_model.get_inputs()[0].shape}, batch_size: {model_batch_size}") |
| |
| if model_batch_size is not None: |
| try: |
| |
| if str(model_batch_size) in ['None', '-1'] or model_batch_size == -1: |
| model_batch_size = None |
| else: |
| model_batch_size = int(model_batch_size) |
| except (ValueError, TypeError): |
| model_batch_size = None |
| |
| print(f"Processed model_batch_size: {model_batch_size}, actual_batch_size: {actual_batch_size}") |
| |
| if model_batch_size and actual_batch_size < model_batch_size: |
| padding_size = model_batch_size - actual_batch_size |
| dummy_img = np.zeros((1, 3, int(scale), int(scale)), dtype=np.float32) |
| padding = np.repeat(dummy_img, padding_size, axis=0) |
| imgs = np.vstack([imgs, padding]) |
| |
| |
| try: |
| input_name = self.bbox_model.get_inputs()[0].name |
| import time |
| start_time = time.time() |
| outputs = self.bbox_model.run(None, {input_name: imgs})[0] |
| inference_time = time.time() - start_time |
| print(f"Inference time: {inference_time:.3f}s for {actual_batch_size} images") |
| |
| |
| if model_batch_size and isinstance(model_batch_size, int) and actual_batch_size < model_batch_size: |
| outputs = outputs[:actual_batch_size] |
| |
| |
| raw_results = self._post_process_output(np.array(outputs), x_scale, y_scale) |
| |
| except Exception as e: |
| print(f"Error during ONNX inference: {e}") |
| return bboxes |
| |
| if not raw_results: |
| return bboxes |
| |
| |
| for frame_idx_in_batch, frame_detections in enumerate(raw_results): |
| if not frame_detections: |
| continue |
| |
| |
| boxes: List[BoundingBox] = [] |
| for box, conf, cls_id in frame_detections: |
| x1, y1, x2, y2 = box |
| if int(cls_id) < 4: |
| boxes.append( |
| BoundingBox( |
| x1=int(x1), |
| y1=int(y1), |
| x2=int(x2), |
| y2=int(y2), |
| cls_id=int(cls_id), |
| conf=float(conf), |
| ) |
| ) |
| |
| |
| footballs = [bb for bb in boxes if int(bb.cls_id) == 0] |
| if len(footballs) > 1: |
| best_ball = max(footballs, key=lambda b: b.conf) |
| boxes = [bb for bb in boxes if int(bb.cls_id) != 0] |
| boxes.append(best_ball) |
|
|
| |
| boxes = suppress_small_contained_boxes(boxes, self.SMALL_CONTAINED_IOA, self.SMALL_RATIO_MAX) |
| |
| |
| |
| |
| |
| boxes = self._handle_multiple_goalkeepers(boxes) |
| |
| |
| frame_id = offset + frame_idx_in_batch |
| bboxes[frame_id] = boxes |
| |
| return bboxes |
|
|
|
|
| def predict_batch( |
| self, |
| batch_images: List[ndarray], |
| offset: int, |
| n_keypoints: int, |
| task_type: Optional[str] = None, |
| ) -> List[TVFrameResult]: |
| process_objects = task_type is None or task_type == "object" |
| process_keypoints = task_type is None or task_type == "keypoint" |
| |
| |
| bboxes: Dict[int, List[BoundingBox]] = {} |
| if process_objects: |
| bboxes = self._detect_objects_batch(batch_images, offset) |
| |
| import time |
| time_start = time.time() |
| |
| if process_objects and bboxes: |
| bboxes, self.team_classifier_fitted, self.player_crops_for_fit = classify_teams_batch( |
| self.team_classifier, |
| self.team_classifier_fitted, |
| self.player_crops_for_fit, |
| batch_images, |
| bboxes, |
| offset, |
| self.MIN_SAMPLES_FOR_FIT, |
| self.MAX_SAMPLES_FOR_FIT, |
| self.SINGLE_PLAYER_HUE_PIVOT |
| ) |
| self.team_classifier_fitted = False |
| self.player_crops_for_fit = [] |
| print(f"Time Team Classification: {time.time() - time_start} s") |
|
|
| |
| keypoints: Dict[int, List[Tuple[int, int]]] = {} |
| if process_keypoints: |
| keypoints = self._detect_keypoints_batch(batch_images, offset, n_keypoints) |
| |
| |
| results: List[TVFrameResult] = [] |
| for frame_number in range(offset, offset + len(batch_images)): |
| results.append( |
| TVFrameResult( |
| frame_id=frame_number, |
| boxes=bboxes.get(frame_number, []), |
| keypoints=keypoints.get( |
| frame_number, |
| [(0, 0) for _ in range(n_keypoints)], |
| ), |
| ) |
| ) |
| return results |
|
|
| def _detect_keypoints_batch(self, batch_images: List[ndarray], |
| offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]: |
| """ |
| Phase 3: Keypoint detection for all frames in batch. |
| |
| Args: |
| batch_images: List of images to process |
| offset: Frame offset for numbering |
| n_keypoints: Number of keypoints expected |
| |
| Returns: |
| Dictionary mapping frame_id to list of keypoint coordinates |
| """ |
| keypoints: Dict[int, List[Tuple[int, int]]] = {} |
| keypoints_model_results = self.keypoints_model.predict(batch_images) |
| |
| if keypoints_model_results is None: |
| return keypoints |
| |
| for frame_idx_in_batch, detection in enumerate(keypoints_model_results): |
| if not hasattr(detection, "keypoints") or detection.keypoints is None: |
| continue |
| |
| |
| frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] |
| for i, part_points in enumerate(detection.keypoints.data): |
| for k_id, (x, y, _) in enumerate(part_points): |
| confidence = float(detection.keypoints.conf[i][k_id]) |
| frame_keypoints_with_conf.append((int(x), int(y), confidence)) |
| |
| |
| if len(frame_keypoints_with_conf) < n_keypoints: |
| frame_keypoints_with_conf.extend( |
| [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) |
| ) |
| else: |
| frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] |
| |
| |
| filtered_keypoints: List[Tuple[int, int]] = [] |
| for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): |
| if idx in self.CORNER_INDICES: |
| |
| if confidence < 0.3: |
| filtered_keypoints.append((0, 0)) |
| else: |
| filtered_keypoints.append((int(x), int(y))) |
| else: |
| |
| if confidence < 0.5: |
| filtered_keypoints.append((0, 0)) |
| else: |
| filtered_keypoints.append((int(x), int(y))) |
| |
| frame_id = offset + frame_idx_in_batch |
| keypoints[frame_id] = filtered_keypoints |
| |
| return keypoints |