Simon9 commited on
Commit
a1b01ef
ยท
verified ยท
1 Parent(s): c16f347

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +593 -328
app.py CHANGED
@@ -9,6 +9,7 @@ import numpy as np
9
  from PIL import Image
10
  import torch
11
  from tqdm import tqdm
 
12
 
13
  import supervision as sv
14
  from sports.common.team import TeamClassifier
@@ -18,6 +19,7 @@ from sports.configs.soccer import SoccerPitchConfiguration
18
 
19
  import gradio as gr
20
  import plotly.graph_objects as go
 
21
  from transformers import AutoProcessor, SiglipVisionModel
22
  from more_itertools import chunked
23
  from sklearn.cluster import KMeans
@@ -60,6 +62,317 @@ EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH, token=HF
60
  # ==============================================
61
  CONFIG = SoccerPitchConfiguration()
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # ==============================================
64
  # HELPER FUNCTIONS
65
  # ==============================================
@@ -75,129 +388,73 @@ def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detectio
75
  for gk in goalkeepers_xy
76
  ])
77
 
 
78
  def pil_image_to_data_uri(image: Image.Image) -> str:
79
  buffered = BytesIO()
80
  image.save(buffered, format="PNG")
81
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
82
  return f"data:image/png;base64,{img_str}"
83
 
84
- def create_umap_3d_plot(crops: List[np.ndarray]) -> Tuple[go.Figure, dict]:
85
- if len(crops) == 0:
86
- return go.Figure(), {}
87
- BATCH_SIZE = 32
88
- crops_pil = [sv.cv2_to_pillow(crop) for crop in crops]
89
- batches = list(chunked(crops_pil, BATCH_SIZE))
90
- data = []
91
- with torch.no_grad():
92
- for batch in tqdm(batches, desc="Extracting embeddings"):
93
- inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(DEVICE)
94
- outputs = EMBEDDINGS_MODEL(**inputs)
95
- embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
96
- data.append(embeddings)
97
- data = np.concatenate(data)
98
- reducer = umap.UMAP(n_components=3, random_state=42)
99
- projections = reducer.fit_transform(data)
100
- clustering_model = KMeans(n_clusters=2, n_init=10, random_state=42)
101
- clusters = clustering_model.fit_predict(projections)
102
-
103
- traces = []
104
- for lbl in np.unique(clusters):
105
- mask = clusters == lbl
106
- trace = go.Scatter3d(
107
- x=projections[mask][:, 0],
108
- y=projections[mask][:, 1],
109
- z=projections[mask][:, 2],
110
- mode="markers",
111
- name=f"Team {lbl}",
112
- marker=dict(size=6, opacity=0.8),
113
- hovertemplate="<b>Team %{text}</b><extra></extra>",
114
- text=clusters[mask]
115
- )
116
- traces.append(trace)
117
- fig = go.Figure(data=traces)
118
- fig.update_layout(
119
- width=800, height=800,
120
- title="3D UMAP: Player Embeddings",
121
- scene=dict(
122
- xaxis_title="UMAP 1",
123
- yaxis_title="UMAP 2",
124
- zaxis_title="UMAP 3",
125
- camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
126
- )
127
- )
128
- return fig, {}
129
 
130
- def create_heatmap(positions: np.ndarray, config: SoccerPitchConfiguration, title: str = "Heatmap") -> np.ndarray:
131
- """Create a heatmap visualization of player positions on the pitch."""
132
- pitch = draw_pitch(config)
133
- if len(positions) == 0:
134
- return pitch
135
-
136
- # Create 2D histogram
137
- h, w = pitch.shape[:2]
138
- heatmap = np.zeros((h, w), dtype=np.float32)
139
-
140
- scale = 0.1 # Default scale from draw_pitch
141
- padding = 50
142
-
143
- for pos in positions:
144
- x = int(pos[0] * scale) + padding
145
- y = int(pos[1] * scale) + padding
146
- if 0 <= x < w and 0 <= y < h:
147
- # Gaussian blur around position
148
- cv2.circle(heatmap, (x, y), 30, 1.0, -1)
149
 
