from pathlib import Path from typing import List, Tuple, Dict, Optional import sys, os sys.path.append(os.path.dirname(os.path.abspath(__file__))) import onnxruntime as ort import numpy as np import cv2 from torchvision.ops import batched_nms import torch from ultralytics import YOLO from numpy import ndarray from pydantic import BaseModel from team_cluster import TeamClassifier from utils import ( BoundingBox, Constants, suppress_small_contained_boxes, classify_teams_batch, ) class TVFrameResult(BaseModel): frame_id: int boxes: List[BoundingBox] keypoints: List[Tuple[int, int]] class Miner: """ Football video analysis system for object detection and team classification. """ # Use constants from utils SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT CORNER_INDICES = Constants.CORNER_INDICES KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier MAX_SAMPLES_FOR_FIT = 500 # Maximum samples to avoid overfitting def __init__(self, path_hf_repo: Path) -> None: providers = [ 'CUDAExecutionProvider', 'CPUExecutionProvider' ] model_path = path_hf_repo / "detection.onnx" session = ort.InferenceSession(model_path, providers=providers) input_name = session.get_inputs()[0].name height = width = 640 dummy = np.zeros((1, 3, height, width), dtype=np.float32) session.run(None, {input_name: dummy}) model = session self.bbox_model = model print("BBox Model Loaded") self.keypoints_model = YOLO(path_hf_repo / "keypoint.pt") print("Keypoints Model (keypoint.pt) Loaded") # Initialize team classifier with OSNet model team_model_path = path_hf_repo / "osnet_model.pth.tar-100" device = 'cuda' self.team_classifier = TeamClassifier( device=device, batch_size=32, model_name=str(team_model_path) ) print("Team Classifier Loaded") # Team classification state self.team_classifier_fitted = False self.player_crops_for_fit = [] # Collect samples across frames def __repr__(self) -> str: return ( f"BBox Model: {type(self.bbox_model).__name__}\n" f"Keypoints Model: {type(self.keypoints_model).__name__}" ) def _handle_multiple_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]: """ Handle goalkeeper detection issues: 1. Fix misplaced goalkeepers (standing in middle of field) 2. Limit to maximum 2 goalkeepers (one from each team) Returns: Filtered list of boxes with corrected goalkeepers """ # Step 1: Fix misplaced goalkeepers first # Convert goalkeepers in middle of field to regular players boxes = self._fix_misplaced_goalkeepers(boxes) # Step 2: Handle multiple goalkeepers (after fixing misplaced ones) gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1] if len(gk_idxs) <= 2: return boxes # Sort goalkeepers by confidence (highest first) gk_idxs_sorted = sorted(gk_idxs, key=lambda i: boxes[i].conf, reverse=True) keep_gk_idxs = set(gk_idxs_sorted[:2]) # Keep top 2 goalkeepers # Create new list keeping only top 2 goalkeepers filtered_boxes = [] for i, box in enumerate(boxes): if int(box.cls_id) == 1: # Only keep the top 2 goalkeepers by confidence if i in keep_gk_idxs: filtered_boxes.append(box) # Skip extra goalkeepers else: # Keep all non-goalkeeper boxes filtered_boxes.append(box) return filtered_boxes def _fix_misplaced_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]: """ """ gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1] player_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 2] if len(gk_idxs) == 0 or len(player_idxs) < 2: return boxes updated_boxes = boxes.copy() for gk_idx in gk_idxs: if boxes[gk_idx].conf < 0.3: updated_boxes[gk_idx].cls_id = 2 return updated_boxes def _pre_process_img(self, frames: List[np.ndarray], scale: float = 640.0) -> np.ndarray: """ Preprocess images for ONNX inference. Args: frames: List of BGR frames scale: Target scale for resizing Returns: Preprocessed numpy array ready for ONNX inference """ imgs = np.stack([cv2.resize(frame, (int(scale), int(scale))) for frame in frames]) imgs = imgs.transpose(0, 3, 1, 2) # BHWC to BCHW imgs = imgs.astype(np.float32) / 255.0 # Normalize to [0, 1] return imgs def _post_process_output(self, outputs: np.ndarray, x_scale: float, y_scale: float, conf_thresh: float = 0.6, nms_thresh: float = 0.55) -> List[List[Tuple]]: """ Post-process ONNX model outputs to get detections. Args: outputs: Raw ONNX model outputs x_scale: X-axis scaling factor y_scale: Y-axis scaling factor conf_thresh: Confidence threshold nms_thresh: NMS threshold Returns: List of detections for each frame: [(box, conf, class_id), ...] """ B, C, N = outputs.shape outputs = torch.from_numpy(outputs) outputs = outputs.permute(0, 2, 1) # B,C,N -> B,N,C boxes = outputs[..., :4] class_scores = 1 / (1 + torch.exp(-outputs[..., 4:])) # Sigmoid activation conf, class_id = class_scores.max(dim=2) mask = conf > conf_thresh # Special handling for balls - keep best one even with lower confidence for i in range(class_id.shape[0]): # loop over batch # Find detections that are balls ball_mask = class_id[i] == 0 ball_idx = ball_mask.nonzero(as_tuple=True)[0] if ball_idx.numel() > 0: # Pick the one with the highest confidence best_ball_idx = ball_idx[conf[i, ball_idx].argmax()] if conf[i, best_ball_idx] >= 0.55: # apply confidence threshold mask[i, best_ball_idx] = True batch_idx, pred_idx = mask.nonzero(as_tuple=True) if len(batch_idx) == 0: return [[] for _ in range(B)] boxes = boxes[batch_idx, pred_idx] conf = conf[batch_idx, pred_idx] class_id = class_id[batch_idx, pred_idx] # Convert from center format to xyxy format x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] x1 = (x - w / 2) * x_scale y1 = (y - h / 2) * y_scale x2 = (x + w / 2) * x_scale y2 = (y + h / 2) * y_scale boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1) # Apply batched NMS max_coord = 1e4 offset = batch_idx.to(boxes_xyxy) * max_coord boxes_for_nms = boxes_xyxy + offset[:, None] keep = batched_nms(boxes_for_nms, conf, batch_idx, nms_thresh) boxes_final = boxes_xyxy[keep] conf_final = conf[keep] class_final = class_id[keep] batch_final = batch_idx[keep] # Group results by batch results = [[] for _ in range(B)] for b in range(B): mask_b = batch_final == b if mask_b.sum() == 0: continue results[b] = list(zip(boxes_final[mask_b].numpy(), conf_final[mask_b].numpy(), class_final[mask_b].numpy())) return results def _ioa(self, a: BoundingBox, b: BoundingBox) -> float: inter = self._intersect_area(a, b) aa = self._area(a) if aa <= 0: return 0.0 return inter / aa def suppress_small_contained(self, boxes: List[BoundingBox]) -> List[BoundingBox]: if len(boxes) <= 1: return boxes keep = [True] * len(boxes) areas = [self._area(bb) for bb in boxes] for i in range(len(boxes)): if not keep[i]: continue for j in range(len(boxes)): if i == j or not keep[j]: continue ai, aj = areas[i], areas[j] if ai == 0 or aj == 0: continue if ai <= aj: ratio = ai / aj if ratio <= self.SMALL_RATIO_MAX: ioa_i_in_j = self._ioa(boxes[i], boxes[j]) if ioa_i_in_j >= self.SMALL_CONTAINED_IOA: keep[i] = False break else: ratio = aj / ai if ratio <= self.SMALL_RATIO_MAX: ioa_j_in_i = self._ioa(boxes[j], boxes[i]) if ioa_j_in_i >= self.SMALL_CONTAINED_IOA: keep[j] = False return [bb for bb, k in zip(boxes, keep) if k] def _detect_objects_batch(self, batch_images: List[ndarray], offset: int) -> Dict[int, List[BoundingBox]]: """ Phase 1: Object detection for all frames in batch. Returns detected objects with players still having class_id=2 (before team classification). Args: batch_images: List of images to process offset: Frame offset for numbering Returns: Dictionary mapping frame_id to list of detected boxes """ bboxes: Dict[int, List[BoundingBox]] = {} if len(batch_images) == 0: return bboxes print(f"Processing batch of {len(batch_images)} images") # Get original image dimensions for scaling height, width = batch_images[0].shape[:2] scale = 640.0 x_scale = width / scale y_scale = height / scale # Memory optimization: Process smaller batches if needed max_batch_size = 32 # Reduce batch size further to prevent memory issues if len(batch_images) > max_batch_size: print(f"Large batch detected ({len(batch_images)} images), splitting into smaller batches of {max_batch_size}") # Process in smaller chunks all_bboxes = {} for chunk_start in range(0, len(batch_images), max_batch_size): chunk_end = min(chunk_start + max_batch_size, len(batch_images)) chunk_images = batch_images[chunk_start:chunk_end] chunk_offset = offset + chunk_start print(f"Processing chunk {chunk_start//max_batch_size + 1}: images {chunk_start}-{chunk_end-1}") chunk_bboxes = self._detect_objects_batch(chunk_images, chunk_offset) all_bboxes.update(chunk_bboxes) return all_bboxes # Preprocess images for ONNX inference imgs = self._pre_process_img(batch_images, scale) actual_batch_size = len(batch_images) # Handle batch size mismatch - pad if needed model_batch_size = self.bbox_model.get_inputs()[0].shape[0] print(f"Model input shape: {self.bbox_model.get_inputs()[0].shape}, batch_size: {model_batch_size}") if model_batch_size is not None: try: # Handle dynamic batch size (None, -1, 'None') if str(model_batch_size) in ['None', '-1'] or model_batch_size == -1: model_batch_size = None else: model_batch_size = int(model_batch_size) except (ValueError, TypeError): model_batch_size = None print(f"Processed model_batch_size: {model_batch_size}, actual_batch_size: {actual_batch_size}") if model_batch_size and actual_batch_size < model_batch_size: padding_size = model_batch_size - actual_batch_size dummy_img = np.zeros((1, 3, int(scale), int(scale)), dtype=np.float32) padding = np.repeat(dummy_img, padding_size, axis=0) imgs = np.vstack([imgs, padding]) # ONNX inference with error handling try: input_name = self.bbox_model.get_inputs()[0].name import time start_time = time.time() outputs = self.bbox_model.run(None, {input_name: imgs})[0] inference_time = time.time() - start_time print(f"Inference time: {inference_time:.3f}s for {actual_batch_size} images") # Remove padded results if we added padding if model_batch_size and isinstance(model_batch_size, int) and actual_batch_size < model_batch_size: outputs = outputs[:actual_batch_size] # Post-process outputs to get detections raw_results = self._post_process_output(np.array(outputs), x_scale, y_scale) except Exception as e: print(f"Error during ONNX inference: {e}") return bboxes if not raw_results: return bboxes # Convert raw results to BoundingBox objects and apply processing for frame_idx_in_batch, frame_detections in enumerate(raw_results): if not frame_detections: continue # Convert to BoundingBox objects boxes: List[BoundingBox] = [] for box, conf, cls_id in frame_detections: x1, y1, x2, y2 = box if int(cls_id) < 4: boxes.append( BoundingBox( x1=int(x1), y1=int(y1), x2=int(x2), y2=int(y2), cls_id=int(cls_id), conf=float(conf), ) ) # Handle footballs - keep only the best one footballs = [bb for bb in boxes if int(bb.cls_id) == 0] if len(footballs) > 1: best_ball = max(footballs, key=lambda b: b.conf) boxes = [bb for bb in boxes if int(bb.cls_id) != 0] boxes.append(best_ball) # Remove overlapping small boxes boxes = suppress_small_contained_boxes(boxes, self.SMALL_CONTAINED_IOA, self.SMALL_RATIO_MAX) # Handle goalkeeper detection issues: # 1. Fix misplaced goalkeepers (convert to players if standing in middle) # 2. Allow up to 2 goalkeepers maximum (one from each team) # Goalkeepers remain class_id = 1 (no team assignment) boxes = self._handle_multiple_goalkeepers(boxes) # Store results (players still have class_id=2, will be classified in phase 2) frame_id = offset + frame_idx_in_batch bboxes[frame_id] = boxes return bboxes def predict_batch( self, batch_images: List[ndarray], offset: int, n_keypoints: int, task_type: Optional[str] = None, ) -> List[TVFrameResult]: process_objects = task_type is None or task_type == "object" process_keypoints = task_type is None or task_type == "keypoint" # Phase 1: Object Detection for all frames bboxes: Dict[int, List[BoundingBox]] = {} if process_objects: bboxes = self._detect_objects_batch(batch_images, offset) import time time_start = time.time() # Phase 2: Team Classification for all detected players if process_objects and bboxes: bboxes, self.team_classifier_fitted, self.player_crops_for_fit = classify_teams_batch( self.team_classifier, self.team_classifier_fitted, self.player_crops_for_fit, batch_images, bboxes, offset, self.MIN_SAMPLES_FOR_FIT, self.MAX_SAMPLES_FOR_FIT, self.SINGLE_PLAYER_HUE_PIVOT ) self.team_classifier_fitted = False self.player_crops_for_fit = [] print(f"Time Team Classification: {time.time() - time_start} s") # Phase 3: Keypoint Detection keypoints: Dict[int, List[Tuple[int, int]]] = {} if process_keypoints: keypoints = self._detect_keypoints_batch(batch_images, offset, n_keypoints) # Phase 4: Combine results results: List[TVFrameResult] = [] for frame_number in range(offset, offset + len(batch_images)): results.append( TVFrameResult( frame_id=frame_number, boxes=bboxes.get(frame_number, []), keypoints=keypoints.get( frame_number, [(0, 0) for _ in range(n_keypoints)], ), ) ) return results def _detect_keypoints_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]: """ Phase 3: Keypoint detection for all frames in batch. Args: batch_images: List of images to process offset: Frame offset for numbering n_keypoints: Number of keypoints expected Returns: Dictionary mapping frame_id to list of keypoint coordinates """ keypoints: Dict[int, List[Tuple[int, int]]] = {} keypoints_model_results = self.keypoints_model.predict(batch_images) if keypoints_model_results is None: return keypoints for frame_idx_in_batch, detection in enumerate(keypoints_model_results): if not hasattr(detection, "keypoints") or detection.keypoints is None: continue # Extract keypoints with confidence frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] for i, part_points in enumerate(detection.keypoints.data): for k_id, (x, y, _) in enumerate(part_points): confidence = float(detection.keypoints.conf[i][k_id]) frame_keypoints_with_conf.append((int(x), int(y), confidence)) # Pad or truncate to expected number of keypoints if len(frame_keypoints_with_conf) < n_keypoints: frame_keypoints_with_conf.extend( [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) ) else: frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] # Filter keypoints based on confidence thresholds filtered_keypoints: List[Tuple[int, int]] = [] for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): if idx in self.CORNER_INDICES: # Corner keypoints have lower confidence threshold if confidence < 0.3: filtered_keypoints.append((0, 0)) else: filtered_keypoints.append((int(x), int(y))) else: # Regular keypoints if confidence < 0.5: filtered_keypoints.append((0, 0)) else: filtered_keypoints.append((int(x), int(y))) frame_id = offset + frame_idx_in_batch keypoints[frame_id] = filtered_keypoints return keypoints