Simon9 commited on
Commit
633d62b
·
verified ·
1 Parent(s): 421b656

Update pipeline_full.py

Browse files
Files changed (1) hide show
  1. pipeline_full.py +357 -31
pipeline_full.py CHANGED
@@ -311,9 +311,7 @@ def step_siglip_clustering(video_path: str, out_dir: str) -> Dict[str, str]:
311
  imageElement.style.display = 'block';
312
  placeholderText.style.display = 'none';
313
  }}
314
-
315
  var chartElement = document.getElementById('scatter-plot-3d');
316
-
317
  chartElement.on('plotly_click', function(data) {{
318
  var customdata = data.points[0].customdata;
319
  displayImage(customdata);
@@ -741,36 +739,360 @@ def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
741
  }
742
 
743
 
744
- # -------------------- 8. stats-only process_video --------------------
745
 
746
 
747
- def process_video_stats(video_path: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
748
  ensure_models_loaded()
 
 
 
 
 
749
 
750
  tracker = sv.ByteTrack()
751
  tracker.reset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
  stats = {
753
- "distance_covered": defaultdict(float),
754
- "team_classifications": {},
755
- "field_key_points": None,
 
 
756
  }
757
 
758
- frame_gen = sv.get_video_frames_generator(video_path)
759
- for frame_idx, frame in enumerate(frame_gen):
760
- result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
761
- detections = sv.Detections.from_inference(result)
762
- field_result = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
763
- field_key_points = sv.KeyPoints.from_inference(field_result)
764
- if stats["field_key_points"] is None:
765
- stats["field_key_points"] = field_key_points.xy.tolist()
766
- detections = tracker.update_with_detections(detections)
767
- player_crops = [sv.crop_image(frame, xyxy) for xyxy in detections.xyxy]
768
- if player_crops:
769
- preds = TEAM_CLASSIFIER.predict(player_crops)
770
- for tid, team in zip(detections.tracker_id, preds):
771
- stats["team_classifications"][str(tid)] = int(team)
772
- stats["distance_covered"] = dict(stats["distance_covered"])
773
- return stats
774
 
775
 
776
  # -------------------- 9. full pipeline entrypoint --------------------
@@ -787,29 +1109,33 @@ def run_full_pipeline(video_path: str, job_dir: str) -> Dict[str, Any]:
787
 
788
  os.makedirs(job_dir, exist_ok=True)
789
 
790
- update_progress("siglip", 0.15, "Running SigLIP clustering...")
791
  siglip_out = step_siglip_clustering(video_path, os.path.join(job_dir, "siglip"))
792
 
793
- update_progress("team_classifier", 0.30, "Training TeamClassifier...")
794
  train_team_classifier_on_video(video_path)
795
 
796
- update_progress("basic_frames", 0.45, "Generating basic annotated frames...")
797
  basic_paths = step_basic_frames(video_path, os.path.join(job_dir, "frames"))
798
 
799
- update_progress("advanced_views", 0.60, "Generating advanced radar / Voronoi views...")
800
  adv_paths = step_single_frame_advanced(video_path, os.path.join(job_dir, "advanced"))
801
 
802
- update_progress("ball_path", 0.80, "Computing ball path and heatmap...")
803
  ball_paths = step_ball_path(video_path, os.path.join(job_dir, "ball_path"))
804
 
805
- update_progress("stats", 0.90, "Calculating stats...")
806
- stats = process_video_stats(video_path)
 
 
807
 
808
  result = {
809
  "basic": basic_paths,
810
  "advanced": adv_paths,
811
  "ball": ball_paths,
812
- "stats": stats,
 
 
813
  "siglip_html": siglip_out["plot_html"],
814
  }
815
 
 
311
  imageElement.style.display = 'block';
312
  placeholderText.style.display = 'none';
313
  }}
 
314
  var chartElement = document.getElementById('scatter-plot-3d');
 
315
  chartElement.on('plotly_click', function(data) {{
316
  var customdata = data.points[0].customdata;
317
  displayImage(customdata);
 
739
  }
740
 
741
 
742
+ # -------------------- 8. NEW: full-match analysis + event-annotated video --------------------
743
 
744
 
745
+ def step_analyze_and_annotate_video(video_path: str, out_dir: str) -> Dict[str, Any]:
746
+ """
747
+ Single pass over the video that:
748
+ * tracks players & ball
749
+ * computes distance & speed per player (pitch coordinates)
750
+ * estimates ball possession per team & per player
751
+ * detects simple events (pass, tackle/interception, clearance, shot)
752
+ * renders an annotated MP4 with overlays
753
+ """
754
  ensure_models_loaded()
755
+ os.makedirs(out_dir, exist_ok=True)
756
+
757
+ video_info = sv.VideoInfo.from_video_path(video_path)
758
+ fps = video_info.fps
759
+ dt = 1.0 / fps
760
 
761
  tracker = sv.ByteTrack()
762
  tracker.reset()
763
+
764
+ # homography smoothing
765
+ Ms = deque(maxlen=5)
766
+
767
+ # stats
768
+ distance_covered_m = defaultdict(float) # tid -> meters
769
+ possession_time_player = defaultdict(float) # tid -> seconds
770
+ possession_time_team = defaultdict(float) # team_id -> seconds
771
+ team_of_player = {} # tid -> team_id
772
+ events: List[Dict[str, Any]] = []
773
+
774
+ # last positions for speed / distance
775
+ last_pitch_pos: Dict[int, np.ndarray] = {}
776
+ prev_owner_tid: Optional[int] = None
777
+ prev_ball_pos_pitch: Optional[np.ndarray] = None
778
+
779
+ # simple goal centers in pitch coordinates (x is length, y is width)
780
+ goal_centers = {
781
+ 0: np.array([0.0, PITCH_CONFIG.width / 2.0]),
782
+ 1: np.array([PITCH_CONFIG.length, PITCH_CONFIG.width / 2.0]),
783
+ }
784
+
785
+ # annotators
786
+ ellipse_annotator = sv.EllipseAnnotator(
787
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
788
+ thickness=2,
789
+ )
790
+ label_annotator = sv.LabelAnnotator(
791
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
792
+ text_color=sv.Color.from_hex("#000000"),
793
+ text_position=sv.Position.BOTTOM_CENTER,
794
+ )
795
+ triangle_annotator = sv.TriangleAnnotator(
796
+ color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
797
+ )
798
+
799
+ sink_path = os.path.join(out_dir, "annotated_events.mp4")
800
+ sink = sv.VideoSink(sink_path, video_info)
801
+
802
+ # text overlay control
803
+ current_event_text = ""
804
+ event_text_frames_left = 0
805
+ EVENT_TEXT_DURATION_S = 2.0
806
+ EVENT_TEXT_DURATION_FRAMES = int(EVENT_TEXT_DURATION_S * fps)
807
+
808
+ frame_generator = sv.get_video_frames_generator(video_path)
809
+
810
+ with sink:
811
+ for frame_idx, frame in enumerate(tqdm(frame_generator, total=video_info.total_frames,
812
+ desc="analyze + annotate")):
813
+ t = frame_idx * dt
814
+
815
+ # --- detections + tracking ---
816
+ det_result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
817
+ detections = sv.Detections.from_inference(det_result)
818
+
819
+ ball_dets = detections[detections.class_id == BALL_ID]
820
+ ball_dets.xyxy = sv.pad_boxes(xyxy=ball_dets.xyxy, px=10)
821
+
822
+ non_ball = detections[detections.class_id != BALL_ID]
823
+ non_ball = non_ball.with_nms(threshold=0.5, class_agnostic=True)
824
+ tracked = tracker.update_with_detections(non_ball)
825
+
826
+ goalkeepers_dets = tracked[tracked.class_id == GOALKEEPER_ID]
827
+ players_dets = tracked[tracked.class_id == PLAYER_ID]
828
+ referees_dets = tracked[tracked.class_id == REFEREE_ID]
829
+
830
+ # --- field homography ---
831
+ field_result = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
832
+ key_points = sv.KeyPoints.from_inference(field_result)
833
+ filt = key_points.confidence[0] > 0.5
834
+ frame_ref = key_points.xy[0][filt]
835
+ pitch_ref = np.array(PITCH_CONFIG.vertices)[filt]
836
+
837
+ transformer = ViewTransformer(source=frame_ref, target=pitch_ref)
838
+ Ms.append(transformer.m)
839
+ transformer.m = np.mean(np.array(Ms), axis=0)
840
+
841
+ # --- team classification & pitch positions ---
842
+ frame_players_xy_pitch = None
843
+ frame_ball_pos_pitch = None
844
+
845
+ if len(players_dets) > 0:
846
+ crops = [sv.crop_image(frame, xyxy) for xyxy in players_dets.xyxy]
847
+ team_preds = TEAM_CLASSIFIER.predict(crops)
848
+ players_dets.class_id = team_preds # now class_id = team_id (0/1)
849
+
850
+ frame_players_xy_img = players_dets.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
851
+ frame_players_xy_pitch = transformer.transform_points(points=frame_players_xy_img)
852
+
853
+ for tid, team_id, pos_pitch in zip(
854
+ players_dets.tracker_id, players_dets.class_id, frame_players_xy_pitch
855
+ ):
856
+ tid_int = int(tid)
857
+ team_of_player[tid_int] = int(team_id)
858
+
859
+ prev_pos = last_pitch_pos.get(tid_int)
860
+ speed_kmh = 0.0
861
+ if prev_pos is not None:
862
+ dist_m = float(np.linalg.norm(pos_pitch - prev_pos))
863
+ distance_covered_m[tid_int] += dist_m
864
+ speed_kmh = (dist_m / dt) * 3.6 # m/s -> km/h
865
+ last_pitch_pos[tid_int] = pos_pitch
866
+
867
+ if len(ball_dets) > 0:
868
+ frame_ball_xy_img = ball_dets.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
869
+ frame_ball_xy_pitch = transformer.transform_points(points=frame_ball_xy_img)
870
+ frame_ball_pos_pitch = frame_ball_xy_pitch[0]
871
+
872
+ # --- possession owner ---
873
+ owner_tid: Optional[int] = None
874
+ POSSESSION_RADIUS_M = 5.0
875
+
876
+ if frame_ball_pos_pitch is not None and frame_players_xy_pitch is not None:
877
+ dists = np.linalg.norm(frame_players_xy_pitch - frame_ball_pos_pitch, axis=1)
878
+ j = int(np.argmin(dists))
879
+ if dists[j] < POSSESSION_RADIUS_M:
880
+ owner_tid = int(players_dets.tracker_id[j])
881
+
882
+ # accumulate possession time
883
+ if owner_tid is not None:
884
+ possession_time_player[owner_tid] += dt
885
+ owner_team = team_of_player.get(owner_tid)
886
+ if owner_team is not None:
887
+ possession_time_team[owner_team] += dt
888
+
889
+ # --- event detection (simple heuristics) ---
890
+ def register_event(ev: Dict[str, Any], text: str):
891
+ nonlocal current_event_text, event_text_frames_left
892
+ events.append(ev)
893
+ current_event_text = text
894
+ event_text_frames_left = EVENT_TEXT_DURATION_FRAMES
895
+
896
+ # possession change events, passes, tackles, interceptions
897
+ if owner_tid != prev_owner_tid:
898
+ if owner_tid is not None and prev_owner_tid is not None:
899
+ prev_team = team_of_player.get(prev_owner_tid)
900
+ cur_team = team_of_player.get(owner_tid)
901
+
902
+ travel_m = 0.0
903
+ if prev_ball_pos_pitch is not None and frame_ball_pos_pitch is not None:
904
+ travel_m = float(np.linalg.norm(frame_ball_pos_pitch - prev_ball_pos_pitch))
905
+
906
+ MIN_PASS_TRAVEL_M = 3.0
907
+
908
+ if prev_team is not None and cur_team is not None:
909
+ if prev_team == cur_team and travel_m > MIN_PASS_TRAVEL_M:
910
+ # pass
911
+ register_event(
912
+ {
913
+ "type": "pass",
914
+ "t": float(t),
915
+ "from_tid": int(prev_owner_tid),
916
+ "to_tid": int(owner_tid),
917
+ "team_id": int(cur_team),
918
+ "extra": {"distance_m": travel_m},
919
+ },
920
+ f"Pass: #{prev_owner_tid} → #{owner_tid} (Team {cur_team})",
921
+ )
922
+ elif prev_team != cur_team:
923
+ # tackle vs interception
924
+ d_pp = 999.0
925
+ if frame_players_xy_pitch is not None:
926
+ # get current positions
927
+ pos_prev = last_pitch_pos.get(int(prev_owner_tid))
928
+ pos_cur = last_pitch_pos.get(int(owner_tid))
929
+ if pos_prev is not None and pos_cur is not None:
930
+ d_pp = float(np.linalg.norm(pos_prev - pos_cur))
931
+ ev_type = "tackle" if d_pp < 3.0 else "interception"
932
+ label = "Tackle" if ev_type == "tackle" else "Interception"
933
+ register_event(
934
+ {
935
+ "type": ev_type,
936
+ "t": float(t),
937
+ "from_tid": int(prev_owner_tid),
938
+ "to_tid": int(owner_tid),
939
+ "team_id": int(cur_team),
940
+ "extra": {"player_distance_m": d_pp, "ball_travel_m": travel_m},
941
+ },
942
+ f"{label}: #{owner_tid} wins ball from #{prev_owner_tid}",
943
+ )
944
+
945
+ # generic possession-change event
946
+ register_event(
947
+ {
948
+ "type": "possession_change",
949
+ "t": float(t),
950
+ "from_tid": int(prev_owner_tid) if prev_owner_tid is not None else None,
951
+ "to_tid": int(owner_tid) if owner_tid is not None else None,
952
+ "team_id": int(team_of_player.get(owner_tid)) if owner_tid is not None else None,
953
+ "extra": {},
954
+ },
955
+ "" if owner_tid is None else f"Team {team_of_player.get(owner_tid)} in possession",
956
+ )
957
+
958
+ # shot / clearance based on ball speed & direction
959
+ if (
960
+ prev_ball_pos_pitch is not None
961
+ and frame_ball_pos_pitch is not None
962
+ and owner_tid is not None
963
+ ):
964
+ v = (frame_ball_pos_pitch - prev_ball_pos_pitch) / dt # m/s
965
+ speed_mps = float(np.linalg.norm(v))
966
+ speed_kmh = speed_mps * 3.6
967
+ HIGH_SPEED_KMH = 18.0
968
+
969
+ if speed_kmh > HIGH_SPEED_KMH:
970
+ shooter_team = team_of_player.get(owner_tid)
971
+ if shooter_team is not None:
972
+ target_goal = goal_centers[1 - shooter_team]
973
+ direction = target_goal - frame_ball_pos_pitch
974
+ cos_angle = float(
975
+ np.dot(v, direction)
976
+ / (np.linalg.norm(v) * np.linalg.norm(direction) + 1e-6)
977
+ )
978
+
979
+ if cos_angle > 0.8:
980
+ register_event(
981
+ {
982
+ "type": "shot",
983
+ "t": float(t),
984
+ "from_tid": int(owner_tid),
985
+ "to_tid": None,
986
+ "team_id": int(shooter_team),
987
+ "extra": {"speed_kmh": speed_kmh},
988
+ },
989
+ f"Shot by #{owner_tid} (Team {shooter_team}) – {speed_kmh:.1f} km/h",
990
+ )
991
+ else:
992
+ register_event(
993
+ {
994
+ "type": "clearance",
995
+ "t": float(t),
996
+ "from_tid": int(owner_tid),
997
+ "to_tid": None,
998
+ "team_id": int(shooter_team),
999
+ "extra": {"speed_kmh": speed_kmh},
1000
+ },
1001
+ f"Clearance by #{owner_tid} (Team {shooter_team})",
1002
+ )
1003
+
1004
+ prev_owner_tid = owner_tid
1005
+ prev_ball_pos_pitch = frame_ball_pos_pitch
1006
+
1007
+ # --- frame drawing ---
1008
+ annotated = frame.copy()
1009
+
1010
+ # build labels for players: id + speed + distance
1011
+ player_labels: List[str] = []
1012
+ if frame_players_xy_pitch is not None and len(players_dets) > 0:
1013
+ for tid, pos_pitch in zip(players_dets.tracker_id, frame_players_xy_pitch):
1014
+ tid_int = int(tid)
1015
+ prev_pos = last_pitch_pos.get(tid_int)
1016
+ speed_kmh = 0.0
1017
+ if prev_pos is not None:
1018
+ dist_m = float(np.linalg.norm(pos_pitch - prev_pos))
1019
+ speed_kmh = (dist_m / dt) * 3.6
1020
+ d_total = distance_covered_m[tid_int]
1021
+ team_id = team_of_player.get(tid_int, -1)
1022
+ player_labels.append(
1023
+ f"#{tid_int} T{team_id} {speed_kmh:4.1f} km/h {d_total:.1f} m"
1024
+ )
1025
+
1026
+ annotated = ellipse_annotator.annotate(
1027
+ scene=annotated, detections=players_dets
1028
+ )
1029
+ annotated = label_annotator.annotate(
1030
+ scene=annotated, detections=players_dets, labels=player_labels
1031
+ )
1032
+
1033
+ # draw ball
1034
+ annotated = triangle_annotator.annotate(scene=annotated, detections=ball_dets)
1035
+
1036
+ # --- HUD: possession percentages ---
1037
+ total_poss_time = sum(possession_time_team.values()) + 1e-6
1038
+ team0_pct = 100.0 * possession_time_team.get(0, 0.0) / total_poss_time
1039
+ team1_pct = 100.0 * possession_time_team.get(1, 0.0) / total_poss_time
1040
+
1041
+ hud_text = f"Team 0 Ball Control: {team0_pct:5.2f}% Team 1 Ball Control: {team1_pct:5.2f}%"
1042
+ cv2.rectangle(
1043
+ annotated,
1044
+ (20, annotated.shape[0] - 60),
1045
+ (annotated.shape[1] - 20, annotated.shape[0] - 20),
1046
+ (255, 255, 255),
1047
+ -1,
1048
+ )
1049
+ cv2.putText(
1050
+ annotated,
1051
+ hud_text,
1052
+ (30, annotated.shape[0] - 30),
1053
+ cv2.FONT_HERSHEY_SIMPLEX,
1054
+ 0.8,
1055
+ (0, 0, 0),
1056
+ 2,
1057
+ cv2.LINE_AA,
1058
+ )
1059
+
1060
+ # --- event banner ---
1061
+ if event_text_frames_left > 0 and current_event_text:
1062
+ cv2.rectangle(annotated, (20, 20), (annotated.shape[1] - 20, 90), (255, 255, 255), -1)
1063
+ cv2.putText(
1064
+ annotated,
1065
+ current_event_text,
1066
+ (30, 70),
1067
+ cv2.FONT_HERSHEY_SIMPLEX,
1068
+ 1.0,
1069
+ (0, 0, 0),
1070
+ 2,
1071
+ cv2.LINE_AA,
1072
+ )
1073
+ event_text_frames_left -= 1
1074
+
1075
+ sink.write_frame(annotated)
1076
+
1077
+ # finalize stats
1078
+ total_poss = sum(possession_time_team.values()) + 1e-6
1079
+ possession_percent_team = {
1080
+ int(team): 100.0 * t_sec / total_poss for team, t_sec in possession_time_team.items()
1081
+ }
1082
+
1083
  stats = {
1084
+ "distance_covered_m": {str(tid): float(d) for tid, d in distance_covered_m.items()},
1085
+ "possession_time_player_s": {str(tid): float(t_sec) for tid, t_sec in possession_time_player.items()},
1086
+ "possession_time_team_s": {int(team): float(t_sec) for team, t_sec in possession_time_team.items()},
1087
+ "possession_percent_team": possession_percent_team,
1088
+ "team_of_player": {str(tid): int(team) for tid, team in team_of_player.items()},
1089
  }
1090
 
1091
+ return {
1092
+ "annotated_video": sink_path,
1093
+ "stats": stats,
1094
+ "events": events,
1095
+ }
 
 
 
 
 
 
 
 
 
 
 
1096
 
1097
 
1098
  # -------------------- 9. full pipeline entrypoint --------------------
 
1109
 
1110
  os.makedirs(job_dir, exist_ok=True)
1111
 
1112
+ update_progress("siglip", 0.10, "Running SigLIP clustering...")
1113
  siglip_out = step_siglip_clustering(video_path, os.path.join(job_dir, "siglip"))
1114
 
1115
+ update_progress("team_classifier", 0.25, "Training TeamClassifier...")
1116
  train_team_classifier_on_video(video_path)
1117
 
1118
+ update_progress("basic_frames", 0.35, "Generating basic annotated frames...")
1119
  basic_paths = step_basic_frames(video_path, os.path.join(job_dir, "frames"))
1120
 
1121
+ update_progress("advanced_views", 0.45, "Generating advanced radar / Voronoi views...")
1122
  adv_paths = step_single_frame_advanced(video_path, os.path.join(job_dir, "advanced"))
1123
 
1124
+ update_progress("ball_path", 0.60, "Computing ball path and heatmap...")
1125
  ball_paths = step_ball_path(video_path, os.path.join(job_dir, "ball_path"))
1126
 
1127
+ update_progress("events_video", 0.80, "Analyzing match and rendering event-annotated video...")
1128
+ analysis_out = step_analyze_and_annotate_video(
1129
+ video_path, os.path.join(job_dir, "analysis")
1130
+ )
1131
 
1132
  result = {
1133
  "basic": basic_paths,
1134
  "advanced": adv_paths,
1135
  "ball": ball_paths,
1136
+ "stats": analysis_out["stats"],
1137
+ "events": analysis_out["events"],
1138
+ "annotated_video": analysis_out["annotated_video"],
1139
  "siglip_html": siglip_out["plot_html"],
1140
  }
1141