150
- # Apply Gaussian blur
151
- heatmap = cv2.GaussianBlur(heatmap, (51, 51), 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # Normalize
154
- if heatmap.max() > 0:
155
- heatmap = heatmap / heatmap.max()
 
 
 
 
 
156
 
157
- # Apply colormap
158
- heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
 
 
 
 
 
 
 
 
159
 
160
- # Blend with pitch
161
- result = cv2.addWeighted(pitch, 0.6, heatmap_colored, 0.4, 0)
 
 
 
 
 
 
162
 
163
- return result
164
 
165
- def calculate_player_stats(tracker_positions: Dict[int, List[np.ndarray]]) -> Dict[int, Dict]:
166
- """Calculate statistics for each tracked player."""
167
- stats = {}
168
- for tracker_id, positions in tracker_positions.items():
169
- if len(positions) < 2:
170
- continue
171
-
172
- positions_array = np.array(positions)
173
-
174
- # Calculate distance traveled
175
- distances = np.sqrt(np.sum(np.diff(positions_array, axis=0)**2, axis=1))
176
- total_distance = np.sum(distances)
177
-
178
- # Calculate average speed (assuming 30 fps)
179
- avg_speed = np.mean(distances) * 30 # pixels per second
180
-
181
- # Calculate area covered (bounding box of positions)
182
- min_x, min_y = positions_array.min(axis=0)
183
- max_x, max_y = positions_array.max(axis=0)
184
- area_covered = (max_x - min_x) * (max_y - min_y)
185
-
186
- stats[tracker_id] = {
187
- 'total_distance': total_distance,
188
- 'avg_speed': avg_speed,
189
- 'area_covered': area_covered,
190
- 'num_positions': len(positions)
191
- }
192
-
193
- return stats
194
 
195
  # ==============================================
196
  # MAIN ANALYSIS PIPELINE
197
  # ==============================================
198
- def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple[str, go.Figure, str, str, str, str]:
199
  """
200
- Complete football video analysis pipeline with proper pitch transformation and performance tracking.
201
  """
202
  if not video_path:
203
  return None, None, None, None, None, "โŒ Please upload a video file."
@@ -205,27 +462,36 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple[str
205
  try:
206
  progress(0, desc="๐Ÿ”ง Initializing...")
207
 
208
- # Detection IDs
209
  BALL_ID, GOALKEEPER_ID, PLAYER_ID, REFEREE_ID = 0, 1, 2, 3
210
  STRIDE = 30
211
  MAXLEN = 5
212
 
213
- # Annotators
 
 
214
  ellipse_annotator = sv.EllipseAnnotator(
215
- color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']), thickness=2
 
216
  )
217
  label_annotator = sv.LabelAnnotator(
218
  color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
219
- text_color=sv.Color.from_hex('#FFFFFF'), text_thickness=2
 
220
  )
221
  triangle_annotator = sv.TriangleAnnotator(
222
- color=sv.Color.from_hex('#FFD700'), base=20, height=17
 
 
223
  )
224
 
225
- tracker = sv.ByteTrack()
 
 
 
 
 
226
  tracker.reset()
227
 
228
- # Video setup
229
  cap = cv2.VideoCapture(video_path)
230
  if not cap.isOpened():
231
  return None, None, None, None, None, f"โŒ Failed to open video: {video_path}"
@@ -240,16 +506,12 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple[str
240
  output_path = "/tmp/annotated_football.mp4"
241
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
242
 
243
- # --------------------------
244
- # Step 1: Collect player crops for classifier
245
- # --------------------------
246
  progress(0.05, desc="๐Ÿƒ Collecting player samples...")
247
  player_crops = []
248
  frame_count = 0
249
- cap_temp = cv2.VideoCapture(video_path)
250
-
251
  while frame_count < min(total_frames, 300):
252
- ret, frame = cap_temp.read()
253
  if not ret:
254
  break
255
 
@@ -263,35 +525,23 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple[str
263
  player_crops.extend(crops)
264
 
265
  frame_count += 1
266
-
267
- cap_temp.release()
268
 
269
  if len(player_crops) == 0:
270
- return None, None, None, None, None, "โŒ No player crops collected. Check video or detection confidence."
271
 
272
- print(f"โœ… Collected {len(player_crops)} player samples for team classifier")
273
 
274
- # --------------------------
275
- # Step 2: Fit TeamClassifier
276
- # --------------------------
277
  progress(0.15, desc="๐ŸŽฏ Training team classifier...")
278
  team_classifier = TeamClassifier(device=DEVICE)
279
  team_classifier.fit(player_crops)
280
  print("โœ… Team classifier trained")
281
 
282
- # --------------------------
283
- # Step 3: Process entire video with tracking
284
- # --------------------------
285
  frame_count = 0
286
  M = deque(maxlen=MAXLEN)
287
-
288
- # Tracking data structures
289
  ball_path_raw = []
290
- team_0_positions = []
291
- team_1_positions = []
292
- referee_positions = []
293
- player_tracker_positions = defaultdict(list) # tracker_id -> list of pitch positions
294
- player_tracker_teams = {} # tracker_id -> team_id
295
 
296
  progress(0.2, desc="๐ŸŽฌ Processing video frames...")
297
  while True:
@@ -300,10 +550,12 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple[str
300
  break
301
 
302
  frame_count += 1
 
 
303
  if frame_count % 30 == 0:
304
- progress(0.2 + 0.5 * (frame_count / total_frames), desc=f"๐ŸŽฌ Frame {frame_count}/{total_frames}")
 
305
 
306
- # Player detection
307
  result = CLIENT.infer(frame, model_id=PLAYER_DETECTION_MODEL_ID)
308
  detections = sv.Detections.from_inference(result)
309
 
@@ -311,9 +563,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple[str
311
  out.write(frame)
312
  continue
313
 
314
- # Separate detections
315
  ball_detections = detections[detections.class_id == BALL_ID]
316
-
317
  all_detections = detections[detections.class_id != BALL_ID]
318
  all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
319
  all_detections = tracker.update_with_detections(detections=all_detections)
@@ -322,201 +572,174 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple[str
322
  players_detections = all_detections[all_detections.class_id == PLAYER_ID]
323
  referees_detections = all_detections[all_detections.class_id == REFEREE_ID]
324
 
325
- # Predict team IDs
326
  if len(players_detections.xyxy) > 0:
327
  crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
328
- players_detections.class_id = team_classifier.predict(crops).astype(int)
329
-
330
- # Assign goalkeeper teams
331
- if len(goalkeepers_detections) > 0 and len(players_detections) > 0:
332
- goalkeepers_detections.class_id = resolve_goalkeepers_team_id(players_detections, goalkeepers_detections).astype(int)
 
 
 
 
333
 
334
- # Set referees to class 2
335
- if len(referees_detections) > 0:
336
- referees_detections.class_id = np.full(len(referees_detections), 2, dtype=int)
 
337
 
338
- # Merge all detections
339
- all_detections = sv.Detections.merge([players_detections, goalkeepers_detections, referees_detections])
340
- all_detections.class_id = all_detections.class_id.astype(int)
341
 
342
- # Field detection & transformation
343
- try:
344
- result_field = CLIENT.infer(frame, model_id=FIELD_DETECTION_MODEL_ID)
345
- key_points = sv.KeyPoints.from_inference(result_field)
346
- filter_mask = key_points.confidence[0] > 0.5
347
-
348
- if np.sum(filter_mask) >= 4: # Need at least 4 points for transformation
349
- frame_ref_pts = key_points.xy[0][filter_mask]
350
- pitch_ref_pts = np.array(CONFIG.vertices)[filter_mask]
351
-
352
- transformer = ViewTransformer(source=frame_ref_pts, target=pitch_ref_pts)
353
- M.append(transformer.m)
354
- transformer.m = np.mean(np.array(M), axis=0)
355
-
356
- # Transform ball position
357
- if len(ball_detections) > 0:
358
- frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
359
- pitch_ball_xy = transformer.transform_points(frame_ball_xy)
360
- ball_path_raw.append(pitch_ball_xy)
361
-
362
- # Transform player positions
363
- if len(all_detections) > 0:
364
- frame_players_xy = all_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
365
- pitch_players_xy = transformer.transform_points(frame_players_xy)
366
-
367
- # Store positions by team and tracker
368
- for i, (pitch_pos, class_id, tracker_id) in enumerate(
369
- zip(pitch_players_xy, all_detections.class_id, all_detections.tracker_id)
370
- ):
371
- if class_id == 0:
372
- team_0_positions.append(pitch_pos)
373
- elif class_id == 1:
374
- team_1_positions.append(pitch_pos)
375
- elif class_id == 2:
376
- referee_positions.append(pitch_pos)
377
-
378
- # Track individual players
379
- if class_id in [0, 1]:
380
- player_tracker_positions[tracker_id].append(pitch_pos)
381
- player_tracker_teams[tracker_id] = class_id
382
-
383
- except Exception as e:
384
- print(f"โš ๏ธ Transformation failed at frame {frame_count}: {e}")
385
-
386
- # Annotate frame
387
  labels = [f"#{tid}" for tid in all_detections.tracker_id]
 
388
  annotated_frame = frame.copy()
389
  annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections)
390
  annotated_frame = label_annotator.annotate(annotated_frame, all_detections, labels=labels)
391
  annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections)
392
  out.write(annotated_frame)
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  cap.release()
395
  out.release()
396
  print(f"โœ… Processed {frame_count} frames")
397
 
398
- # --------------------------
399
- # Step 4: UMAP Embeddings Visualization
400
- # --------------------------
401
- progress(0.75, desc="๐ŸŽจ Creating player embeddings...")
402
- umap_fig, _ = create_umap_3d_plot(player_crops[:500])
403
-
404
- # --------------------------
405
- # Step 5: Radar View with all positions
406
- # --------------------------
407
- progress(0.85, desc="๐Ÿ—บ๏ธ Creating radar view...")
408
- radar_path = "/tmp/radar_view.png"
409
- try:
410
- annotated_frame = draw_pitch(CONFIG)
411
-
412
- # Draw ball path
413
- if len(ball_path_raw) > 0:
414
- ball_path = np.concatenate(ball_path_raw)
415
- # Filter valid coordinates
416
- ball_path = [coord.flatten() for coord in ball_path_raw if coord.shape[0] > 0]
417
- if len(ball_path) > 0:
418
- annotated_frame = draw_paths_on_pitch(
419
- config=CONFIG,
420
- paths=[ball_path],
421
- color=sv.Color.WHITE,
422
- pitch=annotated_frame
423
- )
424
-
425
- # Draw team positions
426
- if len(team_0_positions) > 0:
427
- annotated_frame = draw_points_on_pitch(
428
- CONFIG, np.array(team_0_positions),
429
- face_color=sv.Color.from_hex("00BFFF"),
430
- edge_color=sv.Color.BLACK, radius=8, pitch=annotated_frame
431
- )
432
-
433
- if len(team_1_positions) > 0:
434
- annotated_frame = draw_points_on_pitch(
435
- CONFIG, np.array(team_1_positions),
436
- face_color=sv.Color.from_hex("FF1493"),
437
- edge_color=sv.Color.BLACK, radius=8, pitch=annotated_frame
438
- )
439
-
440
- if len(referee_positions) > 0:
441
- annotated_frame = draw_points_on_pitch(
442
- CONFIG, np.array(referee_positions),
443
- face_color=sv.Color.from_hex("FFD700"),
444
- edge_color=sv.Color.BLACK, radius=8, pitch=annotated_frame
445
- )
446
 
447
- cv2.imwrite(radar_path, annotated_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  except Exception as e:
449
  print(f"โš ๏ธ Radar view creation failed: {e}")
450
  radar_path = None
451
 
452
- # --------------------------
453
- # Step 6: Heatmaps
454
- # --------------------------
455
- progress(0.90, desc="๐Ÿ”ฅ Creating heatmaps...")
456
- heatmap_team0_path = "/tmp/heatmap_team0.png"
457
- heatmap_team1_path = "/tmp/heatmap_team1.png"
458
 
459
- try:
460
- if len(team_0_positions) > 0:
461
- heatmap0 = create_heatmap(np.array(team_0_positions), CONFIG, "Team 0 Heatmap")
462
- cv2.imwrite(heatmap_team0_path, heatmap0)
 
 
 
463
 
464
- if len(team_1_positions) > 0:
465
- heatmap1 = create_heatmap(np.array(team_1_positions), CONFIG, "Team 1 Heatmap")
466
- cv2.imwrite(heatmap_team1_path, heatmap1)
467
- except Exception as e:
468
- print(f"โš ๏ธ Heatmap creation failed: {e}")
469
- heatmap_team0_path = None
470
- heatmap_team1_path = None
471
-
472
- # --------------------------
473
- # Step 7: Player Statistics
474
- # --------------------------
475
- progress(0.95, desc="๐Ÿ“Š Calculating statistics...")
476
- player_stats = calculate_player_stats(player_tracker_positions)
477
-
478
- # Find top performers
479
- if player_stats:
480
- top_distance = max(player_stats.items(), key=lambda x: x[1]['total_distance'])
481
- top_speed = max(player_stats.items(), key=lambda x: x[1]['avg_speed'])
482
- top_area = max(player_stats.items(), key=lambda x: x[1]['area_covered'])
483
 
484
- # Team statistics
485
- team_0_distances = sum(stats['total_distance'] for tid, stats in player_stats.items()
486
- if player_tracker_teams.get(tid) == 0)
487
- team_1_distances = sum(stats['total_distance'] for tid, stats in player_stats.items()
488
- if player_tracker_teams.get(tid) == 1)
489
 
490
- stats_msg = f"""๐Ÿ“Š **Performance Statistics**
491
-
492
- **Top Players:**
493
- - Most Distance: Player #{top_distance[0]} - {top_distance[1]['total_distance']:.0f} units
494
- - Fastest: Player #{top_speed[0]} - {top_speed[1]['avg_speed']:.2f} units/s
495
- - Most Area Covered: Player #{top_area[0]} - {top_area[1]['area_covered']:.0f} sq units
496
-
497
- **Team Statistics:**
498
- - Team 0 (Blue) Total Distance: {team_0_distances:.0f} units
499
- - Team 1 (Pink) Total Distance: {team_1_distances:.0f} units
500
- - Better Team: {"Team 0 (Blue)" if team_0_distances > team_1_distances else "Team 1 (Pink)"}
501
-
502
- **General:**
503
- - Total Tracked Players: {len(player_stats)}
504
- - Team 0 Players: {sum(1 for t in player_tracker_teams.values() if t == 0)}
505
- - Team 1 Players: {sum(1 for t in player_tracker_teams.values() if t == 1)}
506
- """
507
- else:
508
- stats_msg = "โš ๏ธ Insufficient tracking data for statistics"
509
 
510
- progress(1.0, desc="โœ… Analysis Complete!")
511
 
512
- success_msg = f"""โœ… **Analysis Complete!**
513
- - Total Frames: {frame_count}
514
- - Player Samples: {len(player_crops)}
515
- - Ball Path Points: {len(ball_path_raw)}
516
- - Team 0 Positions: {len(team_0_positions)}
517
- - Team 1 Positions: {len(team_1_positions)}
518
- """
519
- return output_path, umap_fig, radar_path, heatmap_team0_path, heatmap_team1_path, success_msg + "\n" + stats_msg
520
 
521
  except Exception as e:
522
  error_msg = f"โŒ Error: {str(e)}"
@@ -525,30 +748,72 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple[str
525
  traceback.print_exc()
526
  return None, None, None, None, None, error_msg
527
 
 
528
  # ==============================================
529
  # GRADIO INTERFACE
530
  # ==============================================
531
- iface = gr.Interface(
532
- fn=analyze_football_video,
533
- inputs=gr.Video(label="Upload Football Video"),
534
- outputs=[
535
- gr.Video(label="Annotated Video"),
536
- gr.Plot(label="3D Player Embeddings (UMAP)"),
537
- gr.Image(label="Radar View (All Positions)"),
538
- gr.Image(label="Team 0 Heatmap"),
539
- gr.Image(label="Team 1 Heatmap"),
540
- gr.Textbox(label="Statistics & Status", lines=20)
541
- ],
542
- title="โšฝ Football Video Analyzer - Complete Pipeline",
543
- description="""
544
- Upload a football video for comprehensive analysis:
545
- - Player detection and team classification
546
- - Ball tracking with trajectory visualization
547
- - Player movement heatmaps
548
- - Performance statistics and rankings
549
- - 3D embeddings visualization
550
- """
551
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  if __name__ == "__main__":
554
  iface.launch()
 
9
  from PIL import Image
10
  import torch
11
  from tqdm import tqdm
12
+ from scipy.ndimage import gaussian_filter
13
 
14
  import supervision as sv
15
  from sports.common.team import TeamClassifier
 
19
 
20
  import gradio as gr
21
  import plotly.graph_objects as go
22
+ from plotly.subplots import make_subplots
23
  from transformers import AutoProcessor, SiglipVisionModel
24
  from more_itertools import chunked
25
  from sklearn.cluster import KMeans
 
62
  # ==============================================
63
  CONFIG = SoccerPitchConfiguration()
64
 
65
+ # ==============================================
66
+ # PLAYER PERFORMANCE TRACKING
67
+ # ==============================================
68
+ class PlayerPerformanceTracker:
69
+ """Track individual player performance metrics and generate heatmaps"""
70
+
71
+ def __init__(self, pitch_config):
72
+ self.config = pitch_config
73
+ self.player_positions = defaultdict(list) # tracker_id -> list of (x, y, frame)
74
+ self.player_velocities = defaultdict(list) # tracker_id -> list of velocities
75
+ self.player_distances = defaultdict(float) # tracker_id -> total distance
76
+ self.player_team = {} # tracker_id -> team_id
77
+ self.player_stats = defaultdict(lambda: {
78
+ 'frames_visible': 0,
79
+ 'avg_velocity': 0,
80
+ 'max_velocity': 0,
81
+ 'time_in_attacking_third': 0,
82
+ 'time_in_defensive_third': 0,
83
+ 'time_in_middle_third': 0
84
+ })
85
+
86
+ def update(self, tracker_id: int, position: np.ndarray, team_id: int, frame: int):
87
+ """Update player position and calculate metrics"""
88
+ if len(position) != 2:
89
+ return
90
+
91
+ self.player_team[tracker_id] = team_id
92
+ self.player_positions[tracker_id].append((position[0], position[1], frame))
93
+ self.player_stats[tracker_id]['frames_visible'] += 1
94
+
95
+ # Calculate velocity if we have previous position
96
+ if len(self.player_positions[tracker_id]) > 1:
97
+ prev_pos = np.array(self.player_positions[tracker_id][-2][:2])
98
+ curr_pos = np.array(position)
99
+ distance = np.linalg.norm(curr_pos - prev_pos)
100
+ self.player_distances[tracker_id] += distance
101
+
102
+ # Velocity (assuming 30 fps)
103
+ velocity = distance * 30 # cm/s
104
+ self.player_velocities[tracker_id].append(velocity)
105
+
106
+ # Update velocity stats
107
+ if velocity > self.player_stats[tracker_id]['max_velocity']:
108
+ self.player_stats[tracker_id]['max_velocity'] = velocity
109
+
110
+ # Track position zones (thirds of the pitch)
111
+ pitch_length = self.config.length
112
+ if position[0] < pitch_length / 3:
113
+ self.player_stats[tracker_id]['time_in_defensive_third'] += 1
114
+ elif position[0] < 2 * pitch_length / 3:
115
+ self.player_stats[tracker_id]['time_in_middle_third'] += 1
116
+ else:
117
+ self.player_stats[tracker_id]['time_in_attacking_third'] += 1
118
+
119
+ def get_player_stats(self, tracker_id: int) -> dict:
120
+ """Get comprehensive stats for a player"""
121
+ stats = self.player_stats[tracker_id].copy()
122
+
123
+ if len(self.player_velocities[tracker_id]) > 0:
124
+ stats['avg_velocity'] = np.mean(self.player_velocities[tracker_id])
125
+
126
+ stats['total_distance'] = self.player_distances[tracker_id]
127
+ stats['total_distance_meters'] = self.player_distances[tracker_id] / 100 # Convert to meters
128
+ stats['team_id'] = self.player_team.get(tracker_id, -1)
129
+
130
+ return stats
131
+
132
+ def generate_heatmap(self, tracker_id: int, resolution: int = 100) -> np.ndarray:
133
+ """Generate heatmap for a specific player"""
134
+ if tracker_id not in self.player_positions or len(self.player_positions[tracker_id]) == 0:
135
+ return np.zeros((resolution, resolution))
136
+
137
+ positions = np.array([(x, y) for x, y, _ in self.player_positions[tracker_id]])
138
+
139
+ # Create 2D histogram
140
+ pitch_length = self.config.length
141
+ pitch_width = self.config.width
142
+
143
+ heatmap, xedges, yedges = np.histogram2d(
144
+ positions[:, 0], positions[:, 1],
145
+ bins=[resolution, resolution],
146
+ range=[[0, pitch_length], [0, pitch_width]]
147
+ )
148
+
149
+ # Apply Gaussian smoothing for better visualization
150
+ heatmap = gaussian_filter(heatmap, sigma=3)
151
+
152
+ return heatmap.T # Transpose for correct orientation
153
+
154
+ def get_all_players_by_team(self) -> Dict[int, List[int]]:
155
+ """Get all player IDs grouped by team"""
156
+ teams = defaultdict(list)
157
+ for tracker_id, team_id in self.player_team.items():
158
+ teams[team_id].append(tracker_id)
159
+ return teams
160
+
161
+
162
+ # ==============================================
163
+ # TRACKING MANAGER
164
+ # ==============================================
165
+ class PlayerTrackingManager:
166
+ """Manages persistent player tracking with team assignment stability"""
167
+
168
+ def __init__(self, max_history=10):
169
+ self.tracker_team_history: Dict[int, List[int]] = defaultdict(list)
170
+ self.max_history = max_history
171
+ self.active_trackers = set()
172
+
173
+ def update_team_assignment(self, tracker_id: int, team_id: int):
174
+ """Store team assignment history for each tracker"""
175
+ self.tracker_team_history[tracker_id].append(team_id)
176
+ if len(self.tracker_team_history[tracker_id]) > self.max_history:
177
+ self.tracker_team_history[tracker_id].pop(0)
178
+ self.active_trackers.add(tracker_id)
179
+
180
+ def get_stable_team_id(self, tracker_id: int, current_team_id: int) -> int:
181
+ """Get stable team ID using majority voting from history"""
182
+ if tracker_id not in self.tracker_team_history or len(self.tracker_team_history[tracker_id]) < 3:
183
+ return current_team_id
184
+
185
+ history = self.tracker_team_history[tracker_id]
186
+ team_counts = np.bincount(history)
187
+ stable_team = np.argmax(team_counts)
188
+ return stable_team
189
+
190
+ def get_player_count_by_team(self) -> Dict[int, int]:
191
+ """Get current count of players per team"""
192
+ team_counts = defaultdict(int)
193
+ for tracker_id in self.active_trackers:
194
+ if tracker_id in self.tracker_team_history and len(self.tracker_team_history[tracker_id]) > 0:
195
+ stable_team = self.get_stable_team_id(tracker_id, self.tracker_team_history[tracker_id][-1])
196
+ team_counts[stable_team] += 1
197
+ return team_counts
198
+
199
+ def reset_frame(self):
200
+ """Reset active trackers for new frame"""
201
+ self.active_trackers = set()
202
+
203
+
204
+ # ==============================================
205
+ # VISUALIZATION FUNCTIONS
206
+ # ==============================================
207
+ def create_player_heatmap_visualization(performance_tracker: PlayerPerformanceTracker,
208
+ tracker_id: int) -> np.ndarray:
209
+ """Create a single player heatmap overlay on pitch"""
210
+ pitch = draw_pitch(CONFIG)
211
+ heatmap = performance_tracker.generate_heatmap(tracker_id, resolution=150)
212
+
213
+ # Normalize heatmap
214
+ if heatmap.max() > 0:
215
+ heatmap = heatmap / heatmap.max()
216
+
217
+ # Create colored heatmap
218
+ scale = 0.1 # Same scale as pitch
219
+ padding = 50
220
+
221
+ pitch_height, pitch_width = pitch.shape[:2]
222
+ heatmap_resized = cv2.resize(heatmap, (pitch_width - 2*padding, pitch_height - 2*padding))
223
+
224
+ # Apply colormap (red = high activity, blue = low activity)
225
+ heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
226
+
227
+ # Create overlay
228
+ overlay = pitch.copy()
229
+ overlay[padding:pitch_height-padding, padding:pitch_width-padding] = heatmap_colored
230
+
231
+ # Blend with pitch
232
+ result = cv2.addWeighted(pitch, 0.6, overlay, 0.4, 0)
233
+
234
+ # Add stats text
235
+ stats = performance_tracker.get_player_stats(tracker_id)
236
+ team_color = "Blue" if stats['team_id'] == 0 else "Pink"
237
+
238
+ text_lines = [
239
+ f"Player #{tracker_id} ({team_color} Team)",
240
+ f"Distance: {stats['total_distance_meters']:.1f}m",
241
+ f"Avg Speed: {stats['avg_velocity']/100:.2f}m/s",
242
+ f"Max Speed: {stats['max_velocity']/100:.2f}m/s",
243
+ f"Frames: {stats['frames_visible']}"
244
+ ]
245
+
246
+ y_offset = 30
247
+ for line in text_lines:
248
+ cv2.putText(result, line, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX,
249
+ 0.6, (255, 255, 255), 2, cv2.LINE_AA)
250
+ y_offset += 25
251
+
252
+ return result
253
+
254
+
255
+ def create_team_comparison_plot(performance_tracker: PlayerPerformanceTracker) -> go.Figure:
256
+ """Create interactive performance comparison plots"""
257
+ teams = performance_tracker.get_all_players_by_team()
258
+
259
+ fig = make_subplots(
260
+ rows=2, cols=2,
261
+ subplot_titles=('Distance Covered', 'Average Speed', 'Max Speed', 'Activity by Zone'),
262
+ specs=[[{'type': 'bar'}, {'type': 'bar'}],
263
+ [{'type': 'bar'}, {'type': 'bar'}]]
264
+ )
265
+
266
+ colors = {0: '#00BFFF', 1: '#FF1493'}
267
+ team_names = {0: 'Team 0 (Blue)', 1: 'Team 1 (Pink)'}
268
+
269
+ for team_id, player_ids in teams.items():
270
+ if team_id not in [0, 1]:
271
+ continue
272
+
273
+ distances = []
274
+ avg_speeds = []
275
+ max_speeds = []
276
+ attacking_time = []
277
+
278
+ for pid in player_ids:
279
+ stats = performance_tracker.get_player_stats(pid)
280
+ distances.append(stats['total_distance_meters'])
281
+ avg_speeds.append(stats['avg_velocity']/100)
282
+ max_speeds.append(stats['max_velocity']/100)
283
+ attacking_time.append(stats['time_in_attacking_third'])
284
+
285
+ player_labels = [f"#{pid}" for pid in player_ids]
286
+
287
+ # Distance covered
288
+ fig.add_trace(
289
+ go.Bar(x=player_labels, y=distances, name=team_names[team_id],
290
+ marker_color=colors[team_id], showlegend=True),
291
+ row=1, col=1
292
+ )
293
+
294
+ # Average speed
295
+ fig.add_trace(
296
+ go.Bar(x=player_labels, y=avg_speeds, name=team_names[team_id],
297
+ marker_color=colors[team_id], showlegend=False),
298
+ row=1, col=2
299
+ )
300
+
301
+ # Max speed
302
+ fig.add_trace(
303
+ go.Bar(x=player_labels, y=max_speeds, name=team_names[team_id],
304
+ marker_color=colors[team_id], showlegend=False),
305
+ row=2, col=1
306
+ )
307
+
308
+ # Attacking third time
309
+ fig.add_trace(
310
+ go.Bar(x=player_labels, y=attacking_time, name=team_names[team_id],
311
+ marker_color=colors[team_id], showlegend=False),
312
+ row=2, col=2
313
+ )
314
+
315
+ fig.update_xaxes(title_text="Players", row=1, col=1)
316
+ fig.update_xaxes(title_text="Players", row=1, col=2)
317
+ fig.update_xaxes(title_text="Players", row=2, col=1)
318
+ fig.update_xaxes(title_text="Players", row=2, col=2)
319
+
320
+ fig.update_yaxes(title_text="Distance (m)", row=1, col=1)
321
+ fig.update_yaxes(title_text="Speed (m/s)", row=1, col=2)
322
+ fig.update_yaxes(title_text="Speed (m/s)", row=2, col=1)
323
+ fig.update_yaxes(title_text="Frames in Zone", row=2, col=2)
324
+
325
+ fig.update_layout(height=800, title_text="Team Performance Comparison", barmode='group')
326
+
327
+ return fig
328
+
329
+
330
+ def create_combined_heatmaps(performance_tracker: PlayerPerformanceTracker) -> np.ndarray:
331
+ """Create side-by-side team heatmaps"""
332
+ teams = performance_tracker.get_all_players_by_team()
333
+
334
+ team_heatmaps = []
335
+ for team_id in [0, 1]:
336
+ if team_id not in teams:
337
+ continue
338
+
339
+ # Combine all players from this team
340
+ combined_heatmap = np.zeros((150, 150))
341
+ for pid in teams[team_id]:
342
+ player_heatmap = performance_tracker.generate_heatmap(pid, resolution=150)
343
+ combined_heatmap += player_heatmap
344
+
345
+ if combined_heatmap.max() > 0:
346
+ combined_heatmap = combined_heatmap / combined_heatmap.max()
347
+
348
+ # Create visualization
349
+ pitch = draw_pitch(CONFIG)
350
+ padding = 50
351
+ pitch_height, pitch_width = pitch.shape[:2]
352
+ heatmap_resized = cv2.resize(combined_heatmap,
353
+ (pitch_width - 2*padding, pitch_height - 2*padding))
354
+
355
+ colormap = cv2.COLORMAP_JET if team_id == 0 else cv2.COLORMAP_HOT
356
+ heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), colormap)
357
+
358
+ overlay = pitch.copy()
359
+ overlay[padding:pitch_height-padding, padding:pitch_width-padding] = heatmap_colored
360
+ result = cv2.addWeighted(pitch, 0.5, overlay, 0.5, 0)
361
+
362
+ team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)"
363
+ cv2.putText(result, team_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
364
+ 1, (255, 255, 255), 2, cv2.LINE_AA)
365
+
366
+ team_heatmaps.append(result)
367
+
368
+ if len(team_heatmaps) == 2:
369
+ return np.hstack(team_heatmaps)
370
+ elif len(team_heatmaps) == 1:
371
+ return team_heatmaps[0]
372
+ else:
373
+ return draw_pitch(CONFIG)
374
+
375
+
376
  # ==============================================
377
  # HELPER FUNCTIONS
378
  # ==============================================
 
388
  for gk in goalkeepers_xy
389
  ])
390
 
391
+
392
  def pil_image_to_data_uri(image: Image.Image) -> str:
393
  buffered = BytesIO()
394
  image.save(buffered, format="PNG")
395
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
396
  return f"data:image/png;base64,{img_str}"
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
+ def create_game_style_radar(pitch_ball_xy, pitch_players_xy, players_class_id,
400
+ pitch_referees_xy, ball_path=None):
401
+ """Create game-style radar view with ball trail effect"""
402
+ annotated_frame = draw_pitch(CONFIG)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
+ if ball_path is not None and len(ball_path) > 0:
405
+ valid_path = [coords for coords in ball_path if len(coords) > 0]
406
+ if len(valid_path) > 1:
407
+ for i, coords in enumerate(valid_path[-20:]):
408
+ if len(coords) == 0:
409
+ continue
410
+ alpha = (i + 1) / min(20, len(valid_path))
411
+ color = sv.Color(int(255 * alpha), int(255 * alpha), int(255 * alpha))
412
+ annotated_frame = draw_points_on_pitch(
413
+ CONFIG, coords,
414
+ face_color=color,
415
+ edge_color=sv.Color.BLACK,
416
+ radius=int(6 + alpha * 4),
417
+ pitch=annotated_frame
418
+ )
419
 
420
+ if len(pitch_ball_xy) > 0:
421
+ annotated_frame = draw_points_on_pitch(
422
+ CONFIG, pitch_ball_xy,
423
+ face_color=sv.Color.WHITE,
424
+ edge_color=sv.Color.BLACK,
425
+ radius=10,
426
+ pitch=annotated_frame
427
+ )
428
 
429
+ for team_id, color_hex in zip([0, 1], ["00BFFF", "FF1493"]):
430
+ mask = players_class_id == team_id
431
+ if np.any(mask):
432
+ annotated_frame = draw_points_on_pitch(
433
+ CONFIG, pitch_players_xy[mask],
434
+ face_color=sv.Color.from_hex(color_hex),
435
+ edge_color=sv.Color.BLACK,
436
+ radius=16,
437
+ pitch=annotated_frame
438
+ )
439
 
440
+ if len(pitch_referees_xy) > 0:
441
+ annotated_frame = draw_points_on_pitch(
442
+ CONFIG, pitch_referees_xy,
443
+ face_color=sv.Color.from_hex("FFD700"),
444
+ edge_color=sv.Color.BLACK,
445
+ radius=16,
446
+ pitch=annotated_frame
447
+ )
448
 
449
+ return annotated_frame
450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
  # ==============================================
453
  # MAIN ANALYSIS PIPELINE
454
  # ==============================================
455
+ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
456
  """
457
+ Complete football analysis with performance tracking and heatmaps
458
  """
459
  if not video_path:
460
  return None, None, None, None, None, "โŒ Please upload a video file."
 
462
  try:
463
  progress(0, desc="๐Ÿ”ง Initializing...")
464
 
 
465
  BALL_ID, GOALKEEPER_ID, PLAYER_ID, REFEREE_ID = 0, 1, 2, 3
466
  STRIDE = 30
467
  MAXLEN = 5
468
 
469
+ tracking_manager = PlayerTrackingManager(max_history=10)
470
+ performance_tracker = PlayerPerformanceTracker(CONFIG)
471
+
472
  ellipse_annotator = sv.EllipseAnnotator(
473
+ color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
474
+ thickness=2
475
  )
476
  label_annotator = sv.LabelAnnotator(
477
  color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
478
+ text_color=sv.Color.from_hex('#FFFFFF'),
479
+ text_thickness=2
480
  )
481
  triangle_annotator = sv.TriangleAnnotator(
482
+ color=sv.Color.from_hex('#FFD700'),
483
+ base=20,
484
+ height=17
485
  )
486
 
487
+ tracker = sv.ByteTrack(
488
+ track_activation_threshold=0.4,
489
+ lost_track_buffer=60,
490
+ minimum_matching_threshold=0.85,
491
+ frame_rate=30
492
+ )
493
  tracker.reset()
494
 
 
495
  cap = cv2.VideoCapture(video_path)
496
  if not cap.isOpened():
497
  return None, None, None, None, None, f"โŒ Failed to open video: {video_path}"
 
506
  output_path = "/tmp/annotated_football.mp4"
507
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
508
 
509
+ # Collect player crops
 
 
510
  progress(0.05, desc="๐Ÿƒ Collecting player samples...")
511
  player_crops = []
512
  frame_count = 0
 
 
513
  while frame_count < min(total_frames, 300):
514
+ ret, frame = cap.read()
515
  if not ret:
516
  break
517
 
 
525
  player_crops.extend(crops)
526
 
527
  frame_count += 1
 
 
528
 
529
  if len(player_crops) == 0:
530
+ return None, None, None, None, None, "โŒ No player crops collected."
531
 
532
+ print(f"โœ… Collected {len(player_crops)} player samples")
533
 
534
+ # Train classifier
 
 
535
  progress(0.15, desc="๐ŸŽฏ Training team classifier...")
536
  team_classifier = TeamClassifier(device=DEVICE)
537
  team_classifier.fit(player_crops)
538
  print("โœ… Team classifier trained")
539
 
540
+ # Process video
541
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
 
542
  frame_count = 0
543
  M = deque(maxlen=MAXLEN)
 
 
544
  ball_path_raw = []
 
 
 
 
 
545
 
546
  progress(0.2, desc="๐ŸŽฌ Processing video frames...")
547
  while True:
 
550
  break
551
 
552
  frame_count += 1
553
+ tracking_manager.reset_frame()
554
+
555
  if frame_count % 30 == 0:
556
+ progress(0.2 + 0.5 * (frame_count / total_frames),
557
+ desc=f"๐ŸŽฌ Frame {frame_count}/{total_frames}")
558
 
 
559
  result = CLIENT.infer(frame, model_id=PLAYER_DETECTION_MODEL_ID)
560
  detections = sv.Detections.from_inference(result)
561
 
 
563
  out.write(frame)
564
  continue
565
 
 
566
  ball_detections = detections[detections.class_id == BALL_ID]
 
567
  all_detections = detections[detections.class_id != BALL_ID]
568
  all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
569
  all_detections = tracker.update_with_detections(detections=all_detections)
 
572
  players_detections = all_detections[all_detections.class_id == PLAYER_ID]
573
  referees_detections = all_detections[all_detections.class_id == REFEREE_ID]
574
 
 
575
  if len(players_detections.xyxy) > 0:
576
  crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
577
+ predicted_teams = team_classifier.predict(crops)
578
+
579
+ for idx, tracker_id in enumerate(players_detections.tracker_id):
580
+ tracking_manager.update_team_assignment(tracker_id, predicted_teams[idx])
581
+ predicted_teams[idx] = tracking_manager.get_stable_team_id(
582
+ tracker_id, predicted_teams[idx]
583
+ )
584
+
585
+ players_detections.class_id = predicted_teams
586
 
587
+ goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
588
+ players_detections, goalkeepers_detections
589
+ )
590
+ referees_detections.class_id -= 1
591
 
592
+ all_detections = sv.Detections.merge([
593
+ players_detections, goalkeepers_detections, referees_detections
594
+ ])
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  labels = [f"#{tid}" for tid in all_detections.tracker_id]
597
+
598
  annotated_frame = frame.copy()
599
  annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections)
600
  annotated_frame = label_annotator.annotate(annotated_frame, all_detections, labels=labels)
601
  annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections)
