|
|
from pathlib import Path
|
|
|
from typing import List, Tuple, Dict
|
|
|
import sys
|
|
|
import os
|
|
|
|
|
|
from numpy import ndarray
|
|
|
from pydantic import BaseModel
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
|
from ultralytics import YOLO
|
|
|
from team_cluster import TeamClassifier
|
|
|
from utils import (
|
|
|
BoundingBox,
|
|
|
Constants,
|
|
|
)
|
|
|
from inference import predict_batch
|
|
|
import torch
|
|
|
from pitch import get_cls_net
|
|
|
import yaml
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
|
|
|
|
|
class BoundingBox(BaseModel):
|
|
|
x1: int
|
|
|
y1: int
|
|
|
x2: int
|
|
|
y2: int
|
|
|
cls_id: int
|
|
|
conf: float
|
|
|
|
|
|
|
|
|
class TVFrameResult(BaseModel):
|
|
|
frame_id: int
|
|
|
boxes: List[BoundingBox]
|
|
|
keypoints: List[Tuple[int, int]]
|
|
|
|
|
|
|
|
|
class Miner:
|
|
|
SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
|
|
|
SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
|
|
|
SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
|
|
|
CORNER_INDICES = Constants.CORNER_INDICES
|
|
|
KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
|
|
|
CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
|
|
|
GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
|
|
|
MIN_SAMPLES_FOR_FIT = 16
|
|
|
MAX_SAMPLES_FOR_FIT = 1000
|
|
|
|
|
|
def __init__(self, path_hf_repo: Path) -> None:
|
|
|
try:
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
model_path = path_hf_repo / "football_object_detection.onnx"
|
|
|
self.bbox_model = YOLO(model_path)
|
|
|
|
|
|
print("BBox Model Loaded")
|
|
|
|
|
|
team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
|
|
|
self.team_classifier = TeamClassifier(
|
|
|
device=device,
|
|
|
batch_size=32,
|
|
|
model_name=str(team_model_path)
|
|
|
)
|
|
|
print("Team Classifier Loaded")
|
|
|
|
|
|
|
|
|
self.team_classifier_fitted = False
|
|
|
self.player_crops_for_fit = []
|
|
|
|
|
|
model_kp_path = path_hf_repo / 'keypoint'
|
|
|
config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
|
|
|
cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
|
|
|
|
|
|
loaded_state_kp = torch.load(model_kp_path, map_location=device)
|
|
|
model = get_cls_net(cfg_kp)
|
|
|
model.load_state_dict(loaded_state_kp)
|
|
|
model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
self.keypoints_model = model
|
|
|
self.kp_threshold = 0.1
|
|
|
self.pitch_batch_size = 4
|
|
|
self.health = "healthy"
|
|
|
self.path_hf_repo = path_hf_repo
|
|
|
print("✅ Keypoints Model Loaded")
|
|
|
except Exception as e:
|
|
|
self.health = "❌ Miner initialization failed: " + str(e)
|
|
|
print(self.health)
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
if self.health == 'healthy':
|
|
|
return (
|
|
|
f"health: {self.health}\n"
|
|
|
f"BBox Model: {type(self.bbox_model).__name__}\n"
|
|
|
f"Keypoints Model: {type(self.keypoints_model).__name__}"
|
|
|
)
|
|
|
else:
|
|
|
return self.health
|
|
|
|
|
|
def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
|
|
|
results = predict_batch(
|
|
|
self.bbox_model,
|
|
|
self.team_classifier,
|
|
|
self.keypoints_model,
|
|
|
batch_images,
|
|
|
offset,
|
|
|
n_keypoints,
|
|
|
self.pitch_batch_size,
|
|
|
self.kp_threshold,
|
|
|
self.path_hf_repo
|
|
|
)
|
|
|
return results |