Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import deque
|
| 3 |
+
from typing import Optional, List
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
import base64
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
import supervision as sv
|
| 15 |
+
from inference_sdk import get_model # Modified import
|
| 16 |
+
from sports.common.team import TeamClassifier
|
| 17 |
+
from sports.common.view import ViewTransformer
|
| 18 |
+
from sports.annotators.soccer import draw_pitch, draw_points_on_pitch
|
| 19 |
+
from sports.configs.soccer import SoccerPitchConfiguration
|
| 20 |
+
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import plotly.graph_objects as go
|
| 23 |
+
from transformers import AutoProcessor, SiglipVisionModel
|
| 24 |
+
from more_itertools import chunked
|
| 25 |
+
from sklearn.cluster import KMeans
|
| 26 |
+
import umap
|
| 27 |
+
|
| 28 |
+
# ==============================================
|
| 29 |
+
# Environment variables
|
| 30 |
+
# ==============================================
|
| 31 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 32 |
+
ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY")
|
| 33 |
+
|
| 34 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 35 |
+
|
| 36 |
+
# ==============================================
|
| 37 |
+
# Load Detection Models
|
| 38 |
+
# ==============================================
|
| 39 |
+
PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
|
| 40 |
+
PLAYER_DETECTION_MODEL = get_model(model_id=PLAYER_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
|
| 41 |
+
|
| 42 |
+
FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
|
| 43 |
+
FIELD_DETECTION_MODEL = get_model(model_id=FIELD_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
|
| 44 |
+
|
| 45 |
+
team_classifier = TeamClassifier(device=DEVICE)
|
| 46 |
+
CONFIG = SoccerPitchConfiguration()
|
| 47 |
+
|
| 48 |
+
# ==============================================
|
| 49 |
+
# Load SigLIP Model
|
| 50 |
+
# ==============================================
|
| 51 |
+
SIGLIP_MODEL_PATH = 'google/siglip-base-patch16-224'
|
| 52 |
+
EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH).to(DEVICE)
|
| 53 |
+
EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
|
| 54 |
+
|
| 55 |
+
# ==============================================
|
| 56 |
+
# Helper Functions
|
| 57 |
+
# ==============================================
|
| 58 |
+
|
| 59 |
+
def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detections) -> np.ndarray:
|
| 60 |
+
goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
|
| 61 |
+
players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
|
| 62 |
+
team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
|
| 63 |
+
team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
|
| 64 |
+
return np.array([0 if np.linalg.norm(gk - team_0_centroid) < np.linalg.norm(gk - team_1_centroid) else 1 for gk in goalkeepers_xy])
|
| 65 |
+
|
| 66 |
+
def pil_image_to_data_uri(image: Image.Image) -> str:
|
| 67 |
+
buffered = BytesIO()
|
| 68 |
+
image.save(buffered, format="PNG")
|
| 69 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 70 |
+
return f"data:image/png;base64,{img_str}"
|
| 71 |
+
|
| 72 |
+
def create_umap_3d_plot(crops: List[Image.Image]):
|
| 73 |
+
BATCH_SIZE = 32
|
| 74 |
+
crops = [sv.cv2_to_pillow(crop) for crop in crops]
|
| 75 |
+
batches = list(chunked(crops, BATCH_SIZE))
|
| 76 |
+
data = []
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
for batch in tqdm(batches, desc='embedding extraction'):
|
| 79 |
+
inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(DEVICE)
|
| 80 |
+
outputs = EMBEDDINGS_MODEL(**inputs)
|
| 81 |
+
embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
|
| 82 |
+
data.append(embeddings)
|
| 83 |
+
data = np.concatenate(data)
|
| 84 |
+
|
| 85 |
+
# UMAP and clustering
|
| 86 |
+
REDUCER = umap.UMAP(n_components=3)
|
| 87 |
+
CLUSTERING_MODEL = KMeans(n_clusters=2)
|
| 88 |
+
projections = REDUCER.fit_transform(data)
|
| 89 |
+
clusters = CLUSTERING_MODEL.fit_predict(projections)
|
| 90 |
+
|
| 91 |
+
# Prepare image data URIs
|
| 92 |
+
image_data_uris = {f"image_{i}": pil_image_to_data_uri(image) for i, image in enumerate(crops)}
|
| 93 |
+
image_ids = np.array([f"image_{i}" for i in range(len(crops))])
|
| 94 |
+
|
| 95 |
+
# Plotly 3D scatter
|
| 96 |
+
traces = []
|
| 97 |
+
for lbl in np.unique(clusters):
|
| 98 |
+
mask = clusters == lbl
|
| 99 |
+
trace = go.Scatter3d(
|
| 100 |
+
x=projections[mask][:,0],
|
| 101 |
+
y=projections[mask][:,1],
|
| 102 |
+
z=projections[mask][:,2],
|
| 103 |
+
mode='markers+text',
|
| 104 |
+
text=clusters[mask],
|
| 105 |
+
customdata=image_ids[mask],
|
| 106 |
+
name=str(lbl),
|
| 107 |
+
marker=dict(size=6),
|
| 108 |
+
hovertemplate="<b>Cluster: %{text}</b><br>Image ID: %{customdata}<extra></extra>"
|
| 109 |
+
)
|
| 110 |
+
traces.append(trace)
|
| 111 |
+
|
| 112 |
+
fig = go.Figure(data=traces)
|
| 113 |
+
fig.update_layout(width=800, height=800)
|
| 114 |
+
return fig, image_data_uris
|
| 115 |
+
|
| 116 |
+
# ==============================================
|
| 117 |
+
# Main Video Processing
|
| 118 |
+
# ==============================================
|
| 119 |
+
def analyze_football_video(video_path: str):
|
| 120 |
+
BALL_ID, GOALKEEPER_ID, PLAYER_ID, REFEREE_ID = 0,1,2,3
|
| 121 |
+
MAXLEN = 5
|
| 122 |
+
M = deque(maxlen=MAXLEN)
|
| 123 |
+
path_raw = []
|
| 124 |
+
|
| 125 |
+
# Annotators
|
| 126 |
+
ellipse_annotator = sv.EllipseAnnotator(color=sv.ColorPalette.from_hex(['#00BFFF','#FF1493','#FFD700']), thickness=2)
|
| 127 |
+
label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.from_hex(['#00BFFF','#FF1493','#FFD700']), text_color=sv.Color.from_hex('#000000'))
|
| 128 |
+
triangle_annotator = sv.TriangleAnnotator(color=sv.Color.from_hex('#FFD700'), base=20, height=17)
|
| 129 |
+
tracker = sv.ByteTrack()
|
| 130 |
+
tracker.reset()
|
| 131 |
+
|
| 132 |
+
cap = cv2.VideoCapture(video_path)
|
| 133 |
+
width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 134 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 135 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 136 |
+
out = cv2.VideoWriter("/tmp/annotated_video.mp4", fourcc, fps, (width, height))
|
| 137 |
+
|
| 138 |
+
player_crops = []
|
| 139 |
+
|
| 140 |
+
while True:
|
| 141 |
+
ret, frame = cap.read()
|
| 142 |
+
if not ret: break
|
| 143 |
+
|
| 144 |
+
# Player/ball/referee detection
|
| 145 |
+
result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
|
| 146 |
+
detections = sv.Detections.from_inference(result)
|
| 147 |
+
ball_detections = detections[detections.class_id==BALL_ID]
|
| 148 |
+
ball_detections.xyxy = sv.pad_boxes(ball_detections.xyxy, px=10)
|
| 149 |
+
all_detections = detections[detections.class_id != BALL_ID].with_nms(threshold=0.5, class_agnostic=True)
|
| 150 |
+
all_detections = tracker.update_with_detections(all_detections)
|
| 151 |
+
|
| 152 |
+
goalkeepers_detections = all_detections[all_detections.class_id==GOALKEEPER_ID]
|
| 153 |
+
players_detections = all_detections[all_detections.class_id==PLAYER_ID]
|
| 154 |
+
referees_detections = all_detections[all_detections.class_id==REFEREE_ID]
|
| 155 |
+
|
| 156 |
+
# Team classification
|
| 157 |
+
if len(players_detections.xyxy) > 0:
|
| 158 |
+
crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
|
| 159 |
+
players_detections.class_id = team_classifier.predict(crops)
|
| 160 |
+
player_crops += crops
|
| 161 |
+
|
| 162 |
+
goalkeepers_detections.class_id = resolve_goalkeepers_team_id(players_detections, goalkeepers_detections)
|
| 163 |
+
referees_detections.class_id -= 1
|
| 164 |
+
all_detections = sv.Detections.merge([players_detections, goalkeepers_detections, referees_detections])
|
| 165 |
+
labels = [f"#{tid}" for tid in all_detections.tracker_id]
|
| 166 |
+
all_detections.class_id = all_detections.class_id.astype(int)
|
| 167 |
+
|
| 168 |
+
# Annotate frame
|
| 169 |
+
annotated_frame = frame.copy()
|
| 170 |
+
annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections)
|
| 171 |
+
annotated_frame = label_annotator.annotate(annotated_frame, all_detections, labels=labels)
|
| 172 |
+
annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections)
|
| 173 |
+
out.write(annotated_frame)
|
| 174 |
+
|
| 175 |
+
# Field detection and projection
|
| 176 |
+
result_field = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
|
| 177 |
+
key_points = sv.KeyPoints.from_inference(result_field)
|
| 178 |
+
filter = key_points.confidence[0] > 0.5
|
| 179 |
+
frame_ref_points = key_points.xy[0][filter]
|
| 180 |
+
pitch_ref_points = np.array(CONFIG.vertices)[filter]
|
| 181 |
+
transformer = ViewTransformer(source=frame_ref_points, target=pitch_ref_points)
|
| 182 |
+
M.append(transformer.m)
|
| 183 |
+
transformer.m = np.mean(np.array(M), axis=0)
|
| 184 |
+
|
| 185 |
+
# Ball & players projected
|
| 186 |
+
pitch_ball_xy = transformer.transform_points(ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER))
|
| 187 |
+
pitch_players_xy = transformer.transform_points(players_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER))
|
| 188 |
+
|
| 189 |
+
cap.release()
|
| 190 |
+
out.release()
|
| 191 |
+
|
| 192 |
+
# Create UMAP 3D plot
|
| 193 |
+
umap_fig, image_data_uris = create_umap_3d_plot(player_crops)
|
| 194 |
+
|
| 195 |
+
return "/tmp/annotated_video.mp4", umap_fig
|
| 196 |
+
|
| 197 |
+
# ==============================================
|
| 198 |
+
# Gradio Interface
|
| 199 |
+
# ==============================================
|
| 200 |
+
iface = gr.Interface(
|
| 201 |
+
fn=analyze_football_video,
|
| 202 |
+
inputs=gr.Video(label="Upload Football Video"),
|
| 203 |
+
outputs=[gr.Video(label="Annotated Video"), gr.Plot(label="3D Player Embeddings")],
|
| 204 |
+
title="Football Video Analyzer with SigLIP Player Embeddings"
|
| 205 |
+
)
|