602
  out.write(annotated_frame)
603
 
604
+ # Performance tracking with field transformation
605
+ try:
606
+ result_field = CLIENT.infer(frame, model_id=FIELD_DETECTION_MODEL_ID)
607
+ key_points = sv.KeyPoints.from_inference(result_field)
608
+ filter_mask = key_points.confidence[0] > 0.5
609
+ frame_ref_pts = key_points.xy[0][filter_mask]
610
+ pitch_ref_pts = np.array(CONFIG.vertices)[filter_mask]
611
+ transformer = ViewTransformer(source=frame_ref_pts, target=pitch_ref_pts)
612
+ M.append(transformer.m)
613
+ transformer.m = np.mean(np.array(M), axis=0)
614
+
615
+ # Track ball
616
+ frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
617
+ pitch_ball_xy = transformer.transform_points(frame_ball_xy)
618
+ ball_path_raw.append(pitch_ball_xy)
619
+
620
+ # Track all players (including goalkeepers)
621
+ all_players = sv.Detections.merge([players_detections, goalkeepers_detections])
622
+ players_xy = all_players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
623
+ pitch_players_xy = transformer.transform_points(players_xy)
624
+
625
+ for idx, tracker_id in enumerate(all_players.tracker_id):
626
+ if idx < len(pitch_players_xy):
627
+ performance_tracker.update(
628
+ tracker_id,
629
+ pitch_players_xy[idx],
630
+ all_players.class_id[idx],
631
+ frame_count
632
+ )
633
+ except:
634
+ ball_path_raw.append(np.empty((0, 2)))
635
+
636
  cap.release()
