visiontest / miner1.py
tarto2's picture
Upload folder using huggingface_hub
e4189f9 verified
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import sys
import os
from numpy import ndarray
from pydantic import BaseModel
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from keypoint_helper import run_keypoints_post_processing
from keypoint_helper_v2 import run_keypoints_post_processing as run_keypoints_post_processing_v2
from ultralytics import YOLO
from team_cluster import TeamClassifier
from utils import (
BoundingBox,
Constants,
)
import time
import torch
import gc
import cv2
import numpy as np
from collections import defaultdict
from pitch import process_batch_input, get_cls_net
from keypoint_evaluation import (
evaluate_keypoints_for_frame,
evaluate_keypoints_for_frame_gpu,
load_template_from_file,
evaluate_keypoints_for_frame_opencv_cuda,
evaluate_keypoints_batch_for_frame,
)
import yaml
class BoundingBox(BaseModel):
x1: int
y1: int
x2: int
y2: int
cls_id: int
conf: float
class TVFrameResult(BaseModel):
frame_id: int
boxes: List[BoundingBox]
keypoints: List[Tuple[int, int]]
class Miner:
SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
CORNER_INDICES = Constants.CORNER_INDICES
KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting
def __init__(self, path_hf_repo: Path) -> None:
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = path_hf_repo / "detection.onnx"
self.bbox_model = YOLO(model_path)
print(f"BBox Model Loaded: class name {self.bbox_model.names}")
team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
self.team_classifier = TeamClassifier(
device=device,
batch_size=32,
model_name=str(team_model_path)
)
print("Team Classifier Loaded")
# Team classification state
self.team_classifier_fitted = False
self.player_crops_for_fit = []
self.keypoints_model_yolo = YOLO(path_hf_repo / "keypoint.pt")
model_kp_path = path_hf_repo / 'keypoint'
config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
loaded_state_kp = torch.load(model_kp_path, map_location=device)
model = get_cls_net(cfg_kp)
model.load_state_dict(loaded_state_kp)
model.to(device)
model.eval()
self.keypoints_model = model
print("Keypoints Model (keypoint.pt) Loaded")
template_image_path = path_hf_repo / "football_pitch_template.png"
self.template_image, self.template_keypoints = load_template_from_file(str(template_image_path))
self.kp_threshold = 0.1
self.pitch_batch_size = 4
self.health = "healthy"
print("✅ Keypoints Model Loaded")
except Exception as e:
self.health = "❌ Miner initialization failed: " + str(e)
print(self.health)
def __repr__(self) -> str:
if self.health == 'healthy':
return (
f"health: {self.health}\n"
f"BBox Model: {type(self.bbox_model).__name__}\n"
f"Keypoints Model: {type(self.keypoints_model).__name__}"
)
else:
return self.health
def _calculate_iou(self, box1: Tuple[float, float, float, float],
box2: Tuple[float, float, float, float]) -> float:
"""
Calculate Intersection over Union (IoU) between two bounding boxes.
Args:
box1: (x1, y1, x2, y2)
box2: (x1, y1, x2, y2)
Returns:
IoU score (0-1)
"""
x1_1, y1_1, x2_1, y2_1 = box1
x1_2, y1_2, x2_2, y2_2 = box2
# Calculate intersection area
x_left = max(x1_1, x1_2)
y_top = max(y1_1, y1_2)
x_right = min(x2_1, x2_2)
y_bottom = min(y2_1, y2_2)
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# Calculate union area
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = box1_area + box2_area - intersection_area
if union_area == 0:
return 0.0
return intersection_area / union_area
def _extract_jersey_region(self, crop: ndarray) -> ndarray:
"""
Extract jersey region (upper body) from player crop.
For close-ups, focuses on upper 60%, for distant shots uses full crop.
"""
if crop is None or crop.size == 0:
return crop
h, w = crop.shape[:2]
if h < 10 or w < 10:
return crop
# For close-up shots, extract upper body (jersey region)
is_closeup = h > 100 or (h * w) > 12000
if is_closeup:
# Upper 60% of the crop (jersey area, avoiding shorts)
jersey_top = 0
jersey_bottom = int(h * 0.60)
jersey_left = max(0, int(w * 0.05))
jersey_right = min(w, int(w * 0.95))
return crop[jersey_top:jersey_bottom, jersey_left:jersey_right]
return crop
def _extract_color_signature(self, crop: ndarray) -> Optional[np.ndarray]:
"""
Extract color signature from jersey region using HSV and LAB color spaces.
Returns a feature vector with dominant colors and color statistics.
"""
if crop is None or crop.size == 0:
return None
jersey_region = self._extract_jersey_region(crop)
if jersey_region.size == 0:
return None
try:
# Convert to HSV and LAB color spaces
hsv = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2HSV)
lab = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2LAB)
# Reshape for processing
hsv_flat = hsv.reshape(-1, 3).astype(np.float32)
lab_flat = lab.reshape(-1, 3).astype(np.float32)
# Compute statistics for HSV
hsv_mean = np.mean(hsv_flat, axis=0) / 255.0
hsv_std = np.std(hsv_flat, axis=0) / 255.0
# Compute statistics for LAB
lab_mean = np.mean(lab_flat, axis=0) / 255.0
lab_std = np.std(lab_flat, axis=0) / 255.0
# Dominant color (most frequent hue)
hue_hist, _ = np.histogram(hsv_flat[:, 0], bins=36, range=(0, 180))
dominant_hue = np.argmax(hue_hist) * 5 # Convert to hue value
# Combine features
color_features = np.concatenate([
hsv_mean,
hsv_std,
lab_mean[:2], # L and A channels (B is less informative)
lab_std[:2],
[dominant_hue / 180.0] # Normalized dominant hue
])
return color_features
except Exception as e:
print(f"Error extracting color signature: {e}")
return None
def _get_spatial_position(self, bbox: Tuple[float, float, float, float],
frame_width: int, frame_height: int) -> Tuple[float, float]:
"""
Get normalized spatial position of player on the pitch.
Returns (x_normalized, y_normalized) where 0,0 is top-left.
"""
x1, y1, x2, y2 = bbox
center_x = (x1 + x2) / 2.0
center_y = (y1 + y2) / 2.0
# Normalize to [0, 1]
x_norm = center_x / frame_width if frame_width > 0 else 0.5
y_norm = center_y / frame_height if frame_height > 0 else 0.5
return (x_norm, y_norm)
def _find_best_match(self, target_box: Tuple[float, float, float, float],
predicted_frame_data: Dict[int, Tuple[Tuple, str]],
iou_threshold: float) -> Tuple[Optional[str], float]:
"""
Find best matching box in predicted frame data using IoU.
"""
best_iou = 0.0
best_team_id = None
for idx, (bbox, team_cls_id) in predicted_frame_data.items():
iou = self._calculate_iou(target_box, bbox)
if iou > best_iou and iou >= iou_threshold:
best_iou = iou
best_team_id = team_cls_id
return (best_team_id, best_iou)
def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
batch_size = 16
detection_results = []
n_frames = len(decoded_images)
for frame_number in range(0, n_frames, batch_size):
batch_images = decoded_images[frame_number: frame_number + batch_size]
detections = self.bbox_model(batch_images, verbose=False, save=False)
detection_results.extend(detections)
return detection_results
def _team_classify(self, detection_results, decoded_images, offset):
self.team_classifier_fitted = False
start = time.time()
# Collect player crops from first batch for fitting
fit_sample_size = 600
player_crops_for_fit = []
for frame_id in range(len(detection_results)):
detection_box = detection_results[frame_id].boxes.data
if len(detection_box) < 4:
continue
# Collect player boxes for team classification fitting (first batch only)
if len(player_crops_for_fit) < fit_sample_size:
frame_image = decoded_images[frame_id]
for box in detection_box:
x1, y1, x2, y2, conf, cls_id = box.tolist()
if conf < 0.5:
continue
mapped_cls_id = str(int(cls_id))
# Only collect player crops (cls_id = 2)
if mapped_cls_id == '2':
crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
if crop.size > 0:
player_crops_for_fit.append(crop)
# Fit team classifier after collecting samples
if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
self.team_classifier.fit(player_crops_for_fit)
self.team_classifier_fitted = True
break
if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16:
print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
self.team_classifier.fit(player_crops_for_fit)
self.team_classifier_fitted = True
end = time.time()
print(f"Fitting Kmeans time: {end - start}")
# Second pass: predict teams with configurable frame skipping optimization
start = time.time()
# Get configuration for frame skipping
prediction_interval = 1 # Default: predict every 2 frames
iou_threshold = 0.3
print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
# Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}}
predicted_frame_data = {}
# Step 1: Predict for frames at prediction_interval only
frames_to_predict = []
for frame_id in range(len(detection_results)):
if frame_id % prediction_interval == 0:
frames_to_predict.append(frame_id)
print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
for frame_id in frames_to_predict:
detection_box = detection_results[frame_id].boxes.data
frame_image = decoded_images[frame_id]
# Collect player crops for this frame
frame_player_crops = []
frame_player_indices = []
frame_player_boxes = []
for idx, box in enumerate(detection_box):
x1, y1, x2, y2, conf, cls_id = box.tolist()
if cls_id == 2 and conf < 0.6:
continue
mapped_cls_id = str(int(cls_id))
# Collect player crops for prediction
if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
if crop.size > 0:
frame_player_crops.append(crop)
frame_player_indices.append(idx)
frame_player_boxes.append((x1, y1, x2, y2))
# Predict teams for all players in this frame
if len(frame_player_crops) > 0:
team_ids = self.team_classifier.predict(frame_player_crops)
predicted_frame_data[frame_id] = {}
for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
# Map team_id (0,1) to cls_id (6,7)
team_cls_id = str(6 + int(team_id))
predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
# Step 2: Process all frames (interpolate skipped frames)
fallback_count = 0
interpolated_count = 0
bboxes: dict[int, list[BoundingBox]] = {}
for frame_id in range(len(detection_results)):
detection_box = detection_results[frame_id].boxes.data
frame_image = decoded_images[frame_id]
boxes = []
team_predictions = {}
if frame_id % prediction_interval == 0:
# Predicted frame: use pre-computed predictions
if frame_id in predicted_frame_data:
for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
team_predictions[idx] = team_cls_id
else:
# Skipped frame: interpolate from neighboring predicted frames
# Find nearest predicted frames
prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
next_predicted_frame = prev_predicted_frame + prediction_interval
# Collect current frame player boxes
for idx, box in enumerate(detection_box):
x1, y1, x2, y2, conf, cls_id = box.tolist()
if cls_id == 2 and conf < 0.6:
continue
mapped_cls_id = str(int(cls_id))
if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
target_box = (x1, y1, x2, y2)
# Try to match with previous predicted frame
best_team_id = None
best_iou = 0.0
if prev_predicted_frame in predicted_frame_data:
team_id, iou = self._find_best_match(
target_box,
predicted_frame_data[prev_predicted_frame],
iou_threshold
)
if team_id is not None:
best_team_id = team_id
best_iou = iou
# Try to match with next predicted frame if available and no good match yet
if best_team_id is None and next_predicted_frame < len(detection_results):
if next_predicted_frame in predicted_frame_data:
team_id, iou = self._find_best_match(
target_box,
predicted_frame_data[next_predicted_frame],
iou_threshold
)
if team_id is not None and iou > best_iou:
best_team_id = team_id
best_iou = iou
# Track interpolation success
if best_team_id is not None:
interpolated_count += 1
else:
# Fallback: if no match found, predict individually
crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
if crop.size > 0:
team_id = self.team_classifier.predict([crop])[0]
best_team_id = str(6 + int(team_id))
fallback_count += 1
if best_team_id is not None:
team_predictions[idx] = best_team_id
# Parse boxes with team classification
for idx, box in enumerate(detection_box):
x1, y1, x2, y2, conf, cls_id = box.tolist()
if cls_id == 2 and conf < 0.6:
continue
# Check overlap with staff box
overlap_staff = False
for idy, boxy in enumerate(detection_box):
s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
if cls_id == 2 and s_cls_id == 4:
staff_iou = self._calculate_iou(box[:4], boxy[:4])
if staff_iou >= 0.8:
overlap_staff = True
break
if overlap_staff:
continue
mapped_cls_id = str(int(cls_id))
# Override cls_id for players with team prediction
if idx in team_predictions:
mapped_cls_id = team_predictions[idx]
if mapped_cls_id != '4':
if int(mapped_cls_id) == 3 and conf < 0.5:
continue
boxes.append(
BoundingBox(
x1=int(x1),
y1=int(y1),
x2=int(x2),
y2=int(y2),
cls_id=int(mapped_cls_id),
conf=float(conf),
)
)
# Handle footballs - keep only the best one
footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
if len(footballs) > 1:
best_ball = max(footballs, key=lambda b: b.conf)
boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
boxes.append(best_ball)
bboxes[offset + frame_id] = boxes
return bboxes
def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
print('=' * 10)
print(f"Offset: {offset}, Batch size: {len(batch_images)}")
print('=' * 10)
start = time.time()
detection_results = self._detect_objects_batch(batch_images)
end = time.time()
print(f"Detection time: {end - start}")
# Use hybrid team classification
start = time.time()
bboxes = self._team_classify(detection_results, batch_images, offset)
end = time.time()
print(f"Team classify time: {end - start}")
# Phase 3: Keypoint Detection
keypoints_yolo: Dict[int, List[Tuple[int, int]]] = {}
keypoints_yolo = self._detect_keypoints_batch(batch_images, offset, n_keypoints)
# pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
# keypoints: Dict[int, List[Tuple[int, int]]] = {}
# start = time.time()
# while True:
# gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# torch.cuda.synchronize()
# device_str = "cuda"
# keypoints_result = process_batch_input(
# batch_images,
# self.keypoints_model,
# self.kp_threshold,
# device_str,
# batch_size=pitch_batch_size,
# )
# if keypoints_result is not None and len(keypoints_result) > 0:
# for frame_number_in_batch, kp_dict in enumerate(keypoints_result):
# if frame_number_in_batch >= len(batch_images):
# break
# frame_keypoints: List[Tuple[int, int]] = []
# try:
# height, width = batch_images[frame_number_in_batch].shape[:2]
# if kp_dict is not None and isinstance(kp_dict, dict):
# for idx in range(32):
# x, y = 0, 0
# kp_idx = idx + 1
# if kp_idx in kp_dict:
# try:
# kp_data = kp_dict[kp_idx]
# if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data:
# x = int(kp_data["x"] * width)
# y = int(kp_data["y"] * height)
# except (KeyError, TypeError, ValueError):
# pass
# frame_keypoints.append((x, y))
# except (IndexError, ValueError, AttributeError):
# frame_keypoints = [(0, 0)] * 32
# if len(frame_keypoints) < n_keypoints:
# frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints)))
# else:
# frame_keypoints = frame_keypoints[:n_keypoints]
# # time1 = time.time()
# # frame_keypoints_yolo = keypoints_yolo.get(offset + frame_number_in_batch, frame_keypoints)
# # valid_keypoints_count = 0
# # valid_keypoints_yolo_count = 0
# # for kp in frame_keypoints:
# # if kp[0] != 0.0 or kp[1] != 0.0:
# # valid_keypoints_count += 1
# # if valid_keypoints_count > 3:
# # break
# # for kp in frame_keypoints_yolo:
# # if kp[0] != 0.0 or kp[1] != 0.0:
# # valid_keypoints_yolo_count += 1
# # if valid_keypoints_yolo_count > 3:
# # break
# # # Evaluate and select best keypoints (using batch evaluation for speed)
# # if valid_keypoints_count > 3 and valid_keypoints_yolo_count > 3:
# # try:
# # # Evaluate both keypoint sets in batch (much faster!)
# # scores = evaluate_keypoints_batch_for_frame(
# # template_keypoints=self.template_keypoints,
# # frame_keypoints_list=[frame_keypoints, frame_keypoints_yolo],
# # frame=batch_images[frame_number_in_batch],
# # floor_markings_template=self.template_image,
# # device="cuda"
# # )
# # score = scores[0]
# # score_yolo = scores[1]
# # # Select the one with higher score
# # if score_yolo > score:
# # frame_keypoints = frame_keypoints_yolo
# # except Exception as e:
# # # Fallback: use YOLO if available, otherwise use pitch model
# # if valid_keypoints_yolo_count > 3:
# # frame_keypoints = frame_keypoints_yolo
# # elif valid_keypoints_yolo_count > 3:
# # # Only YOLO has valid keypoints
# # frame_keypoints = frame_keypoints_yolo
# # time2 = time.time()
# # print(f"Keypoint evaluation time: {time2 - time1}")
# keypoints[offset + frame_number_in_batch] = frame_keypoints
# break
# end = time.time()
# print(f"Keypoint time: {end - start}")
results: List[TVFrameResult] = []
for frame_number in range(offset, offset + len(batch_images)):
frame_boxes = bboxes.get(frame_number, [])
result = TVFrameResult(
frame_id=frame_number,
boxes=frame_boxes,
keypoints=keypoints_yolo.get(
frame_number,
[(0, 0) for _ in range(n_keypoints)],
),
)
results.append(result)
start = time.time()
if len(batch_images) > 0:
h, w = batch_images[0].shape[:2]
results = run_keypoints_post_processing_v2(
results, w, h,
frames=batch_images,
template_keypoints=self.template_keypoints,
floor_markings_template=self.template_image,
offset=offset
)
end = time.time()
print(f"Keypoint post processing time: {end - start}")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
return results
def _detect_keypoints_batch(self, batch_images: List[ndarray],
offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]:
"""
Phase 3: Keypoint detection for all frames in batch.
Args:
batch_images: List of images to process
offset: Frame offset for numbering
n_keypoints: Number of keypoints expected
Returns:
Dictionary mapping frame_id to list of keypoint coordinates
"""
keypoints: Dict[int, List[Tuple[int, int]]] = {}
keypoints_model_results = self.keypoints_model_yolo.predict(batch_images)
if keypoints_model_results is None:
return keypoints
for frame_idx_in_batch, detection in enumerate(keypoints_model_results):
if not hasattr(detection, "keypoints") or detection.keypoints is None:
continue
# Extract keypoints with confidence
frame_keypoints_with_conf: List[Tuple[int, int, float]] = []
for i, part_points in enumerate(detection.keypoints.data):
for k_id, (x, y, _) in enumerate(part_points):
confidence = float(detection.keypoints.conf[i][k_id])
frame_keypoints_with_conf.append((int(x), int(y), confidence))
# Pad or truncate to expected number of keypoints
if len(frame_keypoints_with_conf) < n_keypoints:
frame_keypoints_with_conf.extend(
[(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf))
)
else:
frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints]
# Filter keypoints based on confidence thresholds
filtered_keypoints: List[Tuple[int, int]] = []
for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf):
if idx in self.CORNER_INDICES:
# Corner keypoints have lower confidence threshold
if confidence < 0.3:
filtered_keypoints.append((0, 0))
else:
filtered_keypoints.append((int(x), int(y)))
else:
# Regular keypoints
if confidence < 0.5:
filtered_keypoints.append((0, 0))
else:
filtered_keypoints.append((int(x), int(y)))
frame_id = offset + frame_idx_in_batch
keypoints[frame_id] = filtered_keypoints
return keypoints