Simon9 commited on
Commit
3d11bd1
ยท
verified ยท
1 Parent(s): 655c5fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +759 -403
app.py CHANGED
@@ -1,463 +1,819 @@
1
  import os
 
 
 
 
2
 
3
- # Suppress inference model warnings (must be at the very top)
4
- os.environ["CORE_MODEL_SAM_ENABLED"] = "False"
5
- os.environ["CORE_MODEL_SAM2_ENABLED"] = "False"
6
- os.environ["CORE_MODEL_SAM3_ENABLED"] = "False"
7
- os.environ["CORE_MODEL_GAZE_ENABLED"] = "False"
8
- os.environ["CORE_MODEL_GROUNDINGDINO_ENABLED"] = "False"
9
- os.environ["CORE_MODEL_YOLO_WORLD_ENABLED"] = "False"
 
 
 
 
 
10
 
11
  import gradio as gr
12
- import uuid
13
- import shutil
14
- import json
15
- import time
16
- from pathlib import Path
17
- from PIL import Image
18
 
19
- # Import your existing pipeline
20
- from pipeline_full import run_full_pipeline
21
 
22
- BASE_RESULTS_DIR = "jobs"
23
- Path(BASE_RESULTS_DIR).mkdir(exist_ok=True)
 
 
 
24
 
 
 
25
 
26
- def monitor_progress(job_dir):
27
- """
28
- Monitor the status.json file and yield progress updates
29
- """
30
- status_file = os.path.join(job_dir, "status.json")
31
-
32
- last_progress = 0
33
- last_stage = ""
34
-
35
- # Keep checking until done or error
36
- max_wait = 300 # 5 minutes max
37
- start_time = time.time()
38
-
39
- while (time.time() - start_time) < max_wait:
40
- if os.path.exists(status_file):
41
- try:
42
- with open(status_file, 'r') as f:
43
- status = json.load(f)
44
-
45
- stage = status.get('stage', '')
46
- progress = status.get('progress', 0)
47
- message = status.get('message', '')
48
-
49
- # Only yield if something changed
50
- if stage != last_stage or progress != last_progress:
51
- last_stage = stage
52
- last_progress = progress
53
-
54
- # Map stages to friendly names with emojis
55
- stage_map = {
56
- 'initializing': '๐Ÿš€ Initializing',
57
- 'siglip': '๐Ÿค– SigLIP Clustering',
58
- 'team_classifier': '๐Ÿ‘• Team Classification',
59
- 'basic_frames': '๐ŸŽฏ Basic Detection',
60
- 'advanced_views': '๐ŸŽจ Tactical Views',
61
- 'ball_path': 'โšฝ Ball Tracking',
62
- 'stats': '๐Ÿ“Š Statistics',
63
- 'done': 'โœ… Complete',
64
- 'error': 'โŒ Error'
65
- }
66
-
67
- friendly_stage = stage_map.get(stage, stage)
68
- percentage = int(progress * 100)
69
-
70
- status_msg = f"**[{percentage}%]** {friendly_stage}"
71
- if message:
72
- status_msg += f"\n\n_{message}_"
73
-
74
- yield status_msg
75
-
76
- # Stop if done or error
77
- if stage in ['done', 'error']:
78
- break
79
-
80
- except Exception as e:
81
- print(f"Error reading status: {e}")
82
-
83
- time.sleep(1) # Poll every second
84
-
85
- # Final status
86
- yield "โœ… **Processing complete!** Loading results..."
87
-
88
-
89
- def analyze_video(video_file):
90
- """
91
- Main analysis function with real-time progress monitoring
92
- """
93
- # Initial empty outputs
94
- empty_outputs = (None, None, None, None, None, None, None, None)
95
 
96
- if video_file is None:
97
- yield (*empty_outputs, "โŒ **Please upload a video file**")
98
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- try:
101
- # Create job directory
102
- job_id = str(uuid.uuid4())
103
- job_dir = os.path.join(BASE_RESULTS_DIR, job_id)
104
- os.makedirs(job_dir, exist_ok=True)
105
-
106
- yield (*empty_outputs, f"๐Ÿ”ง **[0%]** Setting up job `{job_id[:8]}...`")
107
-
108
- # Copy video file
109
- video_input_path = video_file
110
 
111
- if not os.path.exists(video_input_path):
112
- yield (*empty_outputs, f"โŒ **Video file not found:** `{video_input_path}`")
113
- return
114
 
115
- video_filename = f"input_{uuid.uuid4().hex[:8]}.mp4"
116
- video_path = os.path.join(job_dir, video_filename)
117
- shutil.copy2(video_input_path, video_path)
118
 
119
- yield (*empty_outputs, f"๐Ÿ“ **[5%]** Video copied successfully")
 
 
 
 
 
120
 
121
- # Start pipeline in separate thread
122
- import threading
123
 
124
- pipeline_error = []
125
- pipeline_complete = []
 
126
 
127
- def run_pipeline_thread():
128
- try:
129
- result = run_full_pipeline(video_path, job_dir)
130
- pipeline_complete.append(result)
131
- except Exception as e:
132
- pipeline_error.append(str(e))
133
-
134
- thread = threading.Thread(target=run_pipeline_thread, daemon=True)
135
- thread.start()
136
-
137
- # Monitor progress while pipeline runs
138
- for progress_msg in monitor_progress(job_dir):
139
- yield (*empty_outputs, progress_msg)
140
-
141
- # Wait for thread to complete
142
- thread.join(timeout=10)
143
-
144
- # Check for errors
145
- if pipeline_error:
146
- error_msg = f"""โŒ **Error during analysis:**
147
-
148
- ```
149
- {pipeline_error[0]}
150
- ```
151
-
152
- **Troubleshooting:**
153
- - Ensure ROBOFLOW_API_KEY is set in Space secrets
154
- - Try a shorter video clip
155
- - Check container logs for details
156
- """
157
- yield (*empty_outputs, error_msg)
158
- return
159
 
160
- # Load results
161
- result_path = os.path.join(job_dir, "result.json")
162
 