637
  out.release()
638
  print(f"โœ… Processed {frame_count} frames")
639
 
640
+ # Generate visualizations
641
+ progress(0.75, desc="๐Ÿ“Š Generating performance analytics...")
642
+
643
+ # Team comparison
644
+ comparison_fig = create_team_comparison_plot(performance_tracker)
645
+
646
+ # Combined team heatmaps
647
+ team_heatmaps_path = "/tmp/team_heatmaps.png"
648
+ team_heatmaps = create_combined_heatmaps(performance_tracker)
649
+ cv2.imwrite(team_heatmaps_path, team_heatmaps)
650
+
651
+ # Individual player heatmaps (top 6 players by distance)
652
+ progress(0.85, desc="๐Ÿ—บ๏ธ Creating individual heatmaps...")
653
+ teams = performance_tracker.get_all_players_by_team()
654
+ top_players = []
655
+ for team_id in [0, 1]:
656
+ if team_id in teams:
657
+ team_players = teams[team_id]
658
+ player_distances = [(pid, performance_tracker.get_player_stats(pid)['total_distance'])
659
+ for pid in team_players]
660
+ player_distances.sort(key=lambda x: x[1], reverse=True)
661
+ top_players.extend([pid for pid, _ in player_distances[:3]])
662
+
663
+ individual_heatmaps = []
664
+ for pid in top_players[:6]:
665
+ heatmap = create_player_heatmap_visualization(performance_tracker, pid)
666
+ individual_heatmaps.append(heatmap)
667
+
668
+ # Arrange individual heatmaps in grid
669
+ if len(individual_heatmaps) > 0:
670
+ rows = []
671
+ for i in range(0, len(individual_heatmaps), 3):
672
+ row_maps = individual_heatmaps[i:i+3]
673
+ if len(row_maps) == 3:
674
+ rows.append(np.hstack(row_maps))
675
+ elif len(row_maps) == 2:
676
+ rows.append(np.hstack([row_maps[0], row_maps[1]]))
677
+ else:
678
+ rows.append(row_maps[0])
 
 
 
 
 
 
 
 
 
