|
|
import os |
|
|
from collections import deque, defaultdict |
|
|
from typing import List, Tuple, Dict, Optional, Union |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
import json |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from scipy.ndimage import gaussian_filter |
|
|
|
|
|
import supervision as sv |
|
|
from sports.common.team import TeamClassifier |
|
|
from sports.common.view import ViewTransformer |
|
|
from sports.annotators.soccer import draw_pitch, draw_points_on_pitch, draw_paths_on_pitch |
|
|
from sports.configs.soccer import SoccerPitchConfiguration |
|
|
|
|
|
import gradio as gr |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
from transformers import AutoProcessor, SiglipVisionModel |
|
|
from more_itertools import chunked |
|
|
from sklearn.cluster import KMeans |
|
|
import umap |
|
|
|
|
|
from inference_sdk import InferenceHTTPClient |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY") |
|
|
|
|
|
if not HF_TOKEN or not ROBOFLOW_API_KEY: |
|
|
raise ValueError("โ HF_TOKEN and ROBOFLOW_API_KEY must be set as environment variables.") |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"๐ฅ๏ธ Using device: {DEVICE}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLIENT = InferenceHTTPClient( |
|
|
api_url="https://detect.roboflow.com", |
|
|
api_key=ROBOFLOW_API_KEY, |
|
|
) |
|
|
|
|
|
PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11" |
|
|
FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14" |
|
|
|
|
|
|
|
|
def infer_with_confidence(model_id: str, frame: np.ndarray, confidence_threshold: float = 0.3): |
|
|
"""Run inference and filter by confidence threshold.""" |
|
|
result = CLIENT.infer(frame, model_id=model_id) |
|
|
detections = sv.Detections.from_inference(result) |
|
|
|
|
|
if len(detections) > 0: |
|
|
detections = detections[detections.confidence > confidence_threshold] |
|
|
return result, detections |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SIGLIP_MODEL_PATH = "google/siglip-base-patch16-224" |
|
|
EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH, token=HF_TOKEN).to(DEVICE) |
|
|
EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH, token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG = SoccerPitchConfiguration() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PLAYER_STATS_HEADERS = [ |
|
|
"Player ID", |
|
|
"Team", |
|
|
"Distance (m)", |
|
|
"Avg Speed (km/h)", |
|
|
"Max Speed (km/h)", |
|
|
"Frames Visible", |
|
|
"Time Def 1/3 (frames)", |
|
|
"Time Mid 1/3 (frames)", |
|
|
"Time Att 1/3 (frames)", |
|
|
"Possession (s)", |
|
|
"Possession (%)", |
|
|
] |
|
|
|
|
|
EVENT_HEADERS = [ |
|
|
"Time (s)", |
|
|
"Type", |
|
|
"Team", |
|
|
"From Player", |
|
|
"To Player", |
|
|
"Ball Speed (km/h)", |
|
|
"Ball Distance (m)", |
|
|
"Player Distance (m)", |
|
|
"Description", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_outliers_based_on_distance( |
|
|
positions: List[np.ndarray], |
|
|
distance_threshold: float, |
|
|
) -> List[np.ndarray]: |
|
|
"""Remove outlier positions based on distance threshold.""" |
|
|
last_valid_position: Union[np.ndarray, None] = None |
|
|
cleaned_positions: List[np.ndarray] = [] |
|
|
|
|
|
for position in positions: |
|
|
if len(position) == 0: |
|
|
cleaned_positions.append(position) |
|
|
else: |
|
|
if last_valid_position is None: |
|
|
cleaned_positions.append(position) |
|
|
last_valid_position = position |
|
|
else: |
|
|
distance = np.linalg.norm(position - last_valid_position) |
|
|
if distance > distance_threshold: |
|
|
cleaned_positions.append(np.array([], dtype=np.float64)) |
|
|
else: |
|
|
cleaned_positions.append(position) |
|
|
last_valid_position = position |
|
|
|
|
|
return cleaned_positions |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pitch_distance_m(p1: np.ndarray, p2: np.ndarray) -> float: |
|
|
""" |
|
|
Compute distance between two pitch points in meters. |
|
|
|
|
|
Heuristic: |
|
|
- If pitch length is > 200, assume coordinates are in centimeters and divide by 100. |
|
|
- Otherwise, treat them as meters. |
|
|
""" |
|
|
p1 = np.asarray(p1, dtype=float) |
|
|
p2 = np.asarray(p2, dtype=float) |
|
|
d = float(np.linalg.norm(p2 - p1)) |
|
|
if CONFIG.length > 200: |
|
|
return d / 100.0 |
|
|
else: |
|
|
return d |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PlayerPerformanceTracker: |
|
|
"""Track individual player performance metrics and generate heatmaps.""" |
|
|
|
|
|
def __init__(self, pitch_config, fps: float = 30.0): |
|
|
self.config = pitch_config |
|
|
self.fps = fps |
|
|
self.player_positions = defaultdict(list) |
|
|
self.player_velocities = defaultdict(list) |
|
|
self.player_distances = defaultdict(float) |
|
|
self.player_team = {} |
|
|
self.player_stats = defaultdict( |
|
|
lambda: { |
|
|
"frames_visible": 0, |
|
|
"avg_velocity": 0.0, |
|
|
"max_velocity": 0.0, |
|
|
"time_in_attacking_third": 0, |
|
|
"time_in_defensive_third": 0, |
|
|
"time_in_middle_third": 0, |
|
|
} |
|
|
) |
|
|
|
|
|
def update(self, tracker_id: int, position: np.ndarray, team_id: int, frame: int): |
|
|
"""Update player position and calculate metrics.""" |
|
|
if len(position) != 2: |
|
|
return |
|
|
|
|
|
self.player_team[tracker_id] = team_id |
|
|
self.player_positions[tracker_id].append((position[0], position[1], frame)) |
|
|
self.player_stats[tracker_id]["frames_visible"] += 1 |
|
|
|
|
|
if len(self.player_positions[tracker_id]) > 1: |
|
|
prev_pos = np.array(self.player_positions[tracker_id][-2][:2], dtype=float) |
|
|
curr_pos = np.array(position, dtype=float) |
|
|
|
|
|
|
|
|
distance_m = pitch_distance_m(prev_pos, curr_pos) |
|
|
self.player_distances[tracker_id] += distance_m |
|
|
|
|
|
|
|
|
speed_mps = distance_m * self.fps |
|
|
speed_kmh = speed_mps * 3.6 |
|
|
self.player_velocities[tracker_id].append(speed_kmh) |
|
|
|
|
|
if speed_kmh > self.player_stats[tracker_id]["max_velocity"]: |
|
|
self.player_stats[tracker_id]["max_velocity"] = speed_kmh |
|
|
|
|
|
pitch_length = self.config.length |
|
|
if position[0] < pitch_length / 3: |
|
|
self.player_stats[tracker_id]["time_in_defensive_third"] += 1 |
|
|
elif position[0] < 2 * pitch_length / 3: |
|
|
self.player_stats[tracker_id]["time_in_middle_third"] += 1 |
|
|
else: |
|
|
self.player_stats[tracker_id]["time_in_attacking_third"] += 1 |
|
|
|
|
|
def get_player_stats(self, tracker_id: int) -> dict: |
|
|
"""Get comprehensive stats for a player.""" |
|
|
stats = self.player_stats[tracker_id].copy() |
|
|
|
|
|
if len(self.player_velocities[tracker_id]) > 0: |
|
|
stats["avg_velocity"] = float(np.mean(self.player_velocities[tracker_id])) |
|
|
|
|
|
stats["total_distance_meters"] = float(self.player_distances[tracker_id]) |
|
|
stats["team_id"] = int(self.player_team.get(tracker_id, -1)) |
|
|
|
|
|
return stats |
|
|
|
|
|
def generate_heatmap(self, tracker_id: int, resolution: int = 100) -> np.ndarray: |
|
|
"""Generate heatmap for a specific player.""" |
|
|
if tracker_id not in self.player_positions or len(self.player_positions[tracker_id]) == 0: |
|
|
return np.zeros((resolution, resolution)) |
|
|
|
|
|
positions = np.array([(x, y) for x, y, _ in self.player_positions[tracker_id]]) |
|
|
|
|
|
pitch_length = self.config.length |
|
|
pitch_width = self.config.width |
|
|
|
|
|
heatmap, xedges, yedges = np.histogram2d( |
|
|
positions[:, 0], |
|
|
positions[:, 1], |
|
|
bins=[resolution, resolution], |
|
|
range=[[0, pitch_length], [0, pitch_width]], |
|
|
) |
|
|
|
|
|
heatmap = gaussian_filter(heatmap, sigma=3) |
|
|
|
|
|
return heatmap.T |
|
|
|
|
|
def get_all_players_by_team(self) -> Dict[int, List[int]]: |
|
|
"""Get all player IDs grouped by team.""" |
|
|
teams = defaultdict(list) |
|
|
for tracker_id, team_id in self.player_team.items(): |
|
|
teams[team_id].append(tracker_id) |
|
|
return teams |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PlayerTrackingManager: |
|
|
"""Manages persistent player tracking with team assignment stability.""" |
|
|
|
|
|
def __init__(self, max_history=10): |
|
|
self.tracker_team_history: Dict[int, List[int]] = defaultdict(list) |
|
|
self.max_history = max_history |
|
|
self.active_trackers = set() |
|
|
|
|
|
def update_team_assignment(self, tracker_id: int, team_id: int): |
|
|
"""Store team assignment history for each tracker.""" |
|
|
self.tracker_team_history[tracker_id].append(team_id) |
|
|
if len(self.tracker_team_history[tracker_id]) > self.max_history: |
|
|
self.tracker_team_history[tracker_id].pop(0) |
|
|
self.active_trackers.add(tracker_id) |
|
|
|
|
|
def get_stable_team_id(self, tracker_id: int, current_team_id: int) -> int: |
|
|
"""Get stable team ID using majority voting from history.""" |
|
|
if tracker_id not in self.tracker_team_history or len(self.tracker_team_history[tracker_id]) < 3: |
|
|
return current_team_id |
|
|
|
|
|
history = self.tracker_team_history[tracker_id] |
|
|
team_counts = np.bincount(history) |
|
|
stable_team = int(np.argmax(team_counts)) |
|
|
return stable_team |
|
|
|
|
|
def get_player_count_by_team(self) -> Dict[int, int]: |
|
|
"""Get current count of players per team.""" |
|
|
team_counts = defaultdict(int) |
|
|
for tracker_id in self.active_trackers: |
|
|
if tracker_id in self.tracker_team_history and len(self.tracker_team_history[tracker_id]) > 0: |
|
|
stable_team = self.get_stable_team_id( |
|
|
tracker_id, |
|
|
self.tracker_team_history[tracker_id][-1], |
|
|
) |
|
|
team_counts[stable_team] += 1 |
|
|
return team_counts |
|
|
|
|
|
def reset_frame(self): |
|
|
"""Reset active trackers for new frame.""" |
|
|
self.active_trackers = set() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_player_heatmap_visualization( |
|
|
performance_tracker: PlayerPerformanceTracker, |
|
|
tracker_id: int, |
|
|
) -> np.ndarray: |
|
|
"""Create a single player heatmap overlay on pitch.""" |
|
|
pitch = draw_pitch(CONFIG) |
|
|
heatmap = performance_tracker.generate_heatmap(tracker_id, resolution=150) |
|
|
|
|
|
if heatmap.max() > 0: |
|
|
heatmap = heatmap / heatmap.max() |
|
|
|
|
|
padding = 50 |
|
|
|
|
|
pitch_height, pitch_width = pitch.shape[:2] |
|
|
heatmap_resized = cv2.resize(heatmap, (pitch_width - 2 * padding, pitch_height - 2 * padding)) |
|
|
|
|
|
heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET) |
|
|
|
|
|
overlay = pitch.copy() |
|
|
overlay[padding : pitch_height - padding, padding : pitch_width - padding] = heatmap_colored |
|
|
|
|
|
result = cv2.addWeighted(pitch, 0.6, overlay, 0.4, 0) |
|
|
|
|
|
stats = performance_tracker.get_player_stats(tracker_id) |
|
|
team_color = "Blue" if stats["team_id"] == 0 else "Pink" |
|
|
|
|
|
text_lines = [ |
|
|
f"Player #{tracker_id} ({team_color} Team)", |
|
|
f"Distance: {stats['total_distance_meters']:.1f} m", |
|
|
f"Avg Speed: {stats['avg_velocity']:.2f} km/h", |
|
|
f"Max Speed: {stats['max_velocity']:.2f} km/h", |
|
|
f"Frames: {stats['frames_visible']}", |
|
|
] |
|
|
|
|
|
y_offset = 30 |
|
|
for line in text_lines: |
|
|
cv2.putText( |
|
|
result, |
|
|
line, |
|
|
(10, y_offset), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.6, |
|
|
(255, 255, 255), |
|
|
2, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
y_offset += 25 |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def create_team_comparison_plot(performance_tracker: PlayerPerformanceTracker) -> go.Figure: |
|
|
"""Create interactive performance comparison plots.""" |
|
|
teams = performance_tracker.get_all_players_by_team() |
|
|
|
|
|
fig = make_subplots( |
|
|
rows=2, |
|
|
cols=2, |
|
|
subplot_titles=( |
|
|
"Distance Covered", |
|
|
"Average Speed", |
|
|
"Max Speed", |
|
|
"Activity by Zone", |
|
|
), |
|
|
specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "bar"}, {"type": "bar"}]], |
|
|
) |
|
|
|
|
|
colors = {0: "#00BFFF", 1: "#FF1493"} |
|
|
team_names = {0: "Team 0 (Blue)", 1: "Team 1 (Pink)"} |
|
|
|
|
|
for team_id, player_ids in teams.items(): |
|
|
if team_id not in [0, 1]: |
|
|
continue |
|
|
|
|
|
distances = [] |
|
|
avg_speeds = [] |
|
|
max_speeds = [] |
|
|
attacking_time = [] |
|
|
|
|
|
for pid in player_ids: |
|
|
stats = performance_tracker.get_player_stats(pid) |
|
|
distances.append(stats["total_distance_meters"]) |
|
|
avg_speeds.append(stats["avg_velocity"]) |
|
|
max_speeds.append(stats["max_velocity"]) |
|
|
attacking_time.append(stats["time_in_attacking_third"]) |
|
|
|
|
|
player_labels = [f"#{pid}" for pid in player_ids] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar( |
|
|
x=player_labels, |
|
|
y=distances, |
|
|
name=team_names[team_id], |
|
|
marker_color=colors[team_id], |
|
|
showlegend=True, |
|
|
), |
|
|
row=1, |
|
|
col=1, |
|
|
) |
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar( |
|
|
x=player_labels, |
|
|
y=avg_speeds, |
|
|
name=team_names[team_id], |
|
|
marker_color=colors[team_id], |
|
|
showlegend=False, |
|
|
), |
|
|
row=1, |
|
|
col=2, |
|
|
) |
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar( |
|
|
x=player_labels, |
|
|
y=max_speeds, |
|
|
name=team_names[team_id], |
|
|
marker_color=colors[team_id], |
|
|
showlegend=False, |
|
|
), |
|
|
row=2, |
|
|
col=1, |
|
|
) |
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar( |
|
|
x=player_labels, |
|
|
y=attacking_time, |
|
|
name=team_names[team_id], |
|
|
marker_color=colors[team_id], |
|
|
showlegend=False, |
|
|
), |
|
|
row=2, |
|
|
col=2, |
|
|
) |
|
|
|
|
|
fig.update_xaxes(title_text="Players", row=1, col=1) |
|
|
fig.update_xaxes(title_text="Players", row=1, col=2) |
|
|
fig.update_xaxes(title_text="Players", row=2, col=1) |
|
|
fig.update_xaxes(title_text="Players", row=2, col=2) |
|
|
|
|
|
fig.update_yaxes(title_text="Distance (m)", row=1, col=1) |
|
|
fig.update_yaxes(title_text="Speed (km/h)", row=1, col=2) |
|
|
fig.update_yaxes(title_text="Speed (km/h)", row=2, col=1) |
|
|
fig.update_yaxes(title_text="Frames in Zone", row=2, col=2) |
|
|
|
|
|
fig.update_layout(height=800, title_text="Team Performance Comparison", barmode="group") |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
def create_combined_heatmaps(performance_tracker: PlayerPerformanceTracker) -> np.ndarray: |
|
|
"""Create side-by-side team heatmaps.""" |
|
|
teams = performance_tracker.get_all_players_by_team() |
|
|
|
|
|
team_heatmaps = [] |
|
|
for team_id in [0, 1]: |
|
|
if team_id not in teams: |
|
|
continue |
|
|
|
|
|
combined_heatmap = np.zeros((150, 150)) |
|
|
for pid in teams[team_id]: |
|
|
player_heatmap = performance_tracker.generate_heatmap(pid, resolution=150) |
|
|
combined_heatmap += player_heatmap |
|
|
|
|
|
if combined_heatmap.max() > 0: |
|
|
combined_heatmap = combined_heatmap / combined_heatmap.max() |
|
|
|
|
|
pitch = draw_pitch(CONFIG) |
|
|
padding = 50 |
|
|
pitch_height, pitch_width = pitch.shape[:2] |
|
|
heatmap_resized = cv2.resize( |
|
|
combined_heatmap, |
|
|
(pitch_width - 2 * padding, pitch_height - 2 * padding), |
|
|
) |
|
|
|
|
|
colormap = cv2.COLORMAP_JET if team_id == 0 else cv2.COLORMAP_HOT |
|
|
heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), colormap) |
|
|
|
|
|
overlay = pitch.copy() |
|
|
overlay[padding : pitch_height - padding, padding : pitch_width - padding] = heatmap_colored |
|
|
result = cv2.addWeighted(pitch, 0.5, overlay, 0.5, 0) |
|
|
|
|
|
team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)" |
|
|
cv2.putText( |
|
|
result, |
|
|
team_name, |
|
|
(10, 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
1, |
|
|
(255, 255, 255), |
|
|
2, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
|
|
|
team_heatmaps.append(result) |
|
|
|
|
|
if len(team_heatmaps) == 2: |
|
|
return np.hstack(team_heatmaps) |
|
|
elif len(team_heatmaps) == 1: |
|
|
return team_heatmaps[0] |
|
|
else: |
|
|
return draw_pitch(CONFIG) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detections) -> np.ndarray: |
|
|
"""Assign goalkeepers to the nearest team centroid.""" |
|
|
if len(goalkeepers) == 0 or len(players) == 0: |
|
|
return np.array([]) |
|
|
goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER) |
|
|
players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER) |
|
|
team_0_centroid = players_xy[players.class_id == 0].mean(axis=0) |
|
|
team_1_centroid = players_xy[players.class_id == 1].mean(axis=0) |
|
|
return np.array( |
|
|
[ |
|
|
0 if np.linalg.norm(gk - team_0_centroid) < np.linalg.norm(gk - team_1_centroid) else 1 |
|
|
for gk in goalkeepers_xy |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def create_game_style_radar( |
|
|
pitch_ball_xy, |
|
|
pitch_players_xy, |
|
|
players_class_id, |
|
|
pitch_referees_xy, |
|
|
ball_path=None, |
|
|
): |
|
|
"""Create game-style radar view with ball trail effect.""" |
|
|
annotated_frame = draw_pitch(CONFIG) |
|
|
|
|
|
|
|
|
if ball_path is not None and len(ball_path) > 0: |
|
|
valid_path = [coords for coords in ball_path if len(coords) > 0] |
|
|
if len(valid_path) > 1: |
|
|
for i, coords in enumerate(valid_path[-20:]): |
|
|
if len(coords) == 0: |
|
|
continue |
|
|
alpha = (i + 1) / min(20, len(valid_path)) |
|
|
color = sv.Color(int(255 * alpha), int(255 * alpha), int(255 * alpha)) |
|
|
annotated_frame = draw_points_on_pitch( |
|
|
CONFIG, |
|
|
coords, |
|
|
face_color=color, |
|
|
edge_color=sv.Color.BLACK, |
|
|
radius=int(6 + alpha * 4), |
|
|
pitch=annotated_frame, |
|
|
) |
|
|
|
|
|
|
|
|
if len(pitch_ball_xy) > 0: |
|
|
annotated_frame = draw_points_on_pitch( |
|
|
CONFIG, |
|
|
pitch_ball_xy, |
|
|
face_color=sv.Color.WHITE, |
|
|
edge_color=sv.Color.BLACK, |
|
|
radius=10, |
|
|
pitch=annotated_frame, |
|
|
) |
|
|
|
|
|
|
|
|
for team_id, color_hex in zip([0, 1], ["00BFFF", "FF1493"]): |
|
|
mask = players_class_id == team_id |
|
|
if np.any(mask): |
|
|
annotated_frame = draw_points_on_pitch( |
|
|
CONFIG, |
|
|
pitch_players_xy[mask], |
|
|
face_color=sv.Color.from_hex(color_hex), |
|
|
edge_color=sv.Color.BLACK, |
|
|
radius=16, |
|
|
pitch=annotated_frame, |
|
|
) |
|
|
|
|
|
|
|
|
if len(pitch_referees_xy) > 0: |
|
|
annotated_frame = draw_points_on_pitch( |
|
|
CONFIG, |
|
|
pitch_referees_xy, |
|
|
face_color=sv.Color.from_hex("FFD700"), |
|
|
edge_color=sv.Color.BLACK, |
|
|
radius=16, |
|
|
pitch=annotated_frame, |
|
|
) |
|
|
|
|
|
return annotated_frame |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple: |
|
|
""" |
|
|
Complete football analysis pipeline: |
|
|
- Player & ball detection (Roboflow) |
|
|
- Team classification (SigLIP-based) |
|
|
- Tracking (ByteTrack) with stable team assignments |
|
|
- Field homography -> pitch coordinates |
|
|
- Ball trajectory cleaning |
|
|
- Performance analytics |
|
|
- Simple events + possession + per-player stats |
|
|
""" |
|
|
if not video_path: |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
"โ Please upload a video file.", |
|
|
[], |
|
|
[], |
|
|
None, |
|
|
) |
|
|
|
|
|
try: |
|
|
progress(0, desc="๐ง Initializing...") |
|
|
|
|
|
|
|
|
BALL_ID, GOALKEEPER_ID, PLAYER_ID, REFEREE_ID = 0, 1, 2, 3 |
|
|
STRIDE = 30 |
|
|
MAXLEN = 5 |
|
|
MAX_DISTANCE_THRESHOLD = 500 |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
f"โ Failed to open video: {video_path}", |
|
|
[], |
|
|
[], |
|
|
None, |
|
|
) |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
if fps <= 0: |
|
|
fps = 30.0 |
|
|
dt = 1.0 / fps |
|
|
|
|
|
print(f"๐น Video: {width}x{height}, {fps}fps, {total_frames} frames") |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
output_path = "/tmp/annotated_football.mp4" |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
|
|
|
tracking_manager = PlayerTrackingManager(max_history=10) |
|
|
performance_tracker = PlayerPerformanceTracker(CONFIG, fps=fps) |
|
|
|
|
|
|
|
|
distance_covered_m = defaultdict(float) |
|
|
possession_time_player = defaultdict(float) |
|
|
possession_time_team = defaultdict(float) |
|
|
team_of_player = {} |
|
|
events: List[Dict] = [] |
|
|
|
|
|
prev_owner_tid: Optional[int] = None |
|
|
prev_ball_pos_pitch: Optional[np.ndarray] = None |
|
|
|
|
|
|
|
|
ellipse_annotator = sv.EllipseAnnotator( |
|
|
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]), |
|
|
thickness=2, |
|
|
) |
|
|
label_annotator = sv.LabelAnnotator( |
|
|
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]), |
|
|
text_color=sv.Color.from_hex("#FFFFFF"), |
|
|
text_thickness=2, |
|
|
text_position=sv.Position.BOTTOM_CENTER, |
|
|
) |
|
|
triangle_annotator = sv.TriangleAnnotator( |
|
|
color=sv.Color.from_hex("#FFD700"), |
|
|
base=20, |
|
|
height=17, |
|
|
) |
|
|
|
|
|
|
|
|
tracker = sv.ByteTrack( |
|
|
track_activation_threshold=0.4, |
|
|
lost_track_buffer=60, |
|
|
minimum_matching_threshold=0.85, |
|
|
frame_rate=fps, |
|
|
) |
|
|
tracker.reset() |
|
|
|
|
|
|
|
|
M = deque(maxlen=MAXLEN) |
|
|
ball_path_raw = [] |
|
|
|
|
|
|
|
|
last_pitch_players_xy = None |
|
|
last_players_class_id = None |
|
|
last_pitch_referees_xy = None |
|
|
last_pitch_pos_by_tid: Dict[int, np.ndarray] = {} |
|
|
|
|
|
|
|
|
goal_centers = { |
|
|
0: np.array([0.0, CONFIG.width / 2.0]), |
|
|
1: np.array([CONFIG.length, CONFIG.width / 2.0]), |
|
|
} |
|
|
|
|
|
|
|
|
current_event_text = "" |
|
|
event_text_frames_left = 0 |
|
|
EVENT_TEXT_DURATION_S = 2.0 |
|
|
EVENT_TEXT_DURATION_FRAMES = int(EVENT_TEXT_DURATION_S * fps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.05, desc="๐ Collecting player samples (Step 1/6)...") |
|
|
player_crops = [] |
|
|
frame_count = 0 |
|
|
|
|
|
while frame_count < min(total_frames, 300): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
if frame_count % STRIDE == 0: |
|
|
_, detections = infer_with_confidence(PLAYER_DETECTION_MODEL_ID, frame, 0.3) |
|
|
detections = detections.with_nms(threshold=0.5, class_agnostic=True) |
|
|
players_detections = detections[detections.class_id == PLAYER_ID] |
|
|
|
|
|
if len(players_detections.xyxy) > 0: |
|
|
crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy] |
|
|
player_crops.extend(crops) |
|
|
|
|
|
frame_count += 1 |
|
|
|
|
|
if len(player_crops) == 0: |
|
|
cap.release() |
|
|
out.release() |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
"โ No player crops collected.", |
|
|
[], |
|
|
[], |
|
|
None, |
|
|
) |
|
|
|
|
|
print(f"โ
Collected {len(player_crops)} player samples") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.15, desc="๐ฏ Training team classifier (Step 2/6)...") |
|
|
team_classifier = TeamClassifier(device=DEVICE) |
|
|
team_classifier.fit(player_crops) |
|
|
print("โ
Team classifier trained") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) |
|
|
frame_count = 0 |
|
|
|
|
|
progress(0.2, desc="๐ฌ Processing video frames (Step 3/6)...") |
|
|
|
|
|
frame_idx = 0 |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
frame_idx += 1 |
|
|
t = frame_idx * dt |
|
|
frame_count += 1 |
|
|
tracking_manager.reset_frame() |
|
|
|
|
|
if frame_count % 30 == 0: |
|
|
progress( |
|
|
0.2 + 0.4 * (frame_count / max(total_frames, 1)), |
|
|
desc=f"๐ฌ Processing frame {frame_count}/{total_frames}", |
|
|
) |
|
|
|
|
|
|
|
|
_, detections = infer_with_confidence(PLAYER_DETECTION_MODEL_ID, frame, 0.3) |
|
|
|
|
|
if len(detections.xyxy) == 0: |
|
|
out.write(frame) |
|
|
ball_path_raw.append(np.empty((0, 2))) |
|
|
continue |
|
|
|
|
|
|
|
|
ball_detections = detections[detections.class_id == BALL_ID] |
|
|
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10) |
|
|
|
|
|
all_detections = detections[detections.class_id != BALL_ID] |
|
|
all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True) |
|
|
|
|
|
|
|
|
all_detections = tracker.update_with_detections(detections=all_detections) |
|
|
|
|
|
|
|
|
goalkeepers_detections = all_detections[all_detections.class_id == GOALKEEPER_ID] |
|
|
players_detections = all_detections[all_detections.class_id == PLAYER_ID] |
|
|
referees_detections = all_detections[all_detections.class_id == REFEREE_ID] |
|
|
|
|
|
|
|
|
if len(players_detections.xyxy) > 0: |
|
|
crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy] |
|
|
predicted_teams = team_classifier.predict(crops) |
|
|
|
|
|
|
|
|
for idx, tracker_id in enumerate(players_detections.tracker_id): |
|
|
tracking_manager.update_team_assignment(int(tracker_id), int(predicted_teams[idx])) |
|
|
predicted_teams[idx] = tracking_manager.get_stable_team_id( |
|
|
int(tracker_id), |
|
|
int(predicted_teams[idx]), |
|
|
) |
|
|
|
|
|
players_detections.class_id = predicted_teams |
|
|
|
|
|
|
|
|
goalkeepers_detections.class_id = resolve_goalkeepers_team_id( |
|
|
players_detections, |
|
|
goalkeepers_detections, |
|
|
) |
|
|
|
|
|
|
|
|
referees_detections.class_id -= 1 |
|
|
|
|
|
|
|
|
all_detections = sv.Detections.merge( |
|
|
[players_detections, goalkeepers_detections, referees_detections] |
|
|
) |
|
|
|
|
|
all_detections.class_id = all_detections.class_id.astype(int) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pitch_players_xy = None |
|
|
pitch_referees_xy = None |
|
|
pitch_ball_xy = np.empty((0, 2), dtype=np.float32) |
|
|
frame_ball_pos_pitch = None |
|
|
|
|
|
try: |
|
|
result_field, _ = infer_with_confidence(FIELD_DETECTION_MODEL_ID, frame, 0.3) |
|
|
key_points = sv.KeyPoints.from_inference(result_field) |
|
|
|
|
|
|
|
|
filter_mask = key_points.confidence[0] > 0.5 |
|
|
frame_ref_pts = key_points.xy[0][filter_mask] |
|
|
pitch_ref_pts = np.array(CONFIG.vertices)[filter_mask] |
|
|
|
|
|
if len(frame_ref_pts) >= 4: |
|
|
transformer = ViewTransformer(source=frame_ref_pts, target=pitch_ref_pts) |
|
|
M.append(transformer.m) |
|
|
transformer.m = np.mean(np.array(M), axis=0) |
|
|
|
|
|
|
|
|
frame_ball_xy = ball_detections.get_anchors_coordinates( |
|
|
sv.Position.BOTTOM_CENTER |
|
|
) |
|
|
pitch_ball_xy = ( |
|
|
transformer.transform_points(frame_ball_xy) |
|
|
if len(frame_ball_xy) > 0 |
|
|
else np.empty((0, 2)) |
|
|
) |
|
|
if len(pitch_ball_xy) > 0: |
|
|
frame_ball_pos_pitch = pitch_ball_xy[0] |
|
|
ball_path_raw.append(pitch_ball_xy) |
|
|
|
|
|
|
|
|
all_players = sv.Detections.merge([players_detections, goalkeepers_detections]) |
|
|
players_xy = all_players.get_anchors_coordinates( |
|
|
sv.Position.BOTTOM_CENTER |
|
|
) |
|
|
pitch_players_xy = ( |
|
|
transformer.transform_points(players_xy) |
|
|
if len(players_xy) > 0 |
|
|
else np.empty((0, 2)) |
|
|
) |
|
|
|
|
|
|
|
|
referees_xy = referees_detections.get_anchors_coordinates( |
|
|
sv.Position.BOTTOM_CENTER |
|
|
) |
|
|
pitch_referees_xy = ( |
|
|
transformer.transform_points(referees_xy) |
|
|
if len(referees_xy) > 0 |
|
|
else np.empty((0, 2)) |
|
|
) |
|
|
|
|
|
|
|
|
last_pitch_players_xy = pitch_players_xy |
|
|
last_players_class_id = all_players.class_id |
|
|
last_pitch_referees_xy = pitch_referees_xy |
|
|
|
|
|
|
|
|
for idx, tracker_id in enumerate(all_players.tracker_id): |
|
|
tid_int = int(tracker_id) |
|
|
if idx < len(pitch_players_xy): |
|
|
pos_pitch = pitch_players_xy[idx] |
|
|
performance_tracker.update( |
|
|
tid_int, |
|
|
pos_pitch, |
|
|
int(all_players.class_id[idx]), |
|
|
frame_count, |
|
|
) |
|
|
team_of_player[tid_int] = int(all_players.class_id[idx]) |
|
|
|
|
|
prev_pos = last_pitch_pos_by_tid.get(tid_int) |
|
|
if prev_pos is not None: |
|
|
dist_m = pitch_distance_m(prev_pos, pos_pitch) |
|
|
distance_covered_m[tid_int] += dist_m |
|
|
last_pitch_pos_by_tid[tid_int] = pos_pitch |
|
|
else: |
|
|
ball_path_raw.append(np.empty((0, 2))) |
|
|
except Exception: |
|
|
ball_path_raw.append(np.empty((0, 2))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
owner_tid: Optional[int] = None |
|
|
POSSESSION_RADIUS_M = 5.0 |
|
|
|
|
|
if frame_ball_pos_pitch is not None and pitch_players_xy is not None and len(pitch_players_xy) > 0: |
|
|
dists = np.linalg.norm(pitch_players_xy - frame_ball_pos_pitch, axis=1) |
|
|
j = int(np.argmin(dists)) |
|
|
nearest_dist_m = pitch_distance_m(pitch_players_xy[j], frame_ball_pos_pitch) |
|
|
if nearest_dist_m < POSSESSION_RADIUS_M: |
|
|
owner_tid = int(all_players.tracker_id[j]) |
|
|
|
|
|
|
|
|
if owner_tid is not None: |
|
|
possession_time_player[owner_tid] += dt |
|
|
owner_team = team_of_player.get(owner_tid) |
|
|
if owner_team is not None: |
|
|
possession_time_team[owner_team] += dt |
|
|
|
|
|
def register_event(ev: Dict, text: str): |
|
|
nonlocal current_event_text, event_text_frames_left |
|
|
events.append(ev) |
|
|
if text: |
|
|
current_event_text = text |
|
|
event_text_frames_left = EVENT_TEXT_DURATION_FRAMES |
|
|
|
|
|
|
|
|
if owner_tid != prev_owner_tid: |
|
|
if owner_tid is not None and prev_owner_tid is not None: |
|
|
prev_team = team_of_player.get(prev_owner_tid) |
|
|
cur_team = team_of_player.get(owner_tid) |
|
|
|
|
|
travel_m = 0.0 |
|
|
if prev_ball_pos_pitch is not None and frame_ball_pos_pitch is not None: |
|
|
travel_m = pitch_distance_m(prev_ball_pos_pitch, frame_ball_pos_pitch) |
|
|
|
|
|
MIN_PASS_TRAVEL_M = 3.0 |
|
|
|
|
|
if prev_team is not None and cur_team is not None: |
|
|
if prev_team == cur_team and travel_m > MIN_PASS_TRAVEL_M: |
|
|
|
|
|
register_event( |
|
|
{ |
|
|
"type": "pass", |
|
|
"t": float(t), |
|
|
"from_tid": int(prev_owner_tid), |
|
|
"to_tid": int(owner_tid), |
|
|
"team_id": int(cur_team), |
|
|
"extra": {"distance_m": travel_m}, |
|
|
}, |
|
|
f"Pass: #{prev_owner_tid} โ #{owner_tid} (Team {cur_team})", |
|
|
) |
|
|
elif prev_team != cur_team: |
|
|
|
|
|
d_pp = 999.0 |
|
|
if pitch_players_xy is not None: |
|
|
pos_prev = last_pitch_pos_by_tid.get(int(prev_owner_tid)) |
|
|
pos_cur = last_pitch_pos_by_tid.get(int(owner_tid)) |
|
|
if pos_prev is not None and pos_cur is not None: |
|
|
d_pp = pitch_distance_m(pos_prev, pos_cur) |
|
|
ev_type = "tackle" if d_pp < 3.0 else "interception" |
|
|
label = "Tackle" if ev_type == "tackle" else "Interception" |
|
|
register_event( |
|
|
{ |
|
|
"type": ev_type, |
|
|
"t": float(t), |
|
|
"from_tid": int(prev_owner_tid), |
|
|
"to_tid": int(owner_tid), |
|
|
"team_id": int(cur_team), |
|
|
"extra": { |
|
|
"player_distance_m": d_pp, |
|
|
"ball_travel_m": travel_m, |
|
|
}, |
|
|
}, |
|
|
f"{label}: #{owner_tid} wins ball from #{prev_owner_tid}", |
|
|
) |
|
|
|
|
|
|
|
|
if owner_tid is not None: |
|
|
team_id = team_of_player.get(owner_tid) |
|
|
register_event( |
|
|
{ |
|
|
"type": "possession_change", |
|
|
"t": float(t), |
|
|
"from_tid": int(prev_owner_tid) |
|
|
if prev_owner_tid is not None |
|
|
else None, |
|
|
"to_tid": int(owner_tid), |
|
|
"team_id": int(team_id) if team_id is not None else None, |
|
|
"extra": {}, |
|
|
}, |
|
|
"", |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
prev_ball_pos_pitch is not None |
|
|
and frame_ball_pos_pitch is not None |
|
|
and owner_tid is not None |
|
|
): |
|
|
v_vec = frame_ball_pos_pitch - prev_ball_pos_pitch |
|
|
|
|
|
dist_m = pitch_distance_m(prev_ball_pos_pitch, frame_ball_pos_pitch) |
|
|
speed_mps = dist_m / dt |
|
|
speed_kmh = speed_mps * 3.6 |
|
|
HIGH_SPEED_KMH = 18.0 |
|
|
|
|
|
if speed_kmh > HIGH_SPEED_KMH: |
|
|
shooter_team = team_of_player.get(owner_tid) |
|
|
if shooter_team is not None: |
|
|
target_goal = goal_centers[1 - shooter_team] |
|
|
direction = target_goal - frame_ball_pos_pitch |
|
|
v_norm = np.linalg.norm(v_vec) |
|
|
d_norm = np.linalg.norm(direction) |
|
|
cos_angle = 0.0 |
|
|
if v_norm > 1e-6 and d_norm > 1e-6: |
|
|
cos_angle = float(np.dot(v_vec, direction) / (v_norm * d_norm)) |
|
|
|
|
|
if cos_angle > 0.8: |
|
|
register_event( |
|
|
{ |
|
|
"type": "shot", |
|
|
"t": float(t), |
|
|
"from_tid": int(owner_tid), |
|
|
"to_tid": None, |
|
|
"team_id": int(shooter_team), |
|
|
"extra": {"speed_kmh": speed_kmh}, |
|
|
}, |
|
|
f"Shot by #{owner_tid} (Team {shooter_team}) โ {speed_kmh:.1f} km/h", |
|
|
) |
|
|
else: |
|
|
register_event( |
|
|
{ |
|
|
"type": "clearance", |
|
|
"t": float(t), |
|
|
"from_tid": int(owner_tid), |
|
|
"to_tid": None, |
|
|
"team_id": int(shooter_team), |
|
|
"extra": {"speed_kmh": speed_kmh}, |
|
|
}, |
|
|
f"Clearance by #{owner_tid} (Team {shooter_team})", |
|
|
) |
|
|
|
|
|
prev_owner_tid = owner_tid |
|
|
prev_ball_pos_pitch = frame_ball_pos_pitch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
annotated_frame = frame.copy() |
|
|
|
|
|
|
|
|
labels = [] |
|
|
for tid, cid in zip(all_detections.tracker_id, all_detections.class_id): |
|
|
labels.append(f"#{int(tid)} T{int(cid)}") |
|
|
|
|
|
annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections) |
|
|
annotated_frame = label_annotator.annotate( |
|
|
annotated_frame, |
|
|
all_detections, |
|
|
labels=labels, |
|
|
) |
|
|
annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections) |
|
|
|
|
|
|
|
|
total_poss = sum(possession_time_team.values()) + 1e-6 |
|
|
team0_pct = 100.0 * possession_time_team.get(0, 0.0) / total_poss |
|
|
team1_pct = 100.0 * possession_time_team.get(1, 0.0) / total_poss |
|
|
|
|
|
hud_text = ( |
|
|
f"Team 0 Ball Control: {team0_pct:5.2f}% " |
|
|
f"Team 1 Ball Control: {team1_pct:5.2f}%" |
|
|
) |
|
|
cv2.rectangle( |
|
|
annotated_frame, |
|
|
(20, annotated_frame.shape[0] - 60), |
|
|
(annotated_frame.shape[1] - 20, annotated_frame.shape[0] - 20), |
|
|
(255, 255, 255), |
|
|
-1, |
|
|
) |
|
|
cv2.putText( |
|
|
annotated_frame, |
|
|
hud_text, |
|
|
(30, annotated_frame.shape[0] - 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.8, |
|
|
(0, 0, 0), |
|
|
2, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
|
|
|
|
|
|
if event_text_frames_left > 0 and current_event_text: |
|
|
cv2.rectangle( |
|
|
annotated_frame, |
|
|
(20, 20), |
|
|
(annotated_frame.shape[1] - 20, 90), |
|
|
(255, 255, 255), |
|
|
-1, |
|
|
) |
|
|
cv2.putText( |
|
|
annotated_frame, |
|
|
current_event_text, |
|
|
(30, 70), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
1.0, |
|
|
(0, 0, 0), |
|
|
2, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
event_text_frames_left -= 1 |
|
|
|
|
|
out.write(annotated_frame) |
|
|
|
|
|
cap.release() |
|
|
out.release() |
|
|
print(f"โ
Processed {frame_count} frames") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.65, desc="๐งน Cleaning ball trajectory (Step 4/6)...") |
|
|
|
|
|
|
|
|
path_for_cleaning = [] |
|
|
for coords in ball_path_raw: |
|
|
if len(coords) == 0: |
|
|
path_for_cleaning.append(np.empty((0, 2), dtype=np.float32)) |
|
|
elif coords.shape[0] >= 2: |
|
|
|
|
|
path_for_cleaning.append(np.empty((0, 2), dtype=np.float32)) |
|
|
else: |
|
|
path_for_cleaning.append(coords) |
|
|
|
|
|
|
|
|
cleaned_path = replace_outliers_based_on_distance( |
|
|
[ |
|
|
np.array(p).reshape(-1, 2) if len(p) > 0 else np.empty((0, 2)) |
|
|
for p in path_for_cleaning |
|
|
], |
|
|
MAX_DISTANCE_THRESHOLD, |
|
|
) |
|
|
|
|
|
print( |
|
|
f"โ
Ball path cleaned: " |
|
|
f"{len([p for p in cleaned_path if len(p) > 0])} valid points" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.75, desc="๐ Generating performance analytics (Step 5/6)...") |
|
|
|
|
|
|
|
|
comparison_fig = create_team_comparison_plot(performance_tracker) |
|
|
|
|
|
|
|
|
team_heatmaps_path = "/tmp/team_heatmaps.png" |
|
|
team_heatmaps = create_combined_heatmaps(performance_tracker) |
|
|
cv2.imwrite(team_heatmaps_path, team_heatmaps) |
|
|
|
|
|
|
|
|
progress(0.85, desc="๐บ๏ธ Creating individual heatmaps...") |
|
|
teams = performance_tracker.get_all_players_by_team() |
|
|
top_players = [] |
|
|
|
|
|
for team_id in [0, 1]: |
|
|
if team_id in teams: |
|
|
team_players = teams[team_id] |
|
|
player_distances = [ |
|
|
(pid, performance_tracker.get_player_stats(pid)["total_distance_meters"]) |
|
|
for pid in team_players |
|
|
] |
|
|
player_distances.sort(key=lambda x: x[1], reverse=True) |
|
|
top_players.extend([pid for pid, _ in player_distances[:3]]) |
|
|
|
|
|
individual_heatmaps = [] |
|
|
for pid in top_players[:6]: |
|
|
heatmap = create_player_heatmap_visualization(performance_tracker, pid) |
|
|
individual_heatmaps.append(heatmap) |
|
|
|
|
|
|
|
|
if len(individual_heatmaps) > 0: |
|
|
rows = [] |
|
|
for i in range(0, len(individual_heatmaps), 3): |
|
|
row_maps = individual_heatmaps[i : i + 3] |
|
|
if len(row_maps) == 3: |
|
|
rows.append(np.hstack(row_maps)) |
|
|
elif len(row_maps) == 2: |
|
|
rows.append(np.hstack([row_maps[0], row_maps[1]])) |
|
|
else: |
|
|
rows.append(row_maps[0]) |
|
|
|
|
|
individual_grid = np.vstack(rows) if len(rows) > 1 else rows[0] |
|
|
individual_heatmaps_path = "/tmp/individual_heatmaps.png" |
|
|
cv2.imwrite(individual_heatmaps_path, individual_grid) |
|
|
else: |
|
|
individual_heatmaps_path = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.9, desc="๐บ๏ธ Creating game-style radar view (Step 6/6)...") |
|
|
radar_path = "/tmp/radar_view_enhanced.png" |
|
|
try: |
|
|
if last_pitch_players_xy is not None: |
|
|
radar_frame = create_game_style_radar( |
|
|
pitch_ball_xy=cleaned_path[-1] |
|
|
if cleaned_path |
|
|
else np.empty((0, 2)), |
|
|
pitch_players_xy=last_pitch_players_xy, |
|
|
players_class_id=last_players_class_id, |
|
|
pitch_referees_xy=last_pitch_referees_xy, |
|
|
ball_path=cleaned_path, |
|
|
) |
|
|
cv2.imwrite(radar_path, radar_frame) |
|
|
else: |
|
|
radar_path = None |
|
|
except Exception as e: |
|
|
print(f"โ ๏ธ Radar view creation failed: {e}") |
|
|
radar_path = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_poss = sum(possession_time_team.values()) + 1e-6 |
|
|
|
|
|
player_stats_table = [] |
|
|
for team_id, player_ids in teams.items(): |
|
|
for pid in player_ids: |
|
|
stats = performance_tracker.get_player_stats(pid) |
|
|
poss_s = float(possession_time_player.get(pid, 0.0)) |
|
|
poss_pct = 100.0 * poss_s / total_poss if total_poss > 0 else 0.0 |
|
|
|
|
|
row = [ |
|
|
int(pid), |
|
|
int(stats["team_id"]), |
|
|
float(stats["total_distance_meters"]), |
|
|
float(stats["avg_velocity"]), |
|
|
float(stats["max_velocity"]), |
|
|
int(stats["frames_visible"]), |
|
|
int(stats["time_in_defensive_third"]), |
|
|
int(stats["time_in_middle_third"]), |
|
|
int(stats["time_in_attacking_third"]), |
|
|
poss_s, |
|
|
poss_pct, |
|
|
] |
|
|
player_stats_table.append(row) |
|
|
|
|
|
events_table = [] |
|
|
for ev in events: |
|
|
ev_type = ev.get("type", "") |
|
|
t_ev = float(ev.get("t", 0.0)) |
|
|
team_id = ev.get("team_id", None) |
|
|
from_tid = ev.get("from_tid", None) |
|
|
to_tid = ev.get("to_tid", None) |
|
|
extra = ev.get("extra", {}) or {} |
|
|
|
|
|
speed_kmh = float(extra.get("speed_kmh", 0.0)) |
|
|
ball_dist_m = float(extra.get("distance_m", extra.get("ball_travel_m", 0.0))) |
|
|
player_dist_m = float(extra.get("player_distance_m", 0.0)) |
|
|
|
|
|
if ev_type == "pass": |
|
|
desc = f"Pass #{from_tid} โ #{to_tid} (Team {team_id})" |
|
|
elif ev_type == "tackle": |
|
|
desc = ( |
|
|
f"Tackle: #{to_tid} wins ball from #{from_tid} " |
|
|
f"(Team {team_id})" |
|
|
) |
|
|
elif ev_type == "interception": |
|
|
desc = ( |
|
|
f"Interception: #{to_tid} intercepts #{from_tid} " |
|
|
f"(Team {team_id})" |
|
|
) |
|
|
elif ev_type == "shot": |
|
|
desc = ( |
|
|
f"Shot by #{from_tid} (Team {team_id}) at {speed_kmh:.1f} km/h" |
|
|
) |
|
|
elif ev_type == "clearance": |
|
|
desc = f"Clearance by #{from_tid} (Team {team_id})" |
|
|
else: |
|
|
desc = ev_type |
|
|
|
|
|
row = [ |
|
|
t_ev, |
|
|
ev_type, |
|
|
team_id, |
|
|
from_tid, |
|
|
to_tid, |
|
|
speed_kmh, |
|
|
ball_dist_m, |
|
|
player_dist_m, |
|
|
desc, |
|
|
] |
|
|
events_table.append(row) |
|
|
|
|
|
events_json_path = "/tmp/events.json" |
|
|
with open(events_json_path, "w", encoding="utf-8") as f: |
|
|
json.dump(events, f, indent=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.95, desc="๐ Generating summary report...") |
|
|
|
|
|
summary_lines = ["โ
**Analysis Complete!**\n"] |
|
|
summary_lines.append("**Video Statistics:**") |
|
|
summary_lines.append(f"- Total Frames Processed: {frame_count}") |
|
|
summary_lines.append(f"- Video Resolution: {width}x{height}") |
|
|
summary_lines.append(f"- Frame Rate: {fps:.2f} fps") |
|
|
summary_lines.append( |
|
|
f"- Ball Trajectory Points: " |
|
|
f"{len([p for p in cleaned_path if len(p) > 0])}\n" |
|
|
) |
|
|
|
|
|
for team_id in [0, 1]: |
|
|
if team_id not in teams: |
|
|
continue |
|
|
|
|
|
team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)" |
|
|
summary_lines.append(f"\n**{team_name}:**") |
|
|
summary_lines.append(f"- Players Tracked: {len(teams[team_id])}") |
|
|
|
|
|
total_dist = sum( |
|
|
performance_tracker.get_player_stats(pid)["total_distance_meters"] |
|
|
for pid in teams[team_id] |
|
|
) |
|
|
avg_dist = total_dist / len(teams[team_id]) if len(teams[team_id]) > 0 else 0 |
|
|
summary_lines.append(f"- Team Total Distance: {total_dist:.1f} m") |
|
|
summary_lines.append( |
|
|
f"- Average Distance per Player: {avg_dist:.1f} m" |
|
|
) |
|
|
|
|
|
|
|
|
player_distances = [ |
|
|
(pid, performance_tracker.get_player_stats(pid)["total_distance_meters"]) |
|
|
for pid in teams[team_id] |
|
|
] |
|
|
player_distances.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
summary_lines.append("\n **Top 3 Performers:**") |
|
|
for i, (pid, dist) in enumerate(player_distances[:3], 1): |
|
|
stats = performance_tracker.get_player_stats(pid) |
|
|
summary_lines.append( |
|
|
f" {i}. Player #{pid}: {dist:.1f} m, " |
|
|
f"Avg: {stats['avg_velocity']:.2f} km/h, " |
|
|
f"Max: {stats['max_velocity']:.2f} km/h" |
|
|
) |
|
|
|
|
|
|
|
|
summary_lines.append("\n**Team Possession:**") |
|
|
for team_id in sorted(possession_time_team.keys()): |
|
|
t_sec = possession_time_team[team_id] |
|
|
pct = 100.0 * t_sec / total_poss if total_poss > 0 else 0.0 |
|
|
summary_lines.append(f"- Team {team_id}: {t_sec:.1f} s ({pct:.1f}%)") |
|
|
|
|
|
summary_lines.append("\n**Pipeline Steps Completed:**") |
|
|
summary_lines.append("โ
1. Player crop collection") |
|
|
summary_lines.append("โ
2. Team classifier training") |
|
|
summary_lines.append("โ
3. Video processing with tracking & events") |
|
|
summary_lines.append("โ
4. Ball trajectory cleaning") |
|
|
summary_lines.append("โ
5. Performance analytics generation") |
|
|
summary_lines.append("โ
6. Visualization creation") |
|
|
|
|
|
summary_msg = "\n".join(summary_lines) |
|
|
|
|
|
progress(1.0, desc="โ
Analysis Complete!") |
|
|
|
|
|
|
|
|
return ( |
|
|
output_path, |
|
|
comparison_fig, |
|
|
team_heatmaps_path, |
|
|
individual_heatmaps_path, |
|
|
radar_path, |
|
|
summary_msg, |
|
|
player_stats_table, |
|
|
events_table, |
|
|
events_json_path, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"โ Error: {str(e)}" |
|
|
print(error_msg) |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
|
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
error_msg, |
|
|
[], |
|
|
[], |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_pipeline(video) -> Tuple: |
|
|
""" |
|
|
Gradio wrapper: accept the raw video object from gr.Video and |
|
|
convert it to a filesystem path for analyze_football_video(). |
|
|
""" |
|
|
if video is None: |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
"โ Please upload a video file.", |
|
|
[], |
|
|
[], |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(video, dict): |
|
|
video_path = ( |
|
|
video.get("path") |
|
|
or video.get("name") |
|
|
or video.get("filename") |
|
|
) |
|
|
else: |
|
|
|
|
|
video_path = str(video) |
|
|
|
|
|
if not video_path: |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
"โ Could not resolve video file path from upload.", |
|
|
[], |
|
|
[], |
|
|
None, |
|
|
) |
|
|
|
|
|
return analyze_football_video(video_path) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="โฝ Football Performance Analyzer", theme=gr.themes.Soft()) as iface: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# โฝ Advanced Football Video Analyzer |
|
|
### Complete Pipeline Implementation |
|
|
|
|
|
This application: |
|
|
1. **Player Detection** - Collect player crops using Roboflow |
|
|
2. **Team Classification** - Train SigLIP-based team classifier |
|
|
3. **Persistent Tracking** - ByteTrack with stable ID assignment |
|
|
4. **Field Transformation** - Project players onto pitch coordinates |
|
|
5. **Ball Trajectory** - Track and clean ball path with outlier removal |
|
|
6. **Performance Analytics** - Heatmaps, stats, possession, and event detection |
|
|
|
|
|
Upload a football match video to get comprehensive performance analytics! |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
video_input = gr.Video(label="๐ค Upload Football Video") |
|
|
|
|
|
analyze_btn = gr.Button("๐ Start Analysis Pipeline", variant="primary", size="lg") |
|
|
|
|
|
with gr.Row(): |
|
|
status_output = gr.Textbox(label="๐ Analysis Summary & Statistics", lines=25) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("๐น Annotated Video"): |
|
|
gr.Markdown( |
|
|
"### Full video with player tracking, team colors, ball detection, and events overlay" |
|
|
) |
|
|
video_output = gr.Video(label="Processed Video") |
|
|
|
|
|
with gr.Tab("๐ Performance Comparison"): |
|
|
gr.Markdown("### Interactive charts comparing player performance metrics") |
|
|
comparison_output = gr.Plot(label="Team Performance Metrics") |
|
|
|
|
|
with gr.Tab("๐บ๏ธ Team Heatmaps"): |
|
|
gr.Markdown("### Combined activity heatmaps showing team positioning") |
|
|
team_heatmaps_output = gr.Image(label="Team Activity Heatmaps") |
|
|
|
|
|
with gr.Tab("๐ค Individual Heatmaps"): |
|
|
gr.Markdown("### Top 6 players with detailed activity analysis") |
|
|
individual_heatmaps_output = gr.Image(label="Top Players Heatmaps") |
|
|
|
|
|
with gr.Tab("๐ฎ Game Radar View"): |
|
|
gr.Markdown("### Game-style tactical view with ball trail") |
|
|
radar_output = gr.Image(label="Tactical Radar View") |
|
|
|
|
|
with gr.Tab("๐ Player Stats"): |
|
|
gr.Markdown("### Per-player totals: distance, speeds, zones, possession") |
|
|
player_stats_output = gr.Dataframe( |
|
|
headers=PLAYER_STATS_HEADERS, |
|
|
col_count=len(PLAYER_STATS_HEADERS), |
|
|
row_count=0, |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
with gr.Tab("โฑ๏ธ Event Timeline"): |
|
|
gr.Markdown( |
|
|
"### Detected passes, tackles, interceptions, shots, clearances" |
|
|
) |
|
|
events_output = gr.Dataframe( |
|
|
headers=EVENT_HEADERS, |
|
|
col_count=len(EVENT_HEADERS), |
|
|
row_count=0, |
|
|
interactive=False, |
|
|
) |
|
|
events_json_output = gr.File( |
|
|
label="Download events JSON", |
|
|
file_types=[".json"], |
|
|
) |
|
|
|
|
|
analyze_btn.click( |
|
|
fn=run_pipeline, |
|
|
inputs=[video_input], |
|
|
outputs=[ |
|
|
video_output, |
|
|
comparison_output, |
|
|
team_heatmaps_output, |
|
|
individual_heatmaps_output, |
|
|
radar_output, |
|
|
status_output, |
|
|
player_stats_output, |
|
|
events_output, |
|
|
events_json_output, |
|
|
], |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### ๐ง Technical Details: |
|
|
|
|
|
**Detection Models:** |
|
|
- Player/Ball/Referee Detection: `football-players-detection-3zvbc/11` |
|
|
- Field Keypoint Detection: `football-field-detection-f07vi/14` |
|
|
|
|
|
**Tracking & Classification:** |
|
|
- ByteTrack for persistent player IDs |
|
|
- SigLIP embeddings for team classification |
|
|
- Majority voting for stable team assignments |
|
|
|
|
|
**Performance Metrics:** |
|
|
- Distance covered (meters) |
|
|
- Average & maximum speed (km/h) |
|
|
- Zone activity (defensive/middle/attacking thirds) |
|
|
- Position heatmaps with Gaussian smoothing |
|
|
- Possession per player & per team |
|
|
|
|
|
**Ball Tracking:** |
|
|
- Field homography transformation |
|
|
- Outlier removal (500 cm threshold) |
|
|
- Transformation matrix smoothing (5-frame window) |
|
|
|
|
|
**Events:** |
|
|
- Passes, tackles, interceptions, shots, clearances |
|
|
- Event banner overlay in video |
|
|
- Full event list downloadable as JSON |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|