File size: 3,708 Bytes
b8add4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from pathlib import Path
from typing import List, Tuple, Dict
import sys
import os
from numpy import ndarray
from pydantic import BaseModel
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from ultralytics import YOLO
from team_cluster import TeamClassifier
from utils import (
BoundingBox,
Constants,
)
from inference import predict_batch
import torch
from pitch import get_cls_net
import yaml
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
class BoundingBox(BaseModel):
x1: int
y1: int
x2: int
y2: int
cls_id: int
conf: float
class TVFrameResult(BaseModel):
frame_id: int
boxes: List[BoundingBox]
keypoints: List[Tuple[int, int]]
class Miner:
SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
CORNER_INDICES = Constants.CORNER_INDICES
KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
MAX_SAMPLES_FOR_FIT = 1000 # Maximum samples to avoid overfitting
def __init__(self, path_hf_repo: Path) -> None:
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = path_hf_repo / "football_object_detection.onnx"
self.bbox_model = YOLO(model_path)
print("BBox Model Loaded")
team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
self.team_classifier = TeamClassifier(
device=device,
batch_size=32,
model_name=str(team_model_path)
)
print("Team Classifier Loaded")
# Team classification state
self.team_classifier_fitted = False
self.player_crops_for_fit = []
model_kp_path = path_hf_repo / 'keypoint'
config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
loaded_state_kp = torch.load(model_kp_path, map_location=device)
model = get_cls_net(cfg_kp)
model.load_state_dict(loaded_state_kp)
model.to(device)
model.eval()
self.keypoints_model = model
self.kp_threshold = 0.1
self.pitch_batch_size = 4
self.health = "healthy"
self.path_hf_repo = path_hf_repo
print("✅ Keypoints Model Loaded")
except Exception as e:
self.health = "❌ Miner initialization failed: " + str(e)
print(self.health)
def __repr__(self) -> str:
if self.health == 'healthy':
return (
f"health: {self.health}\n"
f"BBox Model: {type(self.bbox_model).__name__}\n"
f"Keypoints Model: {type(self.keypoints_model).__name__}"
)
else:
return self.health
def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
results = predict_batch(
self.bbox_model,
self.team_classifier,
self.keypoints_model,
batch_images,
offset,
n_keypoints,
self.pitch_batch_size,
self.kp_threshold,
self.path_hf_repo
)
return results |