679
 
680
+ individual_grid = np.vstack(rows) if len(rows) > 1 else rows[0]
681
+ individual_heatmaps_path = "/tmp/individual_heatmaps.png"
682
+ cv2.imwrite(individual_heatmaps_path, individual_grid)
683
+ else:
684
+ individual_heatmaps_path = None
685
+
686
+ # Radar view
687
+ progress(0.9, desc="๐Ÿ—บ๏ธ Creating game-style radar view...")
688
+ radar_path = "/tmp/radar_view_enhanced.png"
689
+ try:
690
+ radar_frame = create_game_style_radar(
691
+ pitch_ball_xy=ball_path_raw[-1] if ball_path_raw else np.empty((0, 2)),
692
+ pitch_players_xy=pitch_players_xy if 'pitch_players_xy' in locals() else np.empty((0, 2)),
693
+ players_class_id=all_players.class_id if 'all_players' in locals() else np.array([]),
694
+ pitch_referees_xy=np.empty((0, 2)),
695
+ ball_path=ball_path_raw
696
+ )
697
+ cv2.imwrite(radar_path, radar_frame)
698
  except Exception as e:
699
  print(f"โš ๏ธ Radar view creation failed: {e}")
700
  radar_path = None
701
 
702
+ # Generate summary stats
703
+ progress(0.95, desc="๐Ÿ“ Generating summary report...")
704
+ teams = performance_tracker.get_all_players_by_team()
 
 
 
