Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import pandas as pd | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| from collections import defaultdict | |
| from scipy.spatial.distance import cosine | |
| from torchvision import models, transforms | |
| from ultralytics import YOLO | |
| from ultralytics.utils.plotting import Annotator, colors | |
| from huggingface_hub import hf_hub_download | |
| import streamlit as st | |
| # --- EMERGENCY BOOTSTRAP --- | |
| try: | |
| import cv2 | |
| except ImportError: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "opencv-python-headless"]) | |
| import cv2 | |
| # --- PAGE CONFIG --- | |
| st.set_page_config(page_title="Hockey Analytics Re-ID", layout="wide") | |
| st.title("π Hockey Player Re-ID & Possession Tracker") | |
| # --- CONFIGURATION --- | |
| PLAYER_CONF = 0.3 | |
| PUCK_CONF = 0.15 | |
| DIST_THRESHOLD = 130 | |
| MAX_POOL_SIZE = 10 | |
| SIMILARITY_THRESHOLD = 0.75 | |
| TARGET_FPS = 15 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- MODEL LOADING --- | |
| def load_models(): | |
| player_path = hf_hub_download(repo_id="omarkashif/hockey-yolo-models", filename="best v8s.pt") | |
| puck_path = hf_hub_download(repo_id="omarkashif/hockey-yolo-models", filename="best (puck only).pt") | |
| player_model = YOLO(player_path) | |
| puck_model = YOLO(puck_path) | |
| reid_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) | |
| reid_model.fc = nn.Identity() | |
| reid_model.to(device).eval() | |
| return player_model, puck_model, reid_model | |
| player_model, puck_model, reid_model = load_models() | |
| # --- RE-ID EMBEDDING HELPERS --- | |
| reid_transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((128, 64)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def get_embedding(img): | |
| if img.size == 0: return None | |
| img_t = reid_transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| features = reid_model(img_t) | |
| return features.cpu().numpy().flatten() | |
| # --- CORE PROCESSING FUNCTION --- | |
| def process_video(input_path): | |
| # # 1. Open Input | |
| # cap = cv2.VideoCapture(input_path) | |
| # if not cap.isOpened(): | |
| # st.error("OpenCV could not open the input video.") | |
| # return None, None | |
| # # Get accurate dimensions from actual first frame | |
| # ret, first_frame = cap.read() | |
| # if not ret: | |
| # st.error("Failed to read video stream.") | |
| # return None, None | |
| # h, w, _ = first_frame.shape | |
| # orig_fps = cap.get(cv2.CAP_PROP_FPS) | |
| # frame_skip = max(1, int(orig_fps / TARGET_FPS)) | |
| # cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Reset to beginning | |
| # # 2. Setup Intermediate Writer (XVID AVI is most stable in Docker) | |
| # tfile_raw = tempfile.NamedTemporaryFile(delete=False, suffix='.avi') | |
| # fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
| # out = cv2.VideoWriter(tfile_raw.name, fourcc, TARGET_FPS, (w, h)) | |
| # --- STEP 0: PRE-PROCESS INPUT (Codec Fix) --- | |
| # We convert the input to a standard H.264 MP4 to fix AV1/codec errors | |
| sanitized_input = os.path.join(tempfile.gettempdir(), "sanitized_input.mp4") | |
| st.info("Checking video compatibility...") | |
| try: | |
| # This command re-encodes the input to a very standard H.264 format | |
| # which OpenCV is guaranteed to read on any platform. | |
| subprocess.run([ | |
| 'ffmpeg', '-y', '-i', input_path, | |
| '-c:v', 'libx264', '-preset', 'ultrafast', '-crf', '28', | |
| '-an', sanitized_input # Remove audio to speed up processing | |
| ], check=True, capture_output=True) | |
| input_path = sanitized_input | |
| except Exception as e: | |
| st.warning(f"Pre-processing skipped or failed: {e}. Attempting direct read...") | |
| # --- STEP 1: OPEN SANITIZED INPUT --- | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| st.error("OpenCV could not open the video. Try a different format.") | |
| return None, None | |
| ret, first_frame = cap.read() | |
| if not ret: | |
| st.error("Failed to read video stream.") | |
| cap.release() | |
| return None, None | |
| h, w, _ = first_frame.shape | |
| orig_fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_skip = max(1, int(orig_fps / TARGET_FPS)) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, 0) | |
| # Setup Intermediate Writer | |
| tfile_raw = tempfile.NamedTemporaryFile(delete=False, suffix='.avi') | |
| fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
| out = cv2.VideoWriter(tfile_raw.name, fourcc, TARGET_FPS, (w, h)) | |
| id_mapping = {} | |
| appearance_pool = {} | |
| frame_data = [] | |
| raw_puck_positions = [] | |
| total_possession = defaultdict(float) | |
| progress_bar = st.progress(0) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # 3. Step 1: Detection & Tracking | |
| count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: break | |
| if count % frame_skip == 0: | |
| p_results = player_model.track(frame, persist=True, tracker="bytetrack.yaml", conf=PLAYER_CONF, verbose=False) | |
| pk_results = puck_model.predict(frame, conf=PUCK_CONF, imgsz=640, verbose=False) | |
| current_players = {} | |
| if p_results[0].boxes.id is not None: | |
| boxes = p_results[0].boxes.xyxy.cpu().numpy() | |
| ids = p_results[0].boxes.id.int().cpu().numpy() | |
| for box, tracker_id in zip(boxes, ids): | |
| x1, y1, x2, y2 = map(int, box) | |
| crop = frame[max(0, y1):min(h, y2), max(0, x1):min(w, x2)] | |
| if tracker_id not in id_mapping: | |
| if crop.size > 0: | |
| current_feat = get_embedding(crop) | |
| best_match_id, max_sim = None, -1 | |
| for p_id, p_feat in appearance_pool.items(): | |
| sim = 1 - cosine(current_feat, p_feat) | |
| if sim > max_sim: | |
| max_sim, best_match_id = sim, p_id | |
| if best_match_id is not None and max_sim > SIMILARITY_THRESHOLD: | |
| id_mapping[tracker_id] = best_match_id | |
| elif len(appearance_pool) < MAX_POOL_SIZE: | |
| new_id = len(appearance_pool) + 1 | |
| appearance_pool[new_id] = current_feat | |
| id_mapping[tracker_id] = new_id | |
| else: | |
| id_mapping[tracker_id] = best_match_id | |
| m_id = id_mapping.get(tracker_id, tracker_id) | |
| feet = ((box[0] + box[2]) / 2, box[3]) | |
| current_players[m_id] = {"feet": feet, "box": box} | |
| p_pos = None | |
| if len(pk_results[0].boxes) > 0: | |
| pb = pk_results[0].boxes.xyxy.cpu().numpy()[0] | |
| p_pos = ((pb[0] + pb[2]) / 2, (pb[1] + pb[3]) / 2) | |
| raw_puck_positions.append(p_pos) | |
| frame_data.append({'players': current_players, 'frame': frame.copy()}) | |
| progress_bar.progress(min(count / total_frames, 1.0)) | |
| count += 1 | |
| cap.release() | |
| # 4. Step 2: Analytics & Final Rendering | |
| # Puck interpolation logic | |
| processed_pucks = [] | |
| for i in range(len(raw_puck_positions)): | |
| if raw_puck_positions[i] is not None: | |
| processed_pucks.append(raw_puck_positions[i]) | |
| else: | |
| prev_v = next((raw_puck_positions[j] for j in range(i-1, -1, -1) if raw_puck_positions[j]), None) | |
| next_v = next((raw_puck_positions[j] for j in range(i+1, len(raw_puck_positions)) if raw_puck_positions[j]), None) | |
| processed_pucks.append(((prev_v[0]+next_v[0])/2, (prev_v[1]+next_v[1])/2) if prev_v and next_v else None) | |
| for i, data in enumerate(frame_data): | |
| frame, players, puck = data['frame'], data['players'], processed_pucks[i] | |
| annotator = Annotator(frame, line_width=2) | |
| curr_poss = None | |
| if puck: | |
| min_dist = float('inf') | |
| for pid, pdata in players.items(): | |
| dist = np.linalg.norm(np.array(puck) - np.array(pdata['feet'])) | |
| if dist < min_dist and dist < DIST_THRESHOLD: | |
| min_dist, curr_poss = dist, pid | |
| if curr_poss: | |
| total_possession[curr_poss] += (1 / TARGET_FPS) | |
| for pid, pdata in players.items(): | |
| is_poss = (curr_poss == pid) | |
| color = (0, 255, 0) if is_poss else colors(pid, True) | |
| annotator.box_label(pdata['box'], f"ID {pid}", color=color) | |
| if is_poss: | |
| cv2.circle(frame, (int(pdata['feet'][0]), int(pdata['feet'][1])), 35, (0, 255, 0), 2) | |
| if puck: cv2.circle(frame, (int(puck[0]), int(puck[1])), 5, (0, 0, 255), -1) | |
| # Ensure dimensions match Writer strictly | |
| if frame.shape[0] != h or frame.shape[1] != w: | |
| frame = cv2.resize(frame, (w, h)) | |
| out.write(frame) | |
| # out.release() | |
| # # 5. Step 3: FFmpeg Conversion | |
| # final_output = os.path.join(os.getcwd(), "processed_video_final.mp4") | |
| # if not os.path.exists(tfile_raw.name) or os.path.getsize(tfile_raw.name) == 0: | |
| # st.error("Internal Error: Output video stream is empty.") | |
| # return None, None | |
| # try: | |
| # cmd = [ | |
| # 'ffmpeg', '-y', '-i', tfile_raw.name, | |
| # '-c:v', 'libx264', '-pix_fmt', 'yuv420p', | |
| # '-preset', 'ultrafast', '-crf', '23', | |
| # final_output | |
| # ] | |
| # subprocess.run(cmd, check=True, capture_output=True) | |
| # except subprocess.CalledProcessError as e: | |
| # st.error(f"FFmpeg Error: {e.stderr.decode()}") | |
| # return None, None | |
| # finally: | |
| # if os.path.exists(tfile_raw.name): os.remove(tfile_raw.name) | |
| # return final_output, total_possession | |
| out.release() | |
| cap.release() | |
| # --- STEP 3: FINAL WEB ENCODING --- | |
| final_output = os.path.join(os.getcwd(), "processed_video_final.mp4") | |
| try: | |
| cmd = [ | |
| 'ffmpeg', '-y', '-i', tfile_raw.name, | |
| '-c:v', 'libx264', '-pix_fmt', 'yuv420p', | |
| '-preset', 'ultrafast', '-crf', '23', | |
| final_output | |
| ] | |
| subprocess.run(cmd, check=True, capture_output=True) | |
| except Exception as e: | |
| st.error(f"Final encoding failed: {e}") | |
| return None, None | |
| finally: | |
| # Cleanup temp files | |
| if os.path.exists(tfile_raw.name): os.remove(tfile_raw.name) | |
| if os.path.exists(sanitized_input): os.remove(sanitized_input) | |
| return final_output, total_possession | |
| # --- UI LOGIC --- | |
| uploaded_file = st.file_uploader("Upload Your Video", type=['mp4', 'avi', 'mov']) | |
| temp_input_path = None | |
| if uploaded_file: | |
| # Safe path for uploads | |
| temp_input_path = os.path.join(tempfile.gettempdir(), "user_upload_input.mp4") | |
| with open(temp_input_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.success("Upload ready!") | |
| # 2. Trigger Processing | |
| if temp_input_path: | |
| if st.button("π Run Analytics", use_container_width=True): | |
| with st.spinner("Processing... This may take a minute depending on video length."): | |
| res_path, poss_data = process_video(temp_input_path) | |
| if res_path: | |
| st.balloons() | |
| st.success("β Analysis Complete!") | |
| col_v, col_t = st.columns([2, 1]) | |
| with col_v: | |
| with open(res_path, "rb") as f: | |
| video_bytes = f.read() | |
| st.video(video_bytes) | |
| st.download_button("π₯ Download Analyzed Video", video_bytes, "analyzed_hockey.mp4", "video/mp4") | |
| with col_t: | |
| st.subheader("π Possession Time") | |
| if poss_data: | |
| df = pd.DataFrame(poss_data.items(), columns=['ID', 'Seconds']) | |
| df = df.sort_values('Seconds', ascending=False).reset_index(drop=True) | |
| df['Time'] = df['Seconds'].apply(lambda x: f"{x:.2f}s") | |
| st.table(df[['ID', 'Time']]) | |
| else: | |
| st.warning("No possession detected.") | |
| else: | |
| st.warning("Please upload a video or enable the sample video toggle to begin.") | |
| # # --- UI LOGIC --- | |
| # uploaded_file = st.file_uploader("Choose a hockey video clip...", type=['mp4', 'avi', 'mov']) | |
| # if uploaded_file: | |
| # # Use a fixed temp path to ensure OpenCV access permissions | |
| # temp_input_path = os.path.join(tempfile.gettempdir(), "user_upload.mp4") | |
| # with open(temp_input_path, "wb") as f: | |
| # f.write(uploaded_file.getbuffer()) | |
| # if st.button("π Run Analytics", use_container_width=True): | |
| # with st.spinner("Processing... This may take a minute."): | |
| # res_path, poss_data = process_video(temp_input_path) | |
| # if res_path: | |
| # st.success("Analysis Complete!") | |
| # col_v, col_t = st.columns([2, 1]) | |
| # with col_v: | |
| # with open(res_path, "rb") as f: | |
| # video_bytes = f.read() | |
| # st.video(video_bytes) | |
| # st.download_button("π₯ Download Video", video_bytes, "analyzed_hockey.mp4", "video/mp4") | |
| # with col_t: | |
| # st.subheader("Possession Time") | |
| # df = pd.DataFrame(poss_data.items(), columns=['ID', 'Seconds']) | |
| # df = df.sort_values('Seconds', ascending=False) | |
| # st.table(df) |