163
- if not os.path.exists(result_path):
164
- yield (*empty_outputs, "โš ๏ธ **Result file not found.** Pipeline may still be processing.")
165
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- with open(result_path, 'r') as f:
168
- result = json.load(f)
169
-
170
- # Extract results
171
- basic = result["basic"]
172
- adv = result["advanced"]
173
- ball = result["ball"]
174
- stats = result["stats"]
175
- siglip_html = result["siglip_html"]
176
-
177
- # Load images
178
- def load_image(path, name="image"):
179
- if not path or not os.path.exists(path):
180
- print(f"โš ๏ธ {name} not found at {path}")
181
- return None
182
- try:
183
- img = Image.open(path)
184
- if img.mode != 'RGB':
185
- img = img.convert('RGB')
186
- print(f"โœ… Loaded {name}: {img.size}")
187
- return img
188
- except Exception as e:
189
- print(f"โŒ Error loading {name}: {e}")
190
- return None
191
-
192
- print("=" * 60)
193
- print("LOADING RESULT IMAGES...")
194
-
195
- raw_frame = load_image(basic["raw_frame"], "Raw frame")
196
- boxes_labels = load_image(basic["boxes_labels"], "Boxes/labels")
197
- ball_players = load_image(basic["ball_players"], "Ball/players")
198
-
199
- frame_advanced = load_image(adv["frame_advanced"], "Advanced frame")
200
- radar = load_image(adv["radar"], "Radar")
201
- voronoi = load_image(adv["voronoi"], "Voronoi")
202
- voronoi_blended = load_image(adv["voronoi_blended"], "Voronoi blended")
203
-
204
- ball_path_cleaned = load_image(ball["ball_path_cleaned_img"], "Ball path")
205
-
206
- # Format stats
207
- stats_text = json.dumps(stats, indent=2)
208
-
209
- # Count loaded images
210
- images_loaded = sum([
211
- raw_frame is not None,
212
- boxes_labels is not None,
213
- ball_players is not None,
214
- frame_advanced is not None,
215
- radar is not None,
216
- voronoi_blended is not None,
217
- ball_path_cleaned is not None
218
- ])
219
-
220
- print(f"Images loaded: {images_loaded}/7")
221
- print("=" * 60)
222
-
223
- # Create success message
224
- clustering_link = ""
225
- if siglip_html and os.path.exists(siglip_html):
226
- rel_path = os.path.relpath(siglip_html, ".")
227
- clustering_link = f'\n\n๐Ÿ“Š <a href="file/{rel_path}" target="_blank">**View 3D Clustering Visualization**</a>'
228
-
229
- success_msg = f"""โœ… **[100%] Analysis Complete!**
230
-
231
- **Job ID:** `{job_id}`
232
-
233
- **Results Generated:**
234
- - โœ… Player detections ({images_loaded}/7 images loaded)
235
- - โœ… Team classifications
236
- - โœ… Tactical visualizations
237
- - โœ… Ball trajectory analysis
238
- - โœ… Match statistics
239
-
240
- {clustering_link}
241
-
242
- ---
243
- *Scroll through the tabs above to see all visualizations*
244
-
245
- **Note:** If images appear dark, try uploading a brighter video with good lighting.
246
- """
247
-
248
- # Final yield with all results
249
- yield (
250
- raw_frame,
251
- boxes_labels,
252
- ball_players,
253
- frame_advanced,
254
- radar,
255
- voronoi_blended,
256
- ball_path_cleaned,
257
- stats_text,
258
- success_msg
259
  )
260
 
261
- except Exception as e:
262
- import traceback
263
- error_detail = traceback.format_exc()
 
 
 
264
 
265
- error_msg = f"""โŒ **Unexpected Error:**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
- ```python
268
- {str(e)}
269
- ```
270
 
271
- **Debug Info:**
272
- - Video: `{video_file if video_file else 'None'}`
273
- - Job: `{job_dir if 'job_dir' in locals() else 'Not created'}`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- **Full Traceback:**
276
- ```
277
- {error_detail}
278
- ```
279
- """
280
- print("=" * 60)
281
- print("UNEXPECTED ERROR:")
282
- print(error_detail)
283
- print("=" * 60)
284
-
285
- yield (*empty_outputs, error_msg)
286
 
 
 
 
 
 
287
 
288
- # Create Gradio interface
289
- with gr.Blocks(
290
- title="โšฝ Afrigoals - Football Analytics",
291
- theme=gr.themes.Soft()
292
- ) as demo:
293
 
294
- gr.Markdown("""
295
- # โšฝ Afrigoals - Football Analytics Platform
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- AI-powered match analysis using computer vision and machine learning.
 
 
 
 
 
 
 
298
 
299
- **Features:**
300
- - ๐ŸŽฏ Player detection and tracking
301
- - ๐Ÿ‘• Automatic team classification
302
- - โšฝ Ball path tracking
303
- - ๐ŸŽจ Tactical visualizations (radar, Voronoi)
304
- - ๐Ÿ“Š 3D clustering analysis
305
- - ๐Ÿ“ˆ Match statistics
 
 
 
306
 
307
- ---
308
- """)
 
 
 
 
 
 
309
 
