Simon9 commited on
Commit
e1b8afc
·
verified ·
1 Parent(s): b9cd6da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import deque
3
+ from typing import Optional, List
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ import cv2
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ import supervision as sv
15
+ from inference_sdk import get_model # Modified import
16
+ from sports.common.team import TeamClassifier
17
+ from sports.common.view import ViewTransformer
18
+ from sports.annotators.soccer import draw_pitch, draw_points_on_pitch
19
+ from sports.configs.soccer import SoccerPitchConfiguration
20
+
21
+ import gradio as gr
22
+ import plotly.graph_objects as go
23
+ from transformers import AutoProcessor, SiglipVisionModel
24
+ from more_itertools import chunked
25
+ from sklearn.cluster import KMeans
26
+ import umap
27
+
28
+ # ==============================================
29
+ # Environment variables
30
+ # ==============================================
31
+ HF_TOKEN = os.environ.get("HF_TOKEN")
32
+ ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY")
33
+
34
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
35
+
36
+ # ==============================================
37
+ # Load Detection Models
38
+ # ==============================================
39
+ PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
40
+ PLAYER_DETECTION_MODEL = get_model(model_id=PLAYER_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
41
+
42
+ FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
43
+ FIELD_DETECTION_MODEL = get_model(model_id=FIELD_DETECTION_MODEL_ID, api_key=ROBOFLOW_API_KEY)
44
+
45
+ team_classifier = TeamClassifier(device=DEVICE)
46
+ CONFIG = SoccerPitchConfiguration()
47
+
48
+ # ==============================================
49
+ # Load SigLIP Model
50
+ # ==============================================
51
+ SIGLIP_MODEL_PATH = 'google/siglip-base-patch16-224'
52
+ EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH).to(DEVICE)
53
+ EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
54
+
55
+ # ==============================================
56
+ # Helper Functions
57
+ # ==============================================
58
+
59
+ def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detections) -> np.ndarray:
60
+ goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
61
+ players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
62
+ team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
63
+ team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
64
+ return np.array([0 if np.linalg.norm(gk - team_0_centroid) < np.linalg.norm(gk - team_1_centroid) else 1 for gk in goalkeepers_xy])
65
+
66
+ def pil_image_to_data_uri(image: Image.Image) -> str:
67
+ buffered = BytesIO()
68
+ image.save(buffered, format="PNG")
69
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
70
+ return f"data:image/png;base64,{img_str}"
71
+
72
+ def create_umap_3d_plot(crops: List[Image.Image]):
73
+ BATCH_SIZE = 32
74
+ crops = [sv.cv2_to_pillow(crop) for crop in crops]
75
+ batches = list(chunked(crops, BATCH_SIZE))
76
+ data = []
77
+ with torch.no_grad():
78
+ for batch in tqdm(batches, desc='embedding extraction'):
79
+ inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(DEVICE)
80
+ outputs = EMBEDDINGS_MODEL(**inputs)
81
+ embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
82
+ data.append(embeddings)
83
+ data = np.concatenate(data)
84
+
85
+ # UMAP and clustering
86
+ REDUCER = umap.UMAP(n_components=3)
87
+ CLUSTERING_MODEL = KMeans(n_clusters=2)
88
+ projections = REDUCER.fit_transform(data)
89
+ clusters = CLUSTERING_MODEL.fit_predict(projections)
90
+
91
+ # Prepare image data URIs
92
+ image_data_uris = {f"image_{i}": pil_image_to_data_uri(image) for i, image in enumerate(crops)}
93
+ image_ids = np.array([f"image_{i}" for i in range(len(crops))])
94
+
95
+ # Plotly 3D scatter
96
+ traces = []
97
+ for lbl in np.unique(clusters):
98
+ mask = clusters == lbl
99
+ trace = go.Scatter3d(
100
+ x=projections[mask][:,0],
101
+ y=projections[mask][:,1],
102
+ z=projections[mask][:,2],
103
+ mode='markers+text',
104
+ text=clusters[mask],
105
+ customdata=image_ids[mask],
106
+ name=str(lbl),
107
+ marker=dict(size=6),
108
+ hovertemplate="<b>Cluster: %{text}</b><br>Image ID: %{customdata}<extra></extra>"
109
+ )
110
+ traces.append(trace)
111
+
112
+ fig = go.Figure(data=traces)
113
+ fig.update_layout(width=800, height=800)
114
+ return fig, image_data_uris
115
+
116
+ # ==============================================
117
+ # Main Video Processing
118
+ # ==============================================
119
+ def analyze_football_video(video_path: str):
120
+ BALL_ID, GOALKEEPER_ID, PLAYER_ID, REFEREE_ID = 0,1,2,3
121
+ MAXLEN = 5
122
+ M = deque(maxlen=MAXLEN)
123
+ path_raw = []
124
+
125
+ # Annotators
126
+ ellipse_annotator = sv.EllipseAnnotator(color=sv.ColorPalette.from_hex(['#00BFFF','#FF1493','#FFD700']), thickness=2)
127
+ label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.from_hex(['#00BFFF','#FF1493','#FFD700']), text_color=sv.Color.from_hex('#000000'))
128
+ triangle_annotator = sv.TriangleAnnotator(color=sv.Color.from_hex('#FFD700'), base=20, height=17)
129
+ tracker = sv.ByteTrack()
130
+ tracker.reset()
131
+
132
+ cap = cv2.VideoCapture(video_path)
133
+ width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
134
+ fps = cap.get(cv2.CAP_PROP_FPS)
135
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
136
+ out = cv2.VideoWriter("/tmp/annotated_video.mp4", fourcc, fps, (width, height))
137
+
138
+ player_crops = []
139
+
140
+ while True:
141
+ ret, frame = cap.read()
142
+ if not ret: break
143
+
144
+ # Player/ball/referee detection
145
+ result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
146
+ detections = sv.Detections.from_inference(result)
147
+ ball_detections = detections[detections.class_id==BALL_ID]
148
+ ball_detections.xyxy = sv.pad_boxes(ball_detections.xyxy, px=10)
149
+ all_detections = detections[detections.class_id != BALL_ID].with_nms(threshold=0.5, class_agnostic=True)
150
+ all_detections = tracker.update_with_detections(all_detections)
151
+
152
+ goalkeepers_detections = all_detections[all_detections.class_id==GOALKEEPER_ID]
153
+ players_detections = all_detections[all_detections.class_id==PLAYER_ID]
154
+ referees_detections = all_detections[all_detections.class_id==REFEREE_ID]
155
+
156
+ # Team classification
157
+ if len(players_detections.xyxy) > 0:
158
+ crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
159
+ players_detections.class_id = team_classifier.predict(crops)
160
+ player_crops += crops
161
+
162
+ goalkeepers_detections.class_id = resolve_goalkeepers_team_id(players_detections, goalkeepers_detections)
163
+ referees_detections.class_id -= 1
164
+ all_detections = sv.Detections.merge([players_detections, goalkeepers_detections, referees_detections])
165
+ labels = [f"#{tid}" for tid in all_detections.tracker_id]
166
+ all_detections.class_id = all_detections.class_id.astype(int)
167
+
168
+ # Annotate frame
169
+ annotated_frame = frame.copy()
170
+ annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections)
171
+ annotated_frame = label_annotator.annotate(annotated_frame, all_detections, labels=labels)
172
+ annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections)
173
+ out.write(annotated_frame)
174
+
175
+ # Field detection and projection
176
+ result_field = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
177
+ key_points = sv.KeyPoints.from_inference(result_field)
178
+ filter = key_points.confidence[0] > 0.5
179
+ frame_ref_points = key_points.xy[0][filter]
180
+ pitch_ref_points = np.array(CONFIG.vertices)[filter]
181
+ transformer = ViewTransformer(source=frame_ref_points, target=pitch_ref_points)
182
+ M.append(transformer.m)
183
+ transformer.m = np.mean(np.array(M), axis=0)
184
+
185
+ # Ball & players projected
186
+ pitch_ball_xy = transformer.transform_points(ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER))
187
+ pitch_players_xy = transformer.transform_points(players_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER))
188
+
189
+ cap.release()
190
+ out.release()
191
+
192
+ # Create UMAP 3D plot
193
+ umap_fig, image_data_uris = create_umap_3d_plot(player_crops)
194
+
195
+ return "/tmp/annotated_video.mp4", umap_fig
196
+
197
+ # ==============================================
198
+ # Gradio Interface
199
+ # ==============================================
200
+ iface = gr.Interface(
201
+ fn=analyze_football_video,
202
+ inputs=gr.Video(label="Upload Football Video"),
203
+ outputs=[gr.Video(label="Annotated Video"), gr.Plot(label="3D Player Embeddings")],
204
+ title="Football Video Analyzer with SigLIP Player Embeddings"
205
+ )