SecretVision / miner.py
tarto2's picture
update
acf7a04
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import sys, os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import onnxruntime as ort
import numpy as np
import cv2
from torchvision.ops import batched_nms
import torch
from ultralytics import YOLO
from numpy import ndarray
from pydantic import BaseModel
from team_cluster import TeamClassifier
from utils import (
BoundingBox,
Constants,
suppress_small_contained_boxes,
classify_teams_batch,
)
class TVFrameResult(BaseModel):
frame_id: int
boxes: List[BoundingBox]
keypoints: List[Tuple[int, int]]
class Miner:
"""
Football video analysis system for object detection and team classification.
"""
# Use constants from utils
SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
CORNER_INDICES = Constants.CORNER_INDICES
KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
MAX_SAMPLES_FOR_FIT = 500 # Maximum samples to avoid overfitting
def __init__(self, path_hf_repo: Path) -> None:
providers = [
'CUDAExecutionProvider',
'CPUExecutionProvider'
]
model_path = path_hf_repo / "detection.onnx"
session = ort.InferenceSession(model_path, providers=providers)
input_name = session.get_inputs()[0].name
height = width = 640
dummy = np.zeros((1, 3, height, width), dtype=np.float32)
session.run(None, {input_name: dummy})
model = session
self.bbox_model = model
print("BBox Model Loaded")
self.keypoints_model = YOLO(path_hf_repo / "keypoint.pt")
print("Keypoints Model (keypoint.pt) Loaded")
# Initialize team classifier with OSNet model
team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
device = 'cuda'
self.team_classifier = TeamClassifier(
device=device,
batch_size=32,
model_name=str(team_model_path)
)
print("Team Classifier Loaded")
# Team classification state
self.team_classifier_fitted = False
self.player_crops_for_fit = [] # Collect samples across frames
def __repr__(self) -> str:
return (
f"BBox Model: {type(self.bbox_model).__name__}\n"
f"Keypoints Model: {type(self.keypoints_model).__name__}"
)
def _handle_multiple_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
"""
Handle goalkeeper detection issues:
1. Fix misplaced goalkeepers (standing in middle of field)
2. Limit to maximum 2 goalkeepers (one from each team)
Returns:
Filtered list of boxes with corrected goalkeepers
"""
# Step 1: Fix misplaced goalkeepers first
# Convert goalkeepers in middle of field to regular players
boxes = self._fix_misplaced_goalkeepers(boxes)
# Step 2: Handle multiple goalkeepers (after fixing misplaced ones)
gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1]
if len(gk_idxs) <= 2:
return boxes
# Sort goalkeepers by confidence (highest first)
gk_idxs_sorted = sorted(gk_idxs, key=lambda i: boxes[i].conf, reverse=True)
keep_gk_idxs = set(gk_idxs_sorted[:2]) # Keep top 2 goalkeepers
# Create new list keeping only top 2 goalkeepers
filtered_boxes = []
for i, box in enumerate(boxes):
if int(box.cls_id) == 1:
# Only keep the top 2 goalkeepers by confidence
if i in keep_gk_idxs:
filtered_boxes.append(box)
# Skip extra goalkeepers
else:
# Keep all non-goalkeeper boxes
filtered_boxes.append(box)
return filtered_boxes
def _fix_misplaced_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
"""
"""
gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1]
player_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 2]
if len(gk_idxs) == 0 or len(player_idxs) < 2:
return boxes
updated_boxes = boxes.copy()
for gk_idx in gk_idxs:
if boxes[gk_idx].conf < 0.3:
updated_boxes[gk_idx].cls_id = 2
return updated_boxes
def _pre_process_img(self, frames: List[np.ndarray], scale: float = 640.0) -> np.ndarray:
"""
Preprocess images for ONNX inference.
Args:
frames: List of BGR frames
scale: Target scale for resizing
Returns:
Preprocessed numpy array ready for ONNX inference
"""
imgs = np.stack([cv2.resize(frame, (int(scale), int(scale))) for frame in frames])
imgs = imgs.transpose(0, 3, 1, 2) # BHWC to BCHW
imgs = imgs.astype(np.float32) / 255.0 # Normalize to [0, 1]
return imgs
def _post_process_output(self, outputs: np.ndarray, x_scale: float, y_scale: float,
conf_thresh: float = 0.6, nms_thresh: float = 0.55) -> List[List[Tuple]]:
"""
Post-process ONNX model outputs to get detections.
Args:
outputs: Raw ONNX model outputs
x_scale: X-axis scaling factor
y_scale: Y-axis scaling factor
conf_thresh: Confidence threshold
nms_thresh: NMS threshold
Returns:
List of detections for each frame: [(box, conf, class_id), ...]
"""
B, C, N = outputs.shape
outputs = torch.from_numpy(outputs)
outputs = outputs.permute(0, 2, 1) # B,C,N -> B,N,C
boxes = outputs[..., :4]
class_scores = 1 / (1 + torch.exp(-outputs[..., 4:])) # Sigmoid activation
conf, class_id = class_scores.max(dim=2)
mask = conf > conf_thresh
# Special handling for balls - keep best one even with lower confidence
for i in range(class_id.shape[0]): # loop over batch
# Find detections that are balls
ball_mask = class_id[i] == 0
ball_idx = ball_mask.nonzero(as_tuple=True)[0]
if ball_idx.numel() > 0:
# Pick the one with the highest confidence
best_ball_idx = ball_idx[conf[i, ball_idx].argmax()]
if conf[i, best_ball_idx] >= 0.55: # apply confidence threshold
mask[i, best_ball_idx] = True
batch_idx, pred_idx = mask.nonzero(as_tuple=True)
if len(batch_idx) == 0:
return [[] for _ in range(B)]
boxes = boxes[batch_idx, pred_idx]
conf = conf[batch_idx, pred_idx]
class_id = class_id[batch_idx, pred_idx]
# Convert from center format to xyxy format
x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
x1 = (x - w / 2) * x_scale
y1 = (y - h / 2) * y_scale
x2 = (x + w / 2) * x_scale
y2 = (y + h / 2) * y_scale
boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)
# Apply batched NMS
max_coord = 1e4
offset = batch_idx.to(boxes_xyxy) * max_coord
boxes_for_nms = boxes_xyxy + offset[:, None]
keep = batched_nms(boxes_for_nms, conf, batch_idx, nms_thresh)
boxes_final = boxes_xyxy[keep]
conf_final = conf[keep]
class_final = class_id[keep]
batch_final = batch_idx[keep]
# Group results by batch
results = [[] for _ in range(B)]
for b in range(B):
mask_b = batch_final == b
if mask_b.sum() == 0:
continue
results[b] = list(zip(boxes_final[mask_b].numpy(),
conf_final[mask_b].numpy(),
class_final[mask_b].numpy()))
return results
def _ioa(self, a: BoundingBox, b: BoundingBox) -> float:
inter = self._intersect_area(a, b)
aa = self._area(a)
if aa <= 0:
return 0.0
return inter / aa
def suppress_small_contained(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
if len(boxes) <= 1:
return boxes
keep = [True] * len(boxes)
areas = [self._area(bb) for bb in boxes]
for i in range(len(boxes)):
if not keep[i]:
continue
for j in range(len(boxes)):
if i == j or not keep[j]:
continue
ai, aj = areas[i], areas[j]
if ai == 0 or aj == 0:
continue
if ai <= aj:
ratio = ai / aj
if ratio <= self.SMALL_RATIO_MAX:
ioa_i_in_j = self._ioa(boxes[i], boxes[j])
if ioa_i_in_j >= self.SMALL_CONTAINED_IOA:
keep[i] = False
break
else:
ratio = aj / ai
if ratio <= self.SMALL_RATIO_MAX:
ioa_j_in_i = self._ioa(boxes[j], boxes[i])
if ioa_j_in_i >= self.SMALL_CONTAINED_IOA:
keep[j] = False
return [bb for bb, k in zip(boxes, keep) if k]
def _detect_objects_batch(self, batch_images: List[ndarray], offset: int) -> Dict[int, List[BoundingBox]]:
"""
Phase 1: Object detection for all frames in batch.
Returns detected objects with players still having class_id=2 (before team classification).
Args:
batch_images: List of images to process
offset: Frame offset for numbering
Returns:
Dictionary mapping frame_id to list of detected boxes
"""
bboxes: Dict[int, List[BoundingBox]] = {}
if len(batch_images) == 0:
return bboxes
print(f"Processing batch of {len(batch_images)} images")
# Get original image dimensions for scaling
height, width = batch_images[0].shape[:2]
scale = 640.0
x_scale = width / scale
y_scale = height / scale
# Memory optimization: Process smaller batches if needed
max_batch_size = 32 # Reduce batch size further to prevent memory issues
if len(batch_images) > max_batch_size:
print(f"Large batch detected ({len(batch_images)} images), splitting into smaller batches of {max_batch_size}")
# Process in smaller chunks
all_bboxes = {}
for chunk_start in range(0, len(batch_images), max_batch_size):
chunk_end = min(chunk_start + max_batch_size, len(batch_images))
chunk_images = batch_images[chunk_start:chunk_end]
chunk_offset = offset + chunk_start
print(f"Processing chunk {chunk_start//max_batch_size + 1}: images {chunk_start}-{chunk_end-1}")
chunk_bboxes = self._detect_objects_batch(chunk_images, chunk_offset)
all_bboxes.update(chunk_bboxes)
return all_bboxes
# Preprocess images for ONNX inference
imgs = self._pre_process_img(batch_images, scale)
actual_batch_size = len(batch_images)
# Handle batch size mismatch - pad if needed
model_batch_size = self.bbox_model.get_inputs()[0].shape[0]
print(f"Model input shape: {self.bbox_model.get_inputs()[0].shape}, batch_size: {model_batch_size}")
if model_batch_size is not None:
try:
# Handle dynamic batch size (None, -1, 'None')
if str(model_batch_size) in ['None', '-1'] or model_batch_size == -1:
model_batch_size = None
else:
model_batch_size = int(model_batch_size)
except (ValueError, TypeError):
model_batch_size = None
print(f"Processed model_batch_size: {model_batch_size}, actual_batch_size: {actual_batch_size}")
if model_batch_size and actual_batch_size < model_batch_size:
padding_size = model_batch_size - actual_batch_size
dummy_img = np.zeros((1, 3, int(scale), int(scale)), dtype=np.float32)
padding = np.repeat(dummy_img, padding_size, axis=0)
imgs = np.vstack([imgs, padding])
# ONNX inference with error handling
try:
input_name = self.bbox_model.get_inputs()[0].name
import time
start_time = time.time()
outputs = self.bbox_model.run(None, {input_name: imgs})[0]
inference_time = time.time() - start_time
print(f"Inference time: {inference_time:.3f}s for {actual_batch_size} images")
# Remove padded results if we added padding
if model_batch_size and isinstance(model_batch_size, int) and actual_batch_size < model_batch_size:
outputs = outputs[:actual_batch_size]
# Post-process outputs to get detections
raw_results = self._post_process_output(np.array(outputs), x_scale, y_scale)
except Exception as e:
print(f"Error during ONNX inference: {e}")
return bboxes
if not raw_results:
return bboxes
# Convert raw results to BoundingBox objects and apply processing
for frame_idx_in_batch, frame_detections in enumerate(raw_results):
if not frame_detections:
continue
# Convert to BoundingBox objects
boxes: List[BoundingBox] = []
for box, conf, cls_id in frame_detections:
x1, y1, x2, y2 = box
if int(cls_id) < 4:
boxes.append(
BoundingBox(
x1=int(x1),
y1=int(y1),
x2=int(x2),
y2=int(y2),
cls_id=int(cls_id),
conf=float(conf),
)
)
# Handle footballs - keep only the best one
footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
if len(footballs) > 1:
best_ball = max(footballs, key=lambda b: b.conf)
boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
boxes.append(best_ball)
# Remove overlapping small boxes
boxes = suppress_small_contained_boxes(boxes, self.SMALL_CONTAINED_IOA, self.SMALL_RATIO_MAX)
# Handle goalkeeper detection issues:
# 1. Fix misplaced goalkeepers (convert to players if standing in middle)
# 2. Allow up to 2 goalkeepers maximum (one from each team)
# Goalkeepers remain class_id = 1 (no team assignment)
boxes = self._handle_multiple_goalkeepers(boxes)
# Store results (players still have class_id=2, will be classified in phase 2)
frame_id = offset + frame_idx_in_batch
bboxes[frame_id] = boxes
return bboxes
def predict_batch(
self,
batch_images: List[ndarray],
offset: int,
n_keypoints: int,
task_type: Optional[str] = None,
) -> List[TVFrameResult]:
process_objects = task_type is None or task_type == "object"
process_keypoints = task_type is None or task_type == "keypoint"
# Phase 1: Object Detection for all frames
bboxes: Dict[int, List[BoundingBox]] = {}
if process_objects:
bboxes = self._detect_objects_batch(batch_images, offset)
import time
time_start = time.time()
# Phase 2: Team Classification for all detected players
if process_objects and bboxes:
bboxes, self.team_classifier_fitted, self.player_crops_for_fit = classify_teams_batch(
self.team_classifier,
self.team_classifier_fitted,
self.player_crops_for_fit,
batch_images,
bboxes,
offset,
self.MIN_SAMPLES_FOR_FIT,
self.MAX_SAMPLES_FOR_FIT,
self.SINGLE_PLAYER_HUE_PIVOT
)
self.team_classifier_fitted = False
self.player_crops_for_fit = []
print(f"Time Team Classification: {time.time() - time_start} s")
# Phase 3: Keypoint Detection
keypoints: Dict[int, List[Tuple[int, int]]] = {}
if process_keypoints:
keypoints = self._detect_keypoints_batch(batch_images, offset, n_keypoints)
# Phase 4: Combine results
results: List[TVFrameResult] = []
for frame_number in range(offset, offset + len(batch_images)):
results.append(
TVFrameResult(
frame_id=frame_number,
boxes=bboxes.get(frame_number, []),
keypoints=keypoints.get(
frame_number,
[(0, 0) for _ in range(n_keypoints)],
),
)
)
return results
def _detect_keypoints_batch(self, batch_images: List[ndarray],
offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]:
"""
Phase 3: Keypoint detection for all frames in batch.
Args:
batch_images: List of images to process
offset: Frame offset for numbering
n_keypoints: Number of keypoints expected
Returns:
Dictionary mapping frame_id to list of keypoint coordinates
"""
keypoints: Dict[int, List[Tuple[int, int]]] = {}
keypoints_model_results = self.keypoints_model.predict(batch_images)
if keypoints_model_results is None:
return keypoints
for frame_idx_in_batch, detection in enumerate(keypoints_model_results):
if not hasattr(detection, "keypoints") or detection.keypoints is None:
continue
# Extract keypoints with confidence
frame_keypoints_with_conf: List[Tuple[int, int, float]] = []
for i, part_points in enumerate(detection.keypoints.data):
for k_id, (x, y, _) in enumerate(part_points):
confidence = float(detection.keypoints.conf[i][k_id])
frame_keypoints_with_conf.append((int(x), int(y), confidence))
# Pad or truncate to expected number of keypoints
if len(frame_keypoints_with_conf) < n_keypoints:
frame_keypoints_with_conf.extend(
[(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf))
)
else:
frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints]
# Filter keypoints based on confidence thresholds
filtered_keypoints: List[Tuple[int, int]] = []
for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf):
if idx in self.CORNER_INDICES:
# Corner keypoints have lower confidence threshold
if confidence < 0.3:
filtered_keypoints.append((0, 0))
else:
filtered_keypoints.append((int(x), int(y)))
else:
# Regular keypoints
if confidence < 0.5:
filtered_keypoints.append((0, 0))
else:
filtered_keypoints.append((int(x), int(y)))
frame_id = offset + frame_idx_in_batch
keypoints[frame_id] = filtered_keypoints
return keypoints