310
- with gr.Row():
311
- with gr.Column(scale=1):
312
- video_input = gr.Video(
313
- label="๐Ÿ“น Upload Match Video",
314
- sources=["upload"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- gr.Markdown("""
318
- **Supported formats:** MP4, AVI, MOV, MKV
 
319
 
320
- **Best results with:**
321
- - 30-60 second clips
322
- - Good lighting (daytime matches)
323
- - Clear view of the pitch
324
- - 720p or higher resolution
325
 
326
- **Progress updates appear below** ๐Ÿ‘‡
327
- """)
 
 
328
 
329
- analyze_btn = gr.Button(
330
- "๐Ÿ” Analyze Video",
331
- variant="primary",
332
- size="lg"
333
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  with gr.Row():
336
- with gr.Column():
337
- gr.Markdown("### ๐Ÿ“Š Progress & Status")
338
- status_output = gr.Markdown(
339
- value="โณ **Ready** - Upload a video and click Analyze to start"
340
- )
341
 
342
- gr.Markdown("---")
343
- gr.Markdown("## ๐Ÿ“Š Analysis Results")
344
- gr.Markdown("*Results will appear here after analysis (typically 30-90 seconds)*")
345
 
346
  with gr.Tabs():
347
- with gr.Tab("๐ŸŽฏ Basic Detections"):
348
- gr.Markdown("### Player and Ball Detection Results")
349
- with gr.Row():
350
- with gr.Column():
351
- gr.Markdown("#### Raw Frame")
352
- raw_frame_output = gr.Image(
353
- label="Original Frame",
354
- type="pil"
355
- )
356
- with gr.Column():
357
- gr.Markdown("#### Detections with Boxes")
358
- boxes_output = gr.Image(
359
- label="Bounding Boxes + Labels",
360
- type="pil"
361
- )
362
- with gr.Row():
363
- with gr.Column():
364
- gr.Markdown("#### Ball vs Players")
365
- ball_players_output = gr.Image(
366
- label="Ball & Players Markers",
367
- type="pil"
368
- )
369
 
370
- with gr.Tab("๐ŸŽจ Advanced Views"):
371
- gr.Markdown("### Tactical Analysis")
372
- with gr.Row():
373
- with gr.Column():
374
- gr.Markdown("#### Annotated Frame with Teams")
375
- advanced_output = gr.Image(
376
- label="Teams + Player IDs",
377
- type="pil"
378
- )
379
- with gr.Column():
380
- gr.Markdown("#### Radar View")
381
- radar_output = gr.Image(
382
- label="Top-Down Tactical View",
383
- type="pil"
384
- )
385
- with gr.Row():
386
- with gr.Column():
387
- gr.Markdown("#### Voronoi Diagram - Space Control")
388
- voronoi_output = gr.Image(
389
- label="Territory Control",
390
- type="pil"
391
- )
392
 
393
- with gr.Tab("โšฝ Ball Analysis"):
394
- gr.Markdown("### Ball Movement Tracking")
395
- with gr.Row():
396
- with gr.Column():
397
- gr.Markdown("#### Ball Path (Cleaned)")
398
- ball_path_output = gr.Image(
399
- label="Ball Trajectory",
400
- type="pil"
401
- )
402
 
403
- with gr.Tab("๐Ÿ“ˆ Statistics"):
404
- gr.Markdown("### Match Statistics")
405
- stats_output = gr.Code(
406
- label="JSON Data",
407
- language="json",
408
- lines=25
409
- )
410
 
411
- # Wire up the button
412
  analyze_btn.click(
413
- fn=analyze_video,
414
  inputs=[video_input],
415
  outputs=[
416
- raw_frame_output,
417
- boxes_output,
418
- ball_players_output,
419
- advanced_output,
420
  radar_output,
421
- voronoi_output,
422
- ball_path_output,
423
- stats_output,
424
  status_output
425
  ]
426
  )
427
 
428
  gr.Markdown("""
429
  ---
430
- ## ๐Ÿ“ Information
431
-
432
- ### โฑ๏ธ Processing Time
433
- | Video Length | Est. Time |
434
- |--------------|-----------|
435
- | 30 seconds | 30-60s |
436
- | 1 minute | 60-120s |
437
- | 2 minutes | 2-4 min |
438
-
439
- ### ๐Ÿ”ง Technical Stack
440
- - **Detection**: Roboflow YOLO models
441
- - **Tracking**: ByteTrack algorithm
442
- - **Embeddings**: SigLIP
443
- - **Clustering**: UMAP + K-means
444
- - **Visualization**: Supervision, Plotly, OpenCV
445
-
446
- ### โš ๏ธ Requirements
447
- - ROBOFLOW_API_KEY must be set in Space secrets
448
- - GPU recommended for faster processing
449
- - Good lighting improves detection quality
450
 
451
- ---
452
- **Built with โค๏ธ for African Football Analytics**
 
 
 
453
  """)
454
 
455
-
456
- # Launch
457
  if __name__ == "__main__":
458
- demo.queue() # Enable queuing for progress
459
- demo.launch(
460
- server_name="0.0.0.0",
461
- server_port=7860,
462
- show_error=True
463
- )
 
1
  import os
2
+ from collections import deque, defaultdict
3
+ from typing import List, Tuple, Dict
4
+ from io import BytesIO
5
+ import base64
6
 
7
+ import cv2
8
+ import numpy as np
9
+ from PIL import Image
10
+ import torch
11
+ from tqdm import tqdm
12
+ from scipy.ndimage import gaussian_filter
13
+
14
+ import supervision as sv
15
+ from sports.common.team import TeamClassifier
16
+ from sports.common.view import ViewTransformer
17
+ from sports.annotators.soccer import draw_pitch, draw_points_on_pitch, draw_paths_on_pitch
18
+ from sports.configs.soccer import SoccerPitchConfiguration
19
 
20
  import gradio as gr
21
+ import plotly.graph_objects as go
22
+ from plotly.subplots import make_subplots
23
+ from transformers import AutoProcessor, SiglipVisionModel
24
+ from more_itertools import chunked
25
+ from sklearn.cluster import KMeans
26
+ import umap
27
 
28
+ from inference_sdk import InferenceHTTPClient
 
29
 
30
+ # ==============================================
31
+ # ENVIRONMENT VARIABLES
32
+ # ==============================================
33
+ HF_TOKEN = os.environ.get("HF_TOKEN")
34
+ ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY")
35
 
36
+ if not HF_TOKEN or not ROBOFLOW_API_KEY:
37
+ raise ValueError("โŒ HF_TOKEN and ROBOFLOW_API_KEY must be set as environment variables.")
38
 
39
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
+ print(f"๐Ÿ–ฅ๏ธ Using device: {DEVICE}")
41
+
42
+ # ==============================================
43
+ # ROBOFLOW INFERENCE CLIENT
44
+ # ==============================================
45
+ CLIENT = InferenceHTTPClient(
46
+ api_url="https://detect.roboflow.com",
47
+ api_key=ROBOFLOW_API_KEY
48
+ )
49
+
50
+ PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
51
+ FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
52
+
53
+ # ==============================================
54
+ # SIGLIP MODEL (Embeddings)
55
+ # ==============================================
56
+ SIGLIP_MODEL_PATH = "google/siglip-base-patch16-224"
57
+ EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH, token=HF_TOKEN).to(DEVICE)
58
+ EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH, token=HF_TOKEN)
59
+
60
+ # ==============================================
61
+ # TEAM CLASSIFIER & CONFIG
62
+ # ==============================================
63
+ CONFIG = SoccerPitchConfiguration()
64
+
65
+ # ==============================================
66
+ # PLAYER PERFORMANCE TRACKING
67
+ # ==============================================
68
+ class PlayerPerformanceTracker:
69
+ """Track individual player performance metrics and generate heatmaps"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ def __init__(self, pitch_config):
72
+ self.config = pitch_config
73
+ self.player_positions = defaultdict(list) # tracker_id -> list of (x, y, frame)
74
+ self.player_velocities = defaultdict(list) # tracker_id -> list of velocities
75
+ self.player_distances = defaultdict(float) # tracker_id -> total distance
76
+ self.player_team = {} # tracker_id -> team_id
77
+ self.player_stats = defaultdict(lambda: {
78
+ 'frames_visible': 0,
79
+ 'avg_velocity': 0,
80
+ 'max_velocity': 0,
81
+ 'time_in_attacking_third': 0,
82
+ 'time_in_defensive_third': 0,
83
+ 'time_in_middle_third': 0
84
+ })
85
+
86
+ def update(self, tracker_id: int, position: np.ndarray, team_id: int, frame: int):
87
+ """Update player position and calculate metrics"""
88
+ if len(position) != 2:
89
+ return
90
+
91
+ self.player_team[tracker_id] = team_id
92
+ self.player_positions[tracker_id].append((position[0], position[1], frame))
93
+ self.player_stats[tracker_id]['frames_visible'] += 1
94
+
95
+ # Calculate velocity if we have previous position
96
+ if len(self.player_positions[tracker_id]) > 1:
97
+ prev_pos = np.array(self.player_positions[tracker_id][-2][:2])
98
+ curr_pos = np.array(position)
99
+ distance = np.linalg.norm(curr_pos - prev_pos)
100
+ self.player_distances[tracker_id] += distance
101
+
102
+ # Velocity (assuming 30 fps)
103
+ velocity = distance * 30 # cm/s
104
+ self.player_velocities[tracker_id].append(velocity)
105
+
106
+ # Update velocity stats
107
+ if velocity > self.player_stats[tracker_id]['max_velocity']:
108
+ self.player_stats[tracker_id]['max_velocity'] = velocity
109
+
110
+ # Track position zones (thirds of the pitch)
111
+ pitch_length = self.config.length
112
+ if position[0] < pitch_length / 3:
113
+ self.player_stats[tracker_id]['time_in_defensive_third'] += 1
114
+ elif position[0] < 2 * pitch_length / 3:
115
+ self.player_stats[tracker_id]['time_in_middle_third'] += 1
116
+ else:
117
+ self.player_stats[tracker_id]['time_in_attacking_third'] += 1
118
 
119
+ def get_player_stats(self, tracker_id: int) -> dict:
120
+ """Get comprehensive stats for a player"""
121
+ stats = self.player_stats[tracker_id].copy()
 
 
 
 
 
 
 
122
 
123
+ if len(self.player_velocities[tracker_id]) > 0:
124
+ stats['avg_velocity'] = np.mean(self.player_velocities[tracker_id])
 
125
 
126
+ stats['total_distance'] = self.player_distances[tracker_id]
127
+ stats['total_distance_meters'] = self.player_distances[tracker_id] / 100 # Convert to meters
128
+ stats['team_id'] = self.player_team.get(tracker_id, -1)
129
 
130
+ return stats
131
+
132
+ def generate_heatmap(self, tracker_id: int, resolution: int = 100) -> np.ndarray:
133
+ """Generate heatmap for a specific player"""
134
+ if tracker_id not in self.player_positions or len(self.player_positions[tracker_id]) == 0:
135
+ return np.zeros((resolution, resolution))
136
 
137
+ positions = np.array([(x, y) for x, y, _ in self.player_positions[tracker_id]])
 
138
 
139
+ # Create 2D histogram
140
+ pitch_length = self.config.length
141
+ pitch_width = self.config.width
142
 
143
+ heatmap, xedges, yedges = np.histogram2d(
144
+ positions[:, 0], positions[:, 1],
145
+ bins=[resolution, resolution],
146
+ range=[[0, pitch_length], [0, pitch_width]]
147
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # Apply Gaussian smoothing for better visualization
150
+ heatmap = gaussian_filter(heatmap, sigma=3)
151
 
152
+ return heatmap.T # Transpose for correct orientation
153
+
154
+ def get_all_players_by_team(self) -> Dict[int, List[int]]:
155
+ """Get all player IDs grouped by team"""
156
+ teams = defaultdict(list)
157
+ for tracker_id, team_id in self.player_team.items():
158
+ teams[team_id].append(tracker_id)
159
+ return teams
160
+
161
+
162
+ # ==============================================
163
+ # TRACKING MANAGER
164
+ # ==============================================
165
+ class PlayerTrackingManager:
166
+ """Manages persistent player tracking with team assignment stability"""
167
+
168
+ def __init__(self, max_history=10):
169
+ self.tracker_team_history: Dict[int, List[int]] = defaultdict(list)
170
+ self.max_history = max_history
171
+ self.active_trackers = set()
172
+
173
+ def update_team_assignment(self, tracker_id: int, team_id: int):
174
+ """Store team assignment history for each tracker"""
175
+ self.tracker_team_history[tracker_id].append(team_id)
176
+ if len(self.tracker_team_history[tracker_id]) > self.max_history:
177
+ self.tracker_team_history[tracker_id].pop(0)
178
+ self.active_trackers.add(tracker_id)
179
+
180
+ def get_stable_team_id(self, tracker_id: int, current_team_id: int) -> int:
181
+ """Get stable team ID using majority voting from history"""
182
+ if tracker_id not in self.tracker_team_history or len(self.tracker_team_history[tracker_id]) < 3:
183
+ return current_team_id
184
+
185
+ history = self.tracker_team_history[tracker_id]
186
+ team_counts = np.bincount(history)
187
+ stable_team = np.argmax(team_counts)
188
+ return stable_team
189
+
190
+ def get_player_count_by_team(self) -> Dict[int, int]:
191
+ """Get current count of players per team"""
192
+ team_counts = defaultdict(int)
193
+ for tracker_id in self.active_trackers:
194
+ if tracker_id in self.tracker_team_history and len(self.tracker_team_history[tracker_id]) > 0:
195
+ stable_team = self.get_stable_team_id(tracker_id, self.tracker_team_history[tracker_id][-1])
196
+ team_counts[stable_team] += 1
197
+ return team_counts
198
+
199
+ def reset_frame(self):
200
+ """Reset active trackers for new frame"""
201
+ self.active_trackers = set()
202
+
203
+
204
+ # ==============================================
205
+ # VISUALIZATION FUNCTIONS
206
+ # ==============================================
207
+ def create_player_heatmap_visualization(performance_tracker: PlayerPerformanceTracker,
208
+ tracker_id: int) -> np.ndarray:
209
+ """Create a single player heatmap overlay on pitch"""
210
+ pitch = draw_pitch(CONFIG)
211
+ heatmap = performance_tracker.generate_heatmap(tracker_id, resolution=150)
212
+
213
+ # Normalize heatmap
214
+ if heatmap.max() > 0:
215
+ heatmap = heatmap / heatmap.max()
216
+
217
+ # Create colored heatmap
218
+ scale = 0.1 # Same scale as pitch
219
+ padding = 50
220
+
221
+ pitch_height, pitch_width = pitch.shape[:2]
222
+ heatmap_resized = cv2.resize(heatmap, (pitch_width - 2*padding, pitch_height - 2*padding))
223
+
224
+ # Apply colormap (red = high activity, blue = low activity)
225
+ heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
226
+
227
+ # Create overlay
228
+ overlay = pitch.copy()
229
+ overlay[padding:pitch_height-padding, padding:pitch_width-padding] = heatmap_colored
230
+
231
+ # Blend with pitch
232
+ result = cv2.addWeighted(pitch, 0.6, overlay, 0.4, 0)
233
+
234
+ # Add stats text
235
+ stats = performance_tracker.get_player_stats(tracker_id)
236
+ team_color = "Blue" if stats['team_id'] == 0 else "Pink"
237
+
238
+ text_lines = [
239
+ f"Player #{tracker_id} ({team_color} Team)",
240
+ f"Distance: {stats['total_distance_meters']:.1f}m",
241
+ f"Avg Speed: {stats['avg_velocity']/100:.2f}m/s",
242
+ f"Max Speed: {stats['max_velocity']/100:.2f}m/s",
243
+ f"Frames: {stats['frames_visible']}"
244
+ ]
245
+
246
+ y_offset = 30
247
+ for line in text_lines:
248
+ cv2.putText(result, line, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX,
249
+ 0.6, (255, 255, 255), 2, cv2.LINE_AA)
250
+ y_offset += 25
251
+
252
+ return result
253
+
254
+
255
+ def create_team_comparison_plot(performance_tracker: PlayerPerformanceTracker) -> go.Figure:
256
+ """Create interactive performance comparison plots"""
257
+ teams = performance_tracker.get_all_players_by_team()
258
+
259
+ fig = make_subplots(
260
+ rows=2, cols=2,
261
+ subplot_titles=('Distance Covered', 'Average Speed', 'Max Speed', 'Activity by Zone'),
262
+ specs=[[{'type': 'bar'}, {'type': 'bar'}],
263
+ [{'type': 'bar'}, {'type': 'bar'}]]
264
+ )
265
+
266
+ colors = {0: '#00BFFF', 1: '#FF1493'}
267
+ team_names = {0: 'Team 0 (Blue)', 1: 'Team 1 (Pink)'}
268
+
269
+ for team_id, player_ids in teams.items():
270
+ if team_id not in [0, 1]:
271
+ continue
272
+
273
+ distances = []
274
+ avg_speeds = []
275
+ max_speeds = []
276
+ attacking_time = []
277
+
278
+ for pid in player_ids:
279
+ stats = performance_tracker.get_player_stats(pid)
280
+ distances.append(stats['total_distance_meters'])
281
+ avg_speeds.append(stats['avg_velocity']/100)
282
+ max_speeds.append(stats['max_velocity']/100)
283
+ attacking_time.append(stats['time_in_attacking_third'])
284
+
285
+ player_labels = [f"#{pid}" for pid in player_ids]
286
+
287
+ # Distance covered
288
+ fig.add_trace(
289
+ go.Bar(x=player_labels, y=distances, name=team_names[team_id],
290
+ marker_color=colors[team_id], showlegend=True),
291
+ row=1, col=1
292
+ )
293
 
294
+ # Average speed
295
+ fig.add_trace(
296
+ go.Bar(x=player_labels, y=avg_speeds, name=team_names[team_id],
297
+ marker_color=colors[team_id], showlegend=False),
298
+ row=1, col=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  )
300
 
301
+ # Max speed
302
+ fig.add_trace(
303
+ go.Bar(x=player_labels, y=max_speeds, name=team_names[team_id],
304
+ marker_color=colors[team_id], showlegend=False),
305
+ row=2, col=1
306
+ )
307
 
308
+ # Attacking third time
309
+ fig.add_trace(
310
+ go.Bar(x=player_labels, y=attacking_time, name=team_names[team_id],
311
+ marker_color=colors[team_id], showlegend=False),
312
+ row=2, col=2
313
+ )
314
+
315
+ fig.update_xaxes(title_text="Players", row=1, col=1)
316
+ fig.update_xaxes(title_text="Players", row=1, col=2)
317
+ fig.update_xaxes(title_text="Players", row=2, col=1)
318
+ fig.update_xaxes(title_text="Players", row=2, col=2)
319
+
320
+ fig.update_yaxes(title_text="Distance (m)", row=1, col=1)
321
+ fig.update_yaxes(title_text="Speed (m/s)", row=1, col=2)
322
+ fig.update_yaxes(title_text="Speed (m/s)", row=2, col=1)
323
+ fig.update_yaxes(title_text="Frames in Zone", row=2, col=2)
324
+
325
+ fig.update_layout(height=800, title_text="Team Performance Comparison", barmode='group')
326
+
327
+ return fig
328
 
 
 
 
329
 
330
+ def create_combined_heatmaps(performance_tracker: PlayerPerformanceTracker) -> np.ndarray:
331
+ """Create side-by-side team heatmaps"""
332
+ teams = performance_tracker.get_all_players_by_team()
333
+
334
+ team_heatmaps = []
335
+ for team_id in [0, 1]:
336
+ if team_id not in teams:
337
+ continue
338
+
339
+ # Combine all players from this team
340
+ combined_heatmap = np.zeros((150, 150))
341
+ for pid in teams[team_id]:
342
+ player_heatmap = performance_tracker.generate_heatmap(pid, resolution=150)
343
+ combined_heatmap += player_heatmap
344
+
345
+ if combined_heatmap.max() > 0:
346
+ combined_heatmap = combined_heatmap / combined_heatmap.max()
347
+
348
+ # Create visualization
349
+ pitch = draw_pitch(CONFIG)
350
+ padding = 50
351
+ pitch_height, pitch_width = pitch.shape[:2]
352
+ heatmap_resized = cv2.resize(combined_heatmap,
353
+ (pitch_width - 2*padding, pitch_height - 2*padding))
354
+
355
+ colormap = cv2.COLORMAP_JET if team_id == 0 else cv2.COLORMAP_HOT
356
+ heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), colormap)
357
+
358
+ overlay = pitch.copy()
359
+ overlay[padding:pitch_height-padding, padding:pitch_width-padding] = heatmap_colored
360
+ result = cv2.addWeighted(pitch, 0.5, overlay, 0.5, 0)
361
+
362
+ team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)"
363
+ cv2.putText(result, team_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
364
+ 1, (255, 255, 255), 2, cv2.LINE_AA)
365
+
366
+ team_heatmaps.append(result)
367
+
368
+ if len(team_heatmaps) == 2:
369
+ return np.hstack(team_heatmaps)
370
+ elif len(team_heatmaps) == 1:
371
+ return team_heatmaps[0]
372
+ else:
373
+ return draw_pitch(CONFIG)
374
+
375
+
376
+ # ==============================================
377
+ # HELPER FUNCTIONS
378
+ # ==============================================
379
+ def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detections) -> np.ndarray:
380
+ if len(goalkeepers) == 0 or len(players) == 0:
381
+ return np.array([])
382
+ goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
383
+ players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
384
+ team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
385
+ team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
386
+ return np.array([
387
+ 0 if np.linalg.norm(gk - team_0_centroid) < np.linalg.norm(gk - team_1_centroid) else 1
388
+ for gk in goalkeepers_xy
389
+ ])
390
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
+ def pil_image_to_data_uri(image: Image.Image) -> str:
393
+ buffered = BytesIO()
394
+ image.save(buffered, format="PNG")
395
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
396
+ return f"data:image/png;base64,{img_str}"
397
 
398
+
399
+ def create_game_style_radar(pitch_ball_xy, pitch_players_xy, players_class_id,
400
+ pitch_referees_xy, ball_path=None):
401
+ """Create game-style radar view with ball trail effect"""
402
+ annotated_frame = draw_pitch(CONFIG)
403
 
404
+ if ball_path is not None and len(ball_path) > 0:
405
+ valid_path = [coords for coords in ball_path if len(coords) > 0]
406
+ if len(valid_path) > 1:
407
+ for i, coords in enumerate(valid_path[-20:]):
408
+ if len(coords) == 0:
409
+ continue
410
+ alpha = (i + 1) / min(20, len(valid_path))
411
+ color = sv.Color(int(255 * alpha), int(255 * alpha), int(255 * alpha))
412
+ annotated_frame = draw_points_on_pitch(
413
+ CONFIG, coords,
414
+ face_color=color,
415
+ edge_color=sv.Color.BLACK,
416
+ radius=int(6 + alpha * 4),
417
+ pitch=annotated_frame
418
+ )
419
 
420
+ if len(pitch_ball_xy) > 0:
421
+ annotated_frame = draw_points_on_pitch(
422
+ CONFIG, pitch_ball_xy,
423
+ face_color=sv.Color.WHITE,
424
+ edge_color=sv.Color.BLACK,
425
+ radius=10,
426
+ pitch=annotated_frame
427
+ )
428
 
429
+ for team_id, color_hex in zip([0, 1], ["00BFFF", "FF1493"]):
430
+ mask = players_class_id == team_id
431
+ if np.any(mask):
432
+ annotated_frame = draw_points_on_pitch(
433
+ CONFIG, pitch_players_xy[mask],
434
+ face_color=sv.Color.from_hex(color_hex),
435
+ edge_color=sv.Color.BLACK,
436
+ radius=16,
437
+ pitch=annotated_frame
438
+ )
439
 
440
+ if len(pitch_referees_xy) > 0:
441
+ annotated_frame = draw_points_on_pitch(
442
+ CONFIG, pitch_referees_xy,
443
+ face_color=sv.Color.from_hex("FFD700"),
444
+ edge_color=sv.Color.BLACK,
445
+ radius=16,
446
+ pitch=annotated_frame
447
+ )
448
 
449
+ return annotated_frame
450
+
451
+
452
+ # ==============================================
453
+ # MAIN ANALYSIS PIPELINE
454
+ # ==============================================
455
+ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
456
+ """
457
+ Complete football analysis with performance tracking and heatmaps
458
+ """
459
+ if not video_path:
460
+ return None, None, None, None, None, "โŒ Please upload a video file."
461
+
462
+ try:
463
+ progress(0, desc="๐Ÿ”ง Initializing...")
464
+
465
+ BALL_ID, GOALKEEPER_ID, PLAYER_ID, REFEREE_ID = 0, 1, 2, 3
466
+ STRIDE = 30
467
+ MAXLEN = 5
468
+
469
+ tracking_manager = PlayerTrackingManager(max_history=10)
470
+ performance_tracker = PlayerPerformanceTracker(CONFIG)
471
+
472
+ ellipse_annotator = sv.EllipseAnnotator(
473
+ color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
474
+ thickness=2
475
+ )
476
+ label_annotator = sv.LabelAnnotator(
477
+ color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
478
+ text_color=sv.Color.from_hex('#FFFFFF'),
479
+ text_thickness=2
480
+ )
481
+ triangle_annotator = sv.TriangleAnnotator(
482
+ color=sv.Color.from_hex('#FFD700'),
483
+ base=20,
484
+ height=17
485
+ )
486
+
487
+ tracker = sv.ByteTrack(
488
+ track_activation_threshold=0.4,
489
+ lost_track_buffer=60,
490
+ minimum_matching_threshold=0.85,
491
+ frame_rate=30
492
+ )
493
+ tracker.reset()
494
+
495
+ cap = cv2.VideoCapture(video_path)
496
+ if not cap.isOpened():
497
+ return None, None, None, None, None, f"โŒ Failed to open video: {video_path}"
498
+
499
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
500
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
501
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
502
+ fps = cap.get(cv2.CAP_PROP_FPS)
503
+ print(f"๐Ÿ“น Video: {width}x{height}, {fps}fps, {total_frames} frames")
504
+
505
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
506
+ output_path = "/tmp/annotated_football.mp4"
507
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
508
+
509
+ # Collect player crops
510
+ progress(0.05, desc="๐Ÿƒ Collecting player samples...")
511
+ player_crops = []
512
+ frame_count = 0
513
+ while frame_count < min(total_frames, 300):
514
+ ret, frame = cap.read()
515
+ if not ret:
516
+ break
517
+
518
+ if frame_count % STRIDE == 0:
519
+ result = CLIENT.infer(frame, model_id=PLAYER_DETECTION_MODEL_ID)
520
+ detections = sv.Detections.from_inference(result)
521
+ players_detections = detections[detections.class_id == PLAYER_ID]
522
+
523
+ if len(players_detections.xyxy) > 0:
524
+ crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
525
+ player_crops.extend(crops)
526
+
527
+ frame_count += 1
528
+
529
+ if len(player_crops) == 0:
530
+ return None, None, None, None, None, "โŒ No player crops collected."
531
+
532
+ print(f"โœ… Collected {len(player_crops)} player samples")
533
+
534
+ # Train classifier
535
+ progress(0.15, desc="๐ŸŽฏ Training team classifier...")
536
+ team_classifier = TeamClassifier(device=DEVICE)
537
+ team_classifier.fit(player_crops)
538
+ print("โœ… Team classifier trained")
539
+
540
+ # Process video
541
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
542
+ frame_count = 0
543
+ M = deque(maxlen=MAXLEN)
544
+ ball_path_raw = []
545
+
546
+ progress(0.2, desc="๐ŸŽฌ Processing video frames...")
547
+ while True:
548
+ ret, frame = cap.read()
549
+ if not ret:
550
+ break
551
+
552
+ frame_count += 1
553
+ tracking_manager.reset_frame()
554
+
555
+ if frame_count % 30 == 0:
556
+ progress(0.2 + 0.5 * (frame_count / total_frames),
557
+ desc=f"๐ŸŽฌ Frame {frame_count}/{total_frames}")
558
+
559
+ result = CLIENT.infer(frame, model_id=PLAYER_DETECTION_MODEL_ID)
560
+ detections = sv.Detections.from_inference(result)
561
+
562
+ if len(detections.xyxy) == 0:
563
+ out.write(frame)
564
+ continue
565
+
566
+ ball_detections = detections[detections.class_id == BALL_ID]
567
+ all_detections = detections[detections.class_id != BALL_ID]
568
+ all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
569
+ all_detections = tracker.update_with_detections(detections=all_detections)
570
+
571
+ goalkeepers_detections = all_detections[all_detections.class_id == GOALKEEPER_ID]
572
+ players_detections = all_detections[all_detections.class_id == PLAYER_ID]
573
+ referees_detections = all_detections[all_detections.class_id == REFEREE_ID]
574
+
575
+ if len(players_detections.xyxy) > 0:
576
+ crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
577
+ predicted_teams = team_classifier.predict(crops)
578
+
579
+ for idx, tracker_id in enumerate(players_detections.tracker_id):
580
+ tracking_manager.update_team_assignment(tracker_id, predicted_teams[idx])
581
+ predicted_teams[idx] = tracking_manager.get_stable_team_id(
582
+ tracker_id, predicted_teams[idx]
583
+ )
584
+
585
+ players_detections.class_id = predicted_teams
586
+
587
+ goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
588
+ players_detections, goalkeepers_detections
589
  )
590
+ referees_detections.class_id -= 1
591
+
592
+ all_detections = sv.Detections.merge([
593
+ players_detections, goalkeepers_detections, referees_detections
594
+ ])
595
+
596
+ labels = [f"#{tid}" for tid in all_detections.tracker_id]
597
+
598
+ annotated_frame = frame.copy()
599
+ annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections)
600
+ annotated_frame = label_annotator.annotate(annotated_frame, all_detections, labels=labels)
601
+ annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections)
602
+ out.write(annotated_frame)
603
+
604
+ # Performance tracking with field transformation
605
+ try:
606
+ result_field = CLIENT.infer(frame, model_id=FIELD_DETECTION_MODEL_ID)
607
+ key_points = sv.KeyPoints.from_inference(result_field)
608
+ filter_mask = key_points.confidence[0] > 0.5
609
+ frame_ref_pts = key_points.xy[0][filter_mask]
610
+ pitch_ref_pts = np.array(CONFIG.vertices)[filter_mask]
611
+ transformer = ViewTransformer(source=frame_ref_pts, target=pitch_ref_pts)
612
+ M.append(transformer.m)
613
+ transformer.m = np.mean(np.array(M), axis=0)
614
+
615
+ # Track ball
616
+ frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
617
+ pitch_ball_xy = transformer.transform_points(frame_ball_xy)
618
+ ball_path_raw.append(pitch_ball_xy)
619
+
620
+ # Track all players (including goalkeepers)
621
+ all_players = sv.Detections.merge([players_detections, goalkeepers_detections])
622
+ players_xy = all_players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
623
+ pitch_players_xy = transformer.transform_points(players_xy)
624
+
625
+ for idx, tracker_id in enumerate(all_players.tracker_id):
626
+ if idx < len(pitch_players_xy):
627
+ performance_tracker.update(
628
+ tracker_id,
629
+ pitch_players_xy[idx],
630
+ all_players.class_id[idx],
631
+ frame_count
632
+ )
633
+ except:
634
+ ball_path_raw.append(np.empty((0, 2)))
635
+
636
+ cap.release()
637
+ out.release()
638
+ print(f"โœ… Processed {frame_count} frames")
639
+
640
+ # Generate visualizations
641
+ progress(0.75, desc="๐Ÿ“Š Generating performance analytics...")
642
+
643
+ # Team comparison
644
+ comparison_fig = create_team_comparison_plot(performance_tracker)
645
+
646
+ # Combined team heatmaps
647
+ team_heatmaps_path = "/tmp/team_heatmaps.png"
648
+ team_heatmaps = create_combined_heatmaps(performance_tracker)
649
+ cv2.imwrite(team_heatmaps_path, team_heatmaps)
650
+
651
+ # Individual player heatmaps (top 6 players by distance)
652
+ progress(0.85, desc="๐Ÿ—บ๏ธ Creating individual heatmaps...")
653
+ teams = performance_tracker.get_all_players_by_team()
654
+ top_players = []
655
+ for team_id in [0, 1]:
656
+ if team_id in teams:
657
+ team_players = teams[team_id]
658
+ player_distances = [(pid, performance_tracker.get_player_stats(pid)['total_distance'])
659
+ for pid in team_players]
660
+ player_distances.sort(key=lambda x: x[1], reverse=True)
661
+ top_players.extend([pid for pid, _ in player_distances[:3]])
662
+
663
+ individual_heatmaps = []
664
+ for pid in top_players[:6]:
665
+ heatmap = create_player_heatmap_visualization(performance_tracker, pid)
666
+ individual_heatmaps.append(heatmap)
667
+
668
+ # Arrange individual heatmaps in grid
669
+ if len(individual_heatmaps) > 0:
670
+ rows = []
671
+ for i in range(0, len(individual_heatmaps), 3):
672
+ row_maps = individual_heatmaps[i:i+3]
673
+ if len(row_maps) == 3:
674
+ rows.append(np.hstack(row_maps))
675
+ elif len(row_maps) == 2:
676
+ rows.append(np.hstack([row_maps[0], row_maps[1]]))
677
+ else:
678
+ rows.append(row_maps[0])
679
+
680
+ individual_grid = np.vstack(rows) if len(rows) > 1 else rows[0]
681
+ individual_heatmaps_path = "/tmp/individual_heatmaps.png"
682
+ cv2.imwrite(individual_heatmaps_path, individual_grid)
683
+ else:
684
+ individual_heatmaps_path = None
685
+
686
+ # Radar view
687
+ progress(0.9, desc="๐Ÿ—บ๏ธ Creating game-style radar view...")
688
+ radar_path = "/tmp/radar_view_enhanced.png"
689
+ try:
690
+ radar_frame = create_game_style_radar(
691
+ pitch_ball_xy=ball_path_raw[-1] if ball_path_raw else np.empty((0, 2)),
692
+ pitch_players_xy=pitch_players_xy if 'pitch_players_xy' in locals() else np.empty((0, 2)),
693
+ players_class_id=all_players.class_id if 'all_players' in locals() else np.array([]),
694
+ pitch_referees_xy=np.empty((0, 2)),
695
+ ball_path=ball_path_raw
696
+ )
697
+ cv2.imwrite(radar_path, radar_frame)
698
+ except Exception as e:
699
+ print(f"โš ๏ธ Radar view creation failed: {e}")
700
+ radar_path = None
701
+
702
+ # Generate summary stats
703
+ progress(0.95, desc="๐Ÿ“ Generating summary report...")
704
+ teams = performance_tracker.get_all_players_by_team()
705
+
706
+ summary_lines = ["โœ… **Analysis Complete!**\n"]
707
+ summary_lines.append(f"- Total Frames: {frame_count}")
708
+ summary_lines.append(f"- Ball Path Points: {len([p for p in ball_path_raw if len(p) > 0])}\n")
709
+
710
+ for team_id in [0, 1]:
711
+ if team_id not in teams:
712
+ continue
713
 
714
+ team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)"
715
+ summary_lines.append(f"\n**{team_name}:**")
716
+ summary_lines.append(f"- Players Tracked: {len(teams[team_id])}")
717
 
718
+ total_dist = sum(performance_tracker.get_player_stats(pid)['total_distance_meters']
719
+ for pid in teams[team_id])
720
+ avg_dist = total_dist / len(teams[team_id]) if len(teams[team_id]) > 0 else 0
721
+ summary_lines.append(f"- Team Total Distance: {total_dist:.1f}m")
722
+ summary_lines.append(f"- Average Distance per Player: {avg_dist:.1f}m")
723
 
724
+ # Top 3 performers
725
+ player_distances = [(pid, performance_tracker.get_player_stats(pid)['total_distance_meters'])
726
+ for pid in teams[team_id]]
727
+ player_distances.sort(key=lambda x: x[1], reverse=True)
728
 
729
+ summary_lines.append(f"\n **Top Performers:**")
730
+ for i, (pid, dist) in enumerate(player_distances[:3], 1):
731
+ stats = performance_tracker.get_player_stats(pid)
732
+ summary_lines.append(
733
+ f" {i}. Player #{pid}: {dist:.1f}m, "
734
+ f"Avg Speed: {stats['avg_velocity']/100:.2f}m/s"
735
+ )
736
+
737
+ summary_msg = "\n".join(summary_lines)
738
+
739
+ progress(1.0, desc="โœ… Complete!")
740
+
741
+ return (output_path, comparison_fig, team_heatmaps_path,
742
+ individual_heatmaps_path, radar_path, summary_msg)
743
+
744
+ except Exception as e:
745
+ error_msg = f"โŒ Error: {str(e)}"
746
+ print(error_msg)
747
+ import traceback
748
+ traceback.print_exc()
749
+ return None, None, None, None, None, error_msg
750
+
751
+
752
+ # ==============================================
753
+ # GRADIO INTERFACE
754
+ # ==============================================
755
+ with gr.Blocks(title="โšฝ Football Performance Analyzer") as iface:
756
+ gr.Markdown("""
757
+ # โšฝ Advanced Football Video Analyzer
758
+ Upload a football match video to get comprehensive performance analytics including:
759
+ - Player tracking with persistent IDs
760
+ - Individual and team heatmaps
761
+ - Distance covered and speed metrics
762
+ - Game-style radar view with ball tracking
763
+ """)
764
 
765
  with gr.Row():
766
+ video_input = gr.Video(label="Upload Football Video")
767
+
768
+ analyze_btn = gr.Button("๐Ÿš€ Analyze Video", variant="primary", size="lg")
 
 
769
 
770
+ with gr.Row():
771
+ status_output = gr.Textbox(label="Analysis Status & Summary", lines=20)
 
772
 
773
  with gr.Tabs():
774
+ with gr.Tab("๐Ÿ“น Annotated Video"):
775
+ video_output = gr.Video(label="Annotated Video with Player Tracking")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
 
777
+ with gr.Tab("๐Ÿ“Š Performance Comparison"):
778
+ comparison_output = gr.Plot(label="Team Performance Metrics")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
 
780
+ with gr.Tab("๐Ÿ—บ๏ธ Team Heatmaps"):
781
+ team_heatmaps_output = gr.Image(label="Combined Team Activity Heatmaps")
 
 
 
 
 
 
 
782
 
783
+ with gr.Tab("๐Ÿ‘ค Individual Heatmaps"):
784
+ individual_heatmaps_output = gr.Image(label="Top Players Individual Heatmaps")
785
+
786
+ with gr.Tab("๐ŸŽฎ Game Radar View"):
787
+ radar_output = gr.Image(label="Game-Style Radar with Ball Trail")
 
 
788
 
 
789
  analyze_btn.click(
790
+ fn=analyze_football_video,
791
  inputs=[video_input],
792
  outputs=[
793
+ video_output,
794
+ comparison_output,
795
+ team_heatmaps_output,
796
+ individual_heatmaps_output,
797
  radar_output,
 
 
 
798
  status_output
799
  ]
800
  )
801
 
802
  gr.Markdown("""
803
  ---
804
+ ### ๐Ÿ“‹ Features:
805
+ - **Persistent Player Tracking**: IDs remain consistent throughout the video
806
+ - **Performance Metrics**: Distance covered, average/max speed, zone activity
807
+ - **Team Heatmaps**: Visualize team positioning and movement patterns
808
+ - **Individual Analysis**: Top 6 players by distance with detailed heatmaps
809
+ - **Professional Visualization**: Game-style radar view with ball trail effects
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
 
811
+ ### ๐ŸŽฏ Perfect for:
812
+ - Coaching staff analysis
813
+ - Player performance reports
814
+ - Tactical review sessions
815
+ - Scouting and recruitment
816
  """)
817
 
 
 
818
  if __name__ == "__main__":
819
+ iface.launch()