705
 
706
+ summary_lines = ["โœ… **Analysis Complete!**\n"]
707
+ summary_lines.append(f"- Total Frames: {frame_count}")
708
+ summary_lines.append(f"- Ball Path Points: {len([p for p in ball_path_raw if len(p) > 0])}\n")
709
+
710
+ for team_id in [0, 1]:
711
+ if team_id not in teams:
712
+ continue
713
 
714
+ team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)"
715
+ summary_lines.append(f"\n**{team_name}:**")
716
+ summary_lines.append(f"- Players Tracked: {len(teams[team_id])}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
 
718
+ total_dist = sum(performance_tracker.get_player_stats(pid)['total_distance_meters']
719
+ for pid in teams[team_id])
720
+ avg_dist = total_dist / len(teams[team_id]) if len(teams[team_id]) > 0 else 0
721
+ summary_lines.append(f"- Team Total Distance: {total_dist:.1f}m")
722
+ summary_lines.append(f"- Average Distance per Player: {avg_dist:.1f}m")
723
 
724
+ # Top 3 performers
725
+ player_distances = [(pid, performance_tracker.get_player_stats(pid)['total_distance_meters'])
726
+ for pid in teams[team_id]]
727
+ player_distances.sort(key=lambda x: x[1], reverse=True)
728
+
729
+ summary_lines.append(f"\n **Top Performers:**")
730
+ for i, (pid, dist) in enumerate(player_distances[:3], 1):
731
+ stats = performance_tracker.get_player_stats(pid)
732
+ summary_lines.append(
733
+ f" {i}. Player #{pid}: {dist:.1f}m, "
734
+ f"Avg Speed: {stats['avg_velocity']/100:.2f}m/s"
735
+ )
736
+
737
+ summary_msg = "\n".join(summary_lines)
 
 
 
 
 
738
 
739
+ progress(1.0, desc="โœ… Complete!")
740
 
741
+ return (output_path, comparison_fig, team_heatmaps_path,
742
+ individual_heatmaps_path, radar_path, summary_msg)
 
 
 
 
 
 
743
 
744
  except Exception as e:
745
  error_msg = f"โŒ Error: {str(e)}"
 
748
  traceback.print_exc()
749
  return None, None, None, None, None, error_msg
750
 
751
+
752
  # ==============================================
753
  # GRADIO INTERFACE
754
  # ==============================================
755
+ with gr.Blocks(title="โšฝ Football Performance Analyzer") as iface:
756
+ gr.Markdown("""
757
+ # โšฝ Advanced Football Video Analyzer
758
+ Upload a football match video to get comprehensive performance analytics including:
759
+ - Player tracking with persistent IDs
760
+ - Individual and team heatmaps
761
+ - Distance covered and speed metrics
762
+ - Game-style radar view with ball tracking
763
+ """)
764
+
765
+ with gr.Row():
766
+ video_input = gr.Video(label="Upload Football Video")
767
+
768
+ analyze_btn = gr.Button("๐Ÿš€ Analyze Video", variant="primary", size="lg")
769
+
770
+ with gr.Row():
771
+ status_output = gr.Textbox(label="Analysis Status & Summary", lines=20)
772
+
773
+ with gr.Tabs():
774
+ with gr.Tab("๐Ÿ“น Annotated Video"):
775
+ video_output = gr.Video(label="Annotated Video with Player Tracking")
776
+
777
+ with gr.Tab("๐Ÿ“Š Performance Comparison"):
778
+ comparison_output = gr.Plot(label="Team Performance Metrics")
779
+
780
+ with gr.Tab("๐Ÿ—บ๏ธ Team Heatmaps"):
781
+ team_heatmaps_output = gr.Image(label="Combined Team Activity Heatmaps")
782
+
783
+ with gr.Tab("๐Ÿ‘ค Individual Heatmaps"):
784
+ individual_heatmaps_output = gr.Image(label="Top Players Individual Heatmaps")
785
+
786
+ with gr.Tab("๐ŸŽฎ Game Radar View"):
787
+ radar_output = gr.Image(label="Game-Style Radar with Ball Trail")
788
+
789
+ analyze_btn.click(
790
+ fn=analyze_football_video,
791
+ inputs=[video_input],
792
+ outputs=[
793
+ video_output,
794
+ comparison_output,
795
+ team_heatmaps_output,
796
+ individual_heatmaps_output,
797
+ radar_output,
798
+ status_output
799
+ ]
800
+ )
801
+
802
+ gr.Markdown("""
803
+ ---
804
+ ### ๐Ÿ“‹ Features:
805
+ - **Persistent Player Tracking**: IDs remain consistent throughout the video
806
+ - **Performance Metrics**: Distance covered, average/max speed, zone activity
807
+ - **Team Heatmaps**: Visualize team positioning and movement patterns
808
+ - **Individual Analysis**: Top 6 players by distance with detailed heatmaps
809
+ - **Professional Visualization**: Game-style radar view with ball trail effects
810
+
811
+ ### ๐ŸŽฏ Perfect for:
812
+ - Coaching staff analysis
813
+ - Player performance reports
814
+ - Tactical review sessions
815
+ - Scouting and recruitment
816
+ """)
817
 
818
  if __name__ == "__main__":
819
  iface.launch()