| from pathlib import Path
|
| from typing import List, Tuple, Dict, Optional
|
| import sys
|
| import os
|
|
|
| from numpy import ndarray
|
| from pydantic import BaseModel
|
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| from keypoint_helper import run_keypoints_post_processing
|
| from keypoint_helper_v2 import run_keypoints_post_processing as run_keypoints_post_processing_v2
|
|
|
| from ultralytics import YOLO
|
| from team_cluster import TeamClassifier
|
| from utils import (
|
| BoundingBox,
|
| Constants,
|
| )
|
|
|
| import time
|
| import torch
|
| import gc
|
| import cv2
|
| import numpy as np
|
| from collections import defaultdict
|
| from pitch import process_batch_input, get_cls_net
|
| from keypoint_evaluation import (
|
| evaluate_keypoints_for_frame,
|
| evaluate_keypoints_for_frame_gpu,
|
| load_template_from_file,
|
| evaluate_keypoints_for_frame_opencv_cuda,
|
| evaluate_keypoints_batch_for_frame,
|
| )
|
|
|
| import yaml
|
|
|
|
|
| class BoundingBox(BaseModel):
|
| x1: int
|
| y1: int
|
| x2: int
|
| y2: int
|
| cls_id: int
|
| conf: float
|
|
|
|
|
| class TVFrameResult(BaseModel):
|
| frame_id: int
|
| boxes: List[BoundingBox]
|
| keypoints: List[Tuple[int, int]]
|
|
|
|
|
| class Miner:
|
| SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
|
| SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
|
| SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
|
| CORNER_INDICES = Constants.CORNER_INDICES
|
| KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
|
| CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
|
| GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
|
| MIN_SAMPLES_FOR_FIT = 16
|
| MAX_SAMPLES_FOR_FIT = 600
|
|
|
| def __init__(self, path_hf_repo: Path) -> None:
|
| try:
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| model_path = path_hf_repo / "detection.onnx"
|
| self.bbox_model = YOLO(model_path)
|
|
|
| print(f"BBox Model Loaded: class name {self.bbox_model.names}")
|
|
|
| team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
|
| self.team_classifier = TeamClassifier(
|
| device=device,
|
| batch_size=32,
|
| model_name=str(team_model_path)
|
| )
|
| print("Team Classifier Loaded")
|
|
|
| self.last_score = 0
|
| self.last_valid_keypoints = None
|
|
|
| self.team_classifier_fitted = False
|
| self.player_crops_for_fit = []
|
|
|
| self.keypoints_model_yolo = YOLO(path_hf_repo / "keypoint.pt")
|
|
|
| model_kp_path = path_hf_repo / 'keypoint'
|
| config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
|
| cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
|
|
|
| loaded_state_kp = torch.load(model_kp_path, map_location=device)
|
| model = get_cls_net(cfg_kp)
|
| model.load_state_dict(loaded_state_kp)
|
| model.to(device)
|
| model.eval()
|
|
|
| self.keypoints_model = model
|
| print("Keypoints Model (keypoint.pt) Loaded")
|
|
|
| template_image_path = path_hf_repo / "football_pitch_template.png"
|
| self.template_image, self.template_keypoints = load_template_from_file(str(template_image_path))
|
|
|
| self.kp_threshold = 0.1
|
| self.pitch_batch_size = 4
|
| self.health = "healthy"
|
|
|
| print("✅ Keypoints Model Loaded")
|
| except Exception as e:
|
| self.health = "❌ Miner initialization failed: " + str(e)
|
| print(self.health)
|
|
|
| def __repr__(self) -> str:
|
| if self.health == 'healthy':
|
| return (
|
| f"health: {self.health}\n"
|
| f"BBox Model: {type(self.bbox_model).__name__}\n"
|
| f"Keypoints Model: {type(self.keypoints_model).__name__}"
|
| )
|
| else:
|
| return self.health
|
|
|
| def _calculate_iou(self, box1: Tuple[float, float, float, float],
|
| box2: Tuple[float, float, float, float]) -> float:
|
| """
|
| Calculate Intersection over Union (IoU) between two bounding boxes.
|
| Args:
|
| box1: (x1, y1, x2, y2)
|
| box2: (x1, y1, x2, y2)
|
| Returns:
|
| IoU score (0-1)
|
| """
|
| x1_1, y1_1, x2_1, y2_1 = box1
|
| x1_2, y1_2, x2_2, y2_2 = box2
|
|
|
|
|
| x_left = max(x1_1, x1_2)
|
| y_top = max(y1_1, y1_2)
|
| x_right = min(x2_1, x2_2)
|
| y_bottom = min(y2_1, y2_2)
|
|
|
| if x_right < x_left or y_bottom < y_top:
|
| return 0.0
|
|
|
| intersection_area = (x_right - x_left) * (y_bottom - y_top)
|
|
|
|
|
| box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
|
| box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
|
| union_area = box1_area + box2_area - intersection_area
|
|
|
| if union_area == 0:
|
| return 0.0
|
|
|
| return intersection_area / union_area
|
|
|
| def _extract_jersey_region(self, crop: ndarray) -> ndarray:
|
| """
|
| Extract jersey region (upper body) from player crop.
|
| For close-ups, focuses on upper 60%, for distant shots uses full crop.
|
| """
|
| if crop is None or crop.size == 0:
|
| return crop
|
|
|
| h, w = crop.shape[:2]
|
| if h < 10 or w < 10:
|
| return crop
|
|
|
|
|
| is_closeup = h > 100 or (h * w) > 12000
|
| if is_closeup:
|
|
|
| jersey_top = 0
|
| jersey_bottom = int(h * 0.60)
|
| jersey_left = max(0, int(w * 0.05))
|
| jersey_right = min(w, int(w * 0.95))
|
| return crop[jersey_top:jersey_bottom, jersey_left:jersey_right]
|
| return crop
|
|
|
| def _extract_color_signature(self, crop: ndarray) -> Optional[np.ndarray]:
|
| """
|
| Extract color signature from jersey region using HSV and LAB color spaces.
|
| Returns a feature vector with dominant colors and color statistics.
|
| """
|
| if crop is None or crop.size == 0:
|
| return None
|
|
|
| jersey_region = self._extract_jersey_region(crop)
|
| if jersey_region.size == 0:
|
| return None
|
|
|
| try:
|
|
|
| hsv = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2HSV)
|
| lab = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2LAB)
|
|
|
|
|
| hsv_flat = hsv.reshape(-1, 3).astype(np.float32)
|
| lab_flat = lab.reshape(-1, 3).astype(np.float32)
|
|
|
|
|
| hsv_mean = np.mean(hsv_flat, axis=0) / 255.0
|
| hsv_std = np.std(hsv_flat, axis=0) / 255.0
|
|
|
|
|
| lab_mean = np.mean(lab_flat, axis=0) / 255.0
|
| lab_std = np.std(lab_flat, axis=0) / 255.0
|
|
|
|
|
| hue_hist, _ = np.histogram(hsv_flat[:, 0], bins=36, range=(0, 180))
|
| dominant_hue = np.argmax(hue_hist) * 5
|
|
|
|
|
| color_features = np.concatenate([
|
| hsv_mean,
|
| hsv_std,
|
| lab_mean[:2],
|
| lab_std[:2],
|
| [dominant_hue / 180.0]
|
| ])
|
|
|
| return color_features
|
| except Exception as e:
|
| print(f"Error extracting color signature: {e}")
|
| return None
|
|
|
| def _get_spatial_position(self, bbox: Tuple[float, float, float, float],
|
| frame_width: int, frame_height: int) -> Tuple[float, float]:
|
| """
|
| Get normalized spatial position of player on the pitch.
|
| Returns (x_normalized, y_normalized) where 0,0 is top-left.
|
| """
|
| x1, y1, x2, y2 = bbox
|
| center_x = (x1 + x2) / 2.0
|
| center_y = (y1 + y2) / 2.0
|
|
|
|
|
| x_norm = center_x / frame_width if frame_width > 0 else 0.5
|
| y_norm = center_y / frame_height if frame_height > 0 else 0.5
|
|
|
| return (x_norm, y_norm)
|
|
|
| def _find_best_match(self, target_box: Tuple[float, float, float, float],
|
| predicted_frame_data: Dict[int, Tuple[Tuple, str]],
|
| iou_threshold: float) -> Tuple[Optional[str], float]:
|
| """
|
| Find best matching box in predicted frame data using IoU.
|
| """
|
| best_iou = 0.0
|
| best_team_id = None
|
|
|
| for idx, (bbox, team_cls_id) in predicted_frame_data.items():
|
| iou = self._calculate_iou(target_box, bbox)
|
| if iou > best_iou and iou >= iou_threshold:
|
| best_iou = iou
|
| best_team_id = team_cls_id
|
|
|
| return (best_team_id, best_iou)
|
|
|
| def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
|
| batch_size = 16
|
| detection_results = []
|
| n_frames = len(decoded_images)
|
| for frame_number in range(0, n_frames, batch_size):
|
| batch_images = decoded_images[frame_number: frame_number + batch_size]
|
| detections = self.bbox_model(batch_images, verbose=False, save=False)
|
| detection_results.extend(detections)
|
|
|
| return detection_results
|
|
|
| def _team_classify(self, detection_results, decoded_images, offset):
|
| self.team_classifier_fitted = False
|
| start = time.time()
|
|
|
| fit_sample_size = 600
|
| player_crops_for_fit = []
|
|
|
| for frame_id in range(len(detection_results)):
|
| detection_box = detection_results[frame_id].boxes.data
|
| if len(detection_box) < 4:
|
| continue
|
|
|
| if len(player_crops_for_fit) < fit_sample_size:
|
| frame_image = decoded_images[frame_id]
|
| for box in detection_box:
|
| x1, y1, x2, y2, conf, cls_id = box.tolist()
|
| if conf < 0.5:
|
| continue
|
| mapped_cls_id = str(int(cls_id))
|
|
|
| if mapped_cls_id == '2':
|
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
|
| if crop.size > 0:
|
| player_crops_for_fit.append(crop)
|
|
|
|
|
| if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
|
| print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
|
| self.team_classifier.fit(player_crops_for_fit)
|
| self.team_classifier_fitted = True
|
| break
|
| if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16:
|
| print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
|
| self.team_classifier.fit(player_crops_for_fit)
|
| self.team_classifier_fitted = True
|
| end = time.time()
|
| print(f"Fitting Kmeans time: {end - start}")
|
|
|
|
|
| start = time.time()
|
|
|
|
|
| prediction_interval = 1
|
| iou_threshold = 0.3
|
|
|
| print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
|
|
|
|
|
| predicted_frame_data = {}
|
|
|
|
|
| frames_to_predict = []
|
| for frame_id in range(len(detection_results)):
|
| if frame_id % prediction_interval == 0:
|
| frames_to_predict.append(frame_id)
|
|
|
| print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
|
| f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
|
|
|
| for frame_id in frames_to_predict:
|
| detection_box = detection_results[frame_id].boxes.data
|
| frame_image = decoded_images[frame_id]
|
|
|
|
|
| frame_player_crops = []
|
| frame_player_indices = []
|
| frame_player_boxes = []
|
|
|
| for idx, box in enumerate(detection_box):
|
| x1, y1, x2, y2, conf, cls_id = box.tolist()
|
| if cls_id == 2 and conf < 0.6:
|
| continue
|
| mapped_cls_id = str(int(cls_id))
|
|
|
|
|
| if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
|
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
|
| if crop.size > 0:
|
| frame_player_crops.append(crop)
|
| frame_player_indices.append(idx)
|
| frame_player_boxes.append((x1, y1, x2, y2))
|
|
|
|
|
| if len(frame_player_crops) > 0:
|
| team_ids = self.team_classifier.predict(frame_player_crops)
|
| predicted_frame_data[frame_id] = {}
|
| for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
|
|
|
| team_cls_id = str(6 + int(team_id))
|
| predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
|
|
|
|
|
| fallback_count = 0
|
| interpolated_count = 0
|
| bboxes: dict[int, list[BoundingBox]] = {}
|
| for frame_id in range(len(detection_results)):
|
| detection_box = detection_results[frame_id].boxes.data
|
| frame_image = decoded_images[frame_id]
|
| boxes = []
|
|
|
| team_predictions = {}
|
|
|
| if frame_id % prediction_interval == 0:
|
|
|
| if frame_id in predicted_frame_data:
|
| for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
|
| team_predictions[idx] = team_cls_id
|
| else:
|
|
|
|
|
| prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
|
| next_predicted_frame = prev_predicted_frame + prediction_interval
|
|
|
|
|
| for idx, box in enumerate(detection_box):
|
| x1, y1, x2, y2, conf, cls_id = box.tolist()
|
| if cls_id == 2 and conf < 0.6:
|
| continue
|
| mapped_cls_id = str(int(cls_id))
|
|
|
| if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
|
| target_box = (x1, y1, x2, y2)
|
|
|
|
|
| best_team_id = None
|
| best_iou = 0.0
|
|
|
| if prev_predicted_frame in predicted_frame_data:
|
| team_id, iou = self._find_best_match(
|
| target_box,
|
| predicted_frame_data[prev_predicted_frame],
|
| iou_threshold
|
| )
|
| if team_id is not None:
|
| best_team_id = team_id
|
| best_iou = iou
|
|
|
|
|
| if best_team_id is None and next_predicted_frame < len(detection_results):
|
| if next_predicted_frame in predicted_frame_data:
|
| team_id, iou = self._find_best_match(
|
| target_box,
|
| predicted_frame_data[next_predicted_frame],
|
| iou_threshold
|
| )
|
| if team_id is not None and iou > best_iou:
|
| best_team_id = team_id
|
| best_iou = iou
|
|
|
|
|
| if best_team_id is not None:
|
| interpolated_count += 1
|
| else:
|
|
|
| crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
|
| if crop.size > 0:
|
| team_id = self.team_classifier.predict([crop])[0]
|
| best_team_id = str(6 + int(team_id))
|
| fallback_count += 1
|
|
|
| if best_team_id is not None:
|
| team_predictions[idx] = best_team_id
|
|
|
|
|
| for idx, box in enumerate(detection_box):
|
| x1, y1, x2, y2, conf, cls_id = box.tolist()
|
| if cls_id == 2 and conf < 0.6:
|
| continue
|
|
|
|
|
| overlap_staff = False
|
| for idy, boxy in enumerate(detection_box):
|
| s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
|
| if cls_id == 2 and s_cls_id == 4:
|
| staff_iou = self._calculate_iou(box[:4], boxy[:4])
|
| if staff_iou >= 0.8:
|
| overlap_staff = True
|
| break
|
| if overlap_staff:
|
| continue
|
|
|
| mapped_cls_id = str(int(cls_id))
|
|
|
|
|
| if idx in team_predictions:
|
| mapped_cls_id = team_predictions[idx]
|
| if mapped_cls_id != '4':
|
| if int(mapped_cls_id) == 3 and conf < 0.5:
|
| continue
|
| boxes.append(
|
| BoundingBox(
|
| x1=int(x1),
|
| y1=int(y1),
|
| x2=int(x2),
|
| y2=int(y2),
|
| cls_id=int(mapped_cls_id),
|
| conf=float(conf),
|
| )
|
| )
|
|
|
| footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
|
| if len(footballs) > 1:
|
| best_ball = max(footballs, key=lambda b: b.conf)
|
| boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
|
| boxes.append(best_ball)
|
|
|
| bboxes[offset + frame_id] = boxes
|
| return bboxes
|
|
|
|
|
| def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
|
| start = time.time()
|
| detection_results = self._detect_objects_batch(batch_images)
|
| end = time.time()
|
| print(f"Detection time: {end - start}")
|
|
|
|
|
| start = time.time()
|
| bboxes = self._team_classify(detection_results, batch_images, offset)
|
| end = time.time()
|
| print(f"Team classify time: {end - start}")
|
|
|
|
|
| start = time.time()
|
| keypoints_yolo: Dict[int, List[Tuple[int, int]]] = {}
|
|
|
| keypoints_yolo = self._detect_keypoints_batch(batch_images, offset, n_keypoints)
|
|
|
|
|
| pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
|
| keypoints: Dict[int, List[Tuple[int, int]]] = {}
|
|
|
| start = time.time()
|
|
|
| while True:
|
| gc.collect()
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
| torch.cuda.synchronize()
|
| device_str = "cuda"
|
| keypoints_result = process_batch_input(
|
| batch_images,
|
| self.keypoints_model,
|
| self.kp_threshold,
|
| device_str,
|
| batch_size=pitch_batch_size,
|
| )
|
| if keypoints_result is not None and len(keypoints_result) > 0:
|
| for frame_number_in_batch, kp_dict in enumerate(keypoints_result):
|
| if frame_number_in_batch >= len(batch_images):
|
| break
|
| frame_keypoints: List[Tuple[int, int]] = []
|
| try:
|
| height, width = batch_images[frame_number_in_batch].shape[:2]
|
| if kp_dict is not None and isinstance(kp_dict, dict):
|
| for idx in range(32):
|
| x, y = 0, 0
|
| kp_idx = idx + 1
|
| if kp_idx in kp_dict:
|
| try:
|
| kp_data = kp_dict[kp_idx]
|
| if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data:
|
| x = int(kp_data["x"] * width)
|
| y = int(kp_data["y"] * height)
|
| except (KeyError, TypeError, ValueError):
|
| pass
|
| frame_keypoints.append((x, y))
|
| except (IndexError, ValueError, AttributeError):
|
| frame_keypoints = [(0, 0)] * 32
|
| if len(frame_keypoints) < n_keypoints:
|
| frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints)))
|
| else:
|
| frame_keypoints = frame_keypoints[:n_keypoints]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| keypoints[offset + frame_number_in_batch] = frame_keypoints
|
| break
|
| end = time.time()
|
| print(f"Keypoint time: {end - start}")
|
|
|
| results: List[TVFrameResult] = []
|
| for frame_number in range(offset, offset + len(batch_images)):
|
| frame_boxes = bboxes.get(frame_number, [])
|
| result = TVFrameResult(
|
| frame_id=frame_number,
|
| boxes=frame_boxes,
|
| keypoints=keypoints.get(
|
| frame_number,
|
| [(0, 0) for _ in range(n_keypoints)],
|
| ),
|
| )
|
| results.append(result)
|
|
|
| results_yolo: List[TVFrameResult] = []
|
| for frame_number in range(offset, offset + len(batch_images)):
|
| frame_boxes = bboxes.get(frame_number, [])
|
| result = TVFrameResult(
|
| frame_id=frame_number,
|
| boxes=frame_boxes,
|
| keypoints=keypoints_yolo.get(
|
| frame_number,
|
| [(0, 0) for _ in range(n_keypoints)],
|
| ),
|
| )
|
| results_yolo.append(result)
|
|
|
| start = time.time()
|
| if len(batch_images) > 0:
|
| h, w = batch_images[0].shape[:2]
|
| results = run_keypoints_post_processing_v2(
|
| results, w, h,
|
| frames=batch_images,
|
| template_keypoints=self.template_keypoints,
|
| floor_markings_template=self.template_image,
|
| offset=offset
|
| )
|
| results_yolo = run_keypoints_post_processing_v2(
|
| results_yolo, w, h,
|
| frames=batch_images,
|
| template_keypoints=self.template_keypoints,
|
| floor_markings_template=self.template_image,
|
| offset=offset
|
| )
|
| end = time.time()
|
| print(f"Keypoint post processing time: {end - start}")
|
|
|
| final_keypoints: Dict[int, List[Tuple[int, int]]] = {}
|
|
|
| for frame_number_in_batch, (result, result_yolo) in enumerate(zip(results, results_yolo)):
|
| frame_keypoints = result.keypoints
|
| try:
|
| if self.last_valid_keypoints is None:
|
| self.last_valid_keypoints = final_keypoints.get(offset + frame_number_in_batch - 1, self.last_valid_keypoints)
|
|
|
| scores = evaluate_keypoints_batch_for_frame(
|
| template_keypoints=self.template_keypoints,
|
| frame_keypoints_list=[result.keypoints, result_yolo.keypoints, self.last_valid_keypoints],
|
| frame=batch_images[frame_number_in_batch],
|
| floor_markings_template=self.template_image,
|
| device="cuda"
|
| )
|
| score = scores[0]
|
| score_yolo = scores[1]
|
| self.last_score = scores[2]
|
|
|
| if self.last_score > score and self.last_score > score_yolo:
|
| frame_keypoints = self.last_valid_keypoints
|
| elif score_yolo > score:
|
| frame_keypoints = result_yolo.keypoints
|
| self.last_score = score_yolo
|
| else:
|
| self.last_score = score
|
|
|
|
|
| except Exception as e:
|
|
|
| print('Error: ', e)
|
|
|
| self.last_valid_keypoints = frame_keypoints
|
|
|
| final_keypoints[offset + frame_number_in_batch] = frame_keypoints
|
|
|
|
|
| final_results: List[TVFrameResult] = []
|
| for frame_number in range(offset, offset + len(batch_images)):
|
| frame_boxes = bboxes.get(frame_number, [])
|
| result = TVFrameResult(
|
| frame_id=frame_number,
|
| boxes=frame_boxes,
|
| keypoints=final_keypoints.get(
|
| frame_number,
|
| [(0, 0) for _ in range(n_keypoints)],
|
| ),
|
| )
|
| final_results.append(result)
|
|
|
|
|
| gc.collect()
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
| torch.cuda.synchronize()
|
|
|
| return final_results
|
|
|
| def _detect_keypoints_batch(self, batch_images: List[ndarray],
|
| offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]:
|
| """
|
| Phase 3: Keypoint detection for all frames in batch.
|
|
|
| Args:
|
| batch_images: List of images to process
|
| offset: Frame offset for numbering
|
| n_keypoints: Number of keypoints expected
|
|
|
| Returns:
|
| Dictionary mapping frame_id to list of keypoint coordinates
|
| """
|
| keypoints: Dict[int, List[Tuple[int, int]]] = {}
|
| keypoints_model_results = self.keypoints_model_yolo.predict(batch_images)
|
|
|
| if keypoints_model_results is None:
|
| return keypoints
|
|
|
| for frame_idx_in_batch, detection in enumerate(keypoints_model_results):
|
| if not hasattr(detection, "keypoints") or detection.keypoints is None:
|
| continue
|
|
|
|
|
| frame_keypoints_with_conf: List[Tuple[int, int, float]] = []
|
| for i, part_points in enumerate(detection.keypoints.data):
|
| for k_id, (x, y, _) in enumerate(part_points):
|
| confidence = float(detection.keypoints.conf[i][k_id])
|
| frame_keypoints_with_conf.append((int(x), int(y), confidence))
|
|
|
|
|
| if len(frame_keypoints_with_conf) < n_keypoints:
|
| frame_keypoints_with_conf.extend(
|
| [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf))
|
| )
|
| else:
|
| frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints]
|
|
|
|
|
| filtered_keypoints: List[Tuple[int, int]] = []
|
| for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf):
|
| if idx in self.CORNER_INDICES:
|
|
|
| if confidence < 0.3:
|
| filtered_keypoints.append((0, 0))
|
| else:
|
| filtered_keypoints.append((int(x), int(y)))
|
| else:
|
|
|
| if confidence < 0.5:
|
| filtered_keypoints.append((0, 0))
|
| else:
|
| filtered_keypoints.append((int(x), int(y)))
|
|
|
| frame_id = offset + frame_idx_in_batch
|
| keypoints[frame_id] = filtered_keypoints
|
|
|
| return keypoints
|
|
|
| def predict_keypoints(
|
| self,
|
| images: List[ndarray],
|
| n_keypoints: int = 32,
|
| batch_size: Optional[int] = None,
|
| conf_threshold: float = 0.5,
|
| corner_conf_threshold: float = 0.3,
|
| verbose: bool = False
|
| ) -> Dict[int, List[Tuple[int, int]]]:
|
| """
|
| Standalone function for keypoint detection on a list of images.
|
| Optimized for maximum prediction speed.
|
|
|
| Args:
|
| images: List of images (numpy arrays) to process
|
| n_keypoints: Number of keypoints expected per frame (default: 32)
|
| batch_size: Batch size for YOLO prediction (None = auto, uses all images)
|
| conf_threshold: Confidence threshold for regular keypoints (default: 0.5)
|
| corner_conf_threshold: Confidence threshold for corner keypoints (default: 0.3)
|
| verbose: Whether to print progress information
|
|
|
| Returns:
|
| Dictionary mapping frame index to list of keypoint coordinates (x, y)
|
| Frame indices start from 0
|
| """
|
| if not images:
|
| return {}
|
|
|
| keypoints: Dict[int, List[Tuple[int, int]]] = {}
|
|
|
|
|
| if batch_size is None:
|
| batch_size = len(images)
|
|
|
|
|
| for batch_start in range(0, len(images), batch_size):
|
| batch_end = min(batch_start + batch_size, len(images))
|
| batch_images = images[batch_start:batch_end]
|
|
|
| if verbose:
|
| print(f"Processing keypoints batch {batch_start}-{batch_end-1} ({len(batch_images)} images)")
|
|
|
|
|
| keypoints_model_results = self.keypoints_model_yolo.predict(
|
| batch_images,
|
| verbose=False,
|
| save=False,
|
| conf=0.1,
|
| )
|
|
|
| if keypoints_model_results is None:
|
|
|
| for frame_idx in range(batch_start, batch_end):
|
| keypoints[frame_idx] = [(0, 0)] * n_keypoints
|
| continue
|
|
|
|
|
| for batch_idx, detection in enumerate(keypoints_model_results):
|
| frame_idx = batch_start + batch_idx
|
|
|
| if not hasattr(detection, "keypoints") or detection.keypoints is None:
|
| keypoints[frame_idx] = [(0, 0)] * n_keypoints
|
| continue
|
|
|
|
|
| frame_keypoints_with_conf: List[Tuple[int, int, float]] = []
|
| try:
|
| for i, part_points in enumerate(detection.keypoints.data):
|
| for k_id, (x, y, _) in enumerate(part_points):
|
| confidence = float(detection.keypoints.conf[i][k_id])
|
| frame_keypoints_with_conf.append((int(x), int(y), confidence))
|
| except (AttributeError, IndexError, TypeError):
|
| keypoints[frame_idx] = [(0, 0)] * n_keypoints
|
| continue
|
|
|
|
|
| if len(frame_keypoints_with_conf) < n_keypoints:
|
| frame_keypoints_with_conf.extend(
|
| [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf))
|
| )
|
| else:
|
| frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints]
|
|
|
|
|
| filtered_keypoints: List[Tuple[int, int]] = []
|
| for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf):
|
| if idx in self.CORNER_INDICES:
|
|
|
| if confidence < corner_conf_threshold:
|
| filtered_keypoints.append((0, 0))
|
| else:
|
| filtered_keypoints.append((int(x), int(y)))
|
| else:
|
|
|
| if confidence < conf_threshold:
|
| filtered_keypoints.append((0, 0))
|
| else:
|
| filtered_keypoints.append((int(x), int(y)))
|
|
|
| keypoints[frame_idx] = filtered_keypoints
|
|
|
| return keypoints
|
|
|
| def predict_objects(
|
| self,
|
| images: List[ndarray],
|
| batch_size: Optional[int] = 16,
|
| conf_threshold: float = 0.5,
|
| iou_threshold: float = 0.45,
|
| classes: Optional[List[int]] = None,
|
| verbose: bool = False,
|
| ) -> Dict[int, List[BoundingBox]]:
|
| """
|
| Standalone high-throughput object detection function.
|
| Runs the YOLO detector directly on raw images while skipping
|
| any team-classification or keypoint stages for maximum FPS.
|
|
|
| Args:
|
| images: List of frames (BGR numpy arrays).
|
| batch_size: Number of frames per inference pass. Use None to process
|
| all frames at once (fastest but highest memory usage).
|
| conf_threshold: Detection confidence threshold.
|
| iou_threshold: IoU threshold for NMS within YOLO.
|
| classes: Optional list of class IDs to keep (None = all classes).
|
| verbose: Whether to print per-batch progress from YOLO.
|
|
|
| Returns:
|
| Dict mapping frame index -> list of BoundingBox predictions.
|
| """
|
| if not images:
|
| return {}
|
|
|
| detections: Dict[int, List[BoundingBox]] = {}
|
| effective_batch = len(images) if batch_size is None else max(1, batch_size)
|
|
|
| for batch_start in range(0, len(images), effective_batch):
|
| batch_end = min(batch_start + effective_batch, len(images))
|
| batch_images = images[batch_start:batch_end]
|
|
|
| start = time.time()
|
| yolo_results = self.bbox_model(
|
| batch_images,
|
| conf=conf_threshold,
|
| iou=iou_threshold,
|
| classes=classes,
|
| verbose=verbose,
|
| save=False,
|
| )
|
| end = time.time()
|
| print(f"YOLO time: {end - start}")
|
|
|
| for local_idx, result in enumerate(yolo_results):
|
| frame_idx = batch_start + local_idx
|
| frame_boxes: List[BoundingBox] = []
|
|
|
| if not hasattr(result, "boxes") or result.boxes is None:
|
| detections[frame_idx] = frame_boxes
|
| continue
|
|
|
| boxes_tensor = result.boxes.data
|
| if boxes_tensor is None:
|
| detections[frame_idx] = frame_boxes
|
| continue
|
|
|
| for box in boxes_tensor:
|
| try:
|
| x1, y1, x2, y2, conf, cls_id = box.tolist()
|
| frame_boxes.append(
|
| BoundingBox(
|
| x1=int(x1),
|
| y1=int(y1),
|
| x2=int(x2),
|
| y2=int(y2),
|
| cls_id=int(cls_id),
|
| conf=float(conf),
|
| )
|
| )
|
| except (ValueError, TypeError):
|
| continue
|
|
|
| detections[frame_idx] = frame_boxes
|
|
|
| return detections
|
| |