HockeyAI / src /streamlit_app.py
omarkashif's picture
Update src/streamlit_app.py
781ed6c verified
import os
import subprocess
import sys
import tempfile
import pandas as pd
import numpy as np
import cv2
import torch
import torch.nn as nn
from collections import defaultdict
from scipy.spatial.distance import cosine
from torchvision import models, transforms
from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors
from huggingface_hub import hf_hub_download
import streamlit as st
# --- EMERGENCY BOOTSTRAP ---
try:
import cv2
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "opencv-python-headless"])
import cv2
# --- PAGE CONFIG ---
st.set_page_config(page_title="Hockey Analytics Re-ID", layout="wide")
st.title("πŸ’ Hockey Player Re-ID & Possession Tracker")
# --- CONFIGURATION ---
PLAYER_CONF = 0.3
PUCK_CONF = 0.15
DIST_THRESHOLD = 130
MAX_POOL_SIZE = 10
SIMILARITY_THRESHOLD = 0.75
TARGET_FPS = 15
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- MODEL LOADING ---
@st.cache_resource
def load_models():
player_path = hf_hub_download(repo_id="omarkashif/hockey-yolo-models", filename="best v8s.pt")
puck_path = hf_hub_download(repo_id="omarkashif/hockey-yolo-models", filename="best (puck only).pt")
player_model = YOLO(player_path)
puck_model = YOLO(puck_path)
reid_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
reid_model.fc = nn.Identity()
reid_model.to(device).eval()
return player_model, puck_model, reid_model
player_model, puck_model, reid_model = load_models()
# --- RE-ID EMBEDDING HELPERS ---
reid_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((128, 64)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_embedding(img):
if img.size == 0: return None
img_t = reid_transform(img).unsqueeze(0).to(device)
with torch.no_grad():
features = reid_model(img_t)
return features.cpu().numpy().flatten()
# --- CORE PROCESSING FUNCTION ---
def process_video(input_path):
# # 1. Open Input
# cap = cv2.VideoCapture(input_path)
# if not cap.isOpened():
# st.error("OpenCV could not open the input video.")
# return None, None
# # Get accurate dimensions from actual first frame
# ret, first_frame = cap.read()
# if not ret:
# st.error("Failed to read video stream.")
# return None, None
# h, w, _ = first_frame.shape
# orig_fps = cap.get(cv2.CAP_PROP_FPS)
# frame_skip = max(1, int(orig_fps / TARGET_FPS))
# cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Reset to beginning
# # 2. Setup Intermediate Writer (XVID AVI is most stable in Docker)
# tfile_raw = tempfile.NamedTemporaryFile(delete=False, suffix='.avi')
# fourcc = cv2.VideoWriter_fourcc(*'XVID')
# out = cv2.VideoWriter(tfile_raw.name, fourcc, TARGET_FPS, (w, h))
# --- STEP 0: PRE-PROCESS INPUT (Codec Fix) ---
# We convert the input to a standard H.264 MP4 to fix AV1/codec errors
sanitized_input = os.path.join(tempfile.gettempdir(), "sanitized_input.mp4")
st.info("Checking video compatibility...")
try:
# This command re-encodes the input to a very standard H.264 format
# which OpenCV is guaranteed to read on any platform.
subprocess.run([
'ffmpeg', '-y', '-i', input_path,
'-c:v', 'libx264', '-preset', 'ultrafast', '-crf', '28',
'-an', sanitized_input # Remove audio to speed up processing
], check=True, capture_output=True)
input_path = sanitized_input
except Exception as e:
st.warning(f"Pre-processing skipped or failed: {e}. Attempting direct read...")
# --- STEP 1: OPEN SANITIZED INPUT ---
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
st.error("OpenCV could not open the video. Try a different format.")
return None, None
ret, first_frame = cap.read()
if not ret:
st.error("Failed to read video stream.")
cap.release()
return None, None
h, w, _ = first_frame.shape
orig_fps = cap.get(cv2.CAP_PROP_FPS)
frame_skip = max(1, int(orig_fps / TARGET_FPS))
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
# Setup Intermediate Writer
tfile_raw = tempfile.NamedTemporaryFile(delete=False, suffix='.avi')
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(tfile_raw.name, fourcc, TARGET_FPS, (w, h))
id_mapping = {}
appearance_pool = {}
frame_data = []
raw_puck_positions = []
total_possession = defaultdict(float)
progress_bar = st.progress(0)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 3. Step 1: Detection & Tracking
count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
if count % frame_skip == 0:
p_results = player_model.track(frame, persist=True, tracker="bytetrack.yaml", conf=PLAYER_CONF, verbose=False)
pk_results = puck_model.predict(frame, conf=PUCK_CONF, imgsz=640, verbose=False)
current_players = {}
if p_results[0].boxes.id is not None:
boxes = p_results[0].boxes.xyxy.cpu().numpy()
ids = p_results[0].boxes.id.int().cpu().numpy()
for box, tracker_id in zip(boxes, ids):
x1, y1, x2, y2 = map(int, box)
crop = frame[max(0, y1):min(h, y2), max(0, x1):min(w, x2)]
if tracker_id not in id_mapping:
if crop.size > 0:
current_feat = get_embedding(crop)
best_match_id, max_sim = None, -1
for p_id, p_feat in appearance_pool.items():
sim = 1 - cosine(current_feat, p_feat)
if sim > max_sim:
max_sim, best_match_id = sim, p_id
if best_match_id is not None and max_sim > SIMILARITY_THRESHOLD:
id_mapping[tracker_id] = best_match_id
elif len(appearance_pool) < MAX_POOL_SIZE:
new_id = len(appearance_pool) + 1
appearance_pool[new_id] = current_feat
id_mapping[tracker_id] = new_id
else:
id_mapping[tracker_id] = best_match_id
m_id = id_mapping.get(tracker_id, tracker_id)
feet = ((box[0] + box[2]) / 2, box[3])
current_players[m_id] = {"feet": feet, "box": box}
p_pos = None
if len(pk_results[0].boxes) > 0:
pb = pk_results[0].boxes.xyxy.cpu().numpy()[0]
p_pos = ((pb[0] + pb[2]) / 2, (pb[1] + pb[3]) / 2)
raw_puck_positions.append(p_pos)
frame_data.append({'players': current_players, 'frame': frame.copy()})
progress_bar.progress(min(count / total_frames, 1.0))
count += 1
cap.release()
# 4. Step 2: Analytics & Final Rendering
# Puck interpolation logic
processed_pucks = []
for i in range(len(raw_puck_positions)):
if raw_puck_positions[i] is not None:
processed_pucks.append(raw_puck_positions[i])
else:
prev_v = next((raw_puck_positions[j] for j in range(i-1, -1, -1) if raw_puck_positions[j]), None)
next_v = next((raw_puck_positions[j] for j in range(i+1, len(raw_puck_positions)) if raw_puck_positions[j]), None)
processed_pucks.append(((prev_v[0]+next_v[0])/2, (prev_v[1]+next_v[1])/2) if prev_v and next_v else None)
for i, data in enumerate(frame_data):
frame, players, puck = data['frame'], data['players'], processed_pucks[i]
annotator = Annotator(frame, line_width=2)
curr_poss = None
if puck:
min_dist = float('inf')
for pid, pdata in players.items():
dist = np.linalg.norm(np.array(puck) - np.array(pdata['feet']))
if dist < min_dist and dist < DIST_THRESHOLD:
min_dist, curr_poss = dist, pid
if curr_poss:
total_possession[curr_poss] += (1 / TARGET_FPS)
for pid, pdata in players.items():
is_poss = (curr_poss == pid)
color = (0, 255, 0) if is_poss else colors(pid, True)
annotator.box_label(pdata['box'], f"ID {pid}", color=color)
if is_poss:
cv2.circle(frame, (int(pdata['feet'][0]), int(pdata['feet'][1])), 35, (0, 255, 0), 2)
if puck: cv2.circle(frame, (int(puck[0]), int(puck[1])), 5, (0, 0, 255), -1)
# Ensure dimensions match Writer strictly
if frame.shape[0] != h or frame.shape[1] != w:
frame = cv2.resize(frame, (w, h))
out.write(frame)
# out.release()
# # 5. Step 3: FFmpeg Conversion
# final_output = os.path.join(os.getcwd(), "processed_video_final.mp4")
# if not os.path.exists(tfile_raw.name) or os.path.getsize(tfile_raw.name) == 0:
# st.error("Internal Error: Output video stream is empty.")
# return None, None
# try:
# cmd = [
# 'ffmpeg', '-y', '-i', tfile_raw.name,
# '-c:v', 'libx264', '-pix_fmt', 'yuv420p',
# '-preset', 'ultrafast', '-crf', '23',
# final_output
# ]
# subprocess.run(cmd, check=True, capture_output=True)
# except subprocess.CalledProcessError as e:
# st.error(f"FFmpeg Error: {e.stderr.decode()}")
# return None, None
# finally:
# if os.path.exists(tfile_raw.name): os.remove(tfile_raw.name)
# return final_output, total_possession
out.release()
cap.release()
# --- STEP 3: FINAL WEB ENCODING ---
final_output = os.path.join(os.getcwd(), "processed_video_final.mp4")
try:
cmd = [
'ffmpeg', '-y', '-i', tfile_raw.name,
'-c:v', 'libx264', '-pix_fmt', 'yuv420p',
'-preset', 'ultrafast', '-crf', '23',
final_output
]
subprocess.run(cmd, check=True, capture_output=True)
except Exception as e:
st.error(f"Final encoding failed: {e}")
return None, None
finally:
# Cleanup temp files
if os.path.exists(tfile_raw.name): os.remove(tfile_raw.name)
if os.path.exists(sanitized_input): os.remove(sanitized_input)
return final_output, total_possession
# --- UI LOGIC ---
uploaded_file = st.file_uploader("Upload Your Video", type=['mp4', 'avi', 'mov'])
temp_input_path = None
if uploaded_file:
# Safe path for uploads
temp_input_path = os.path.join(tempfile.gettempdir(), "user_upload_input.mp4")
with open(temp_input_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success("Upload ready!")
# 2. Trigger Processing
if temp_input_path:
if st.button("πŸš€ Run Analytics", use_container_width=True):
with st.spinner("Processing... This may take a minute depending on video length."):
res_path, poss_data = process_video(temp_input_path)
if res_path:
st.balloons()
st.success("βœ… Analysis Complete!")
col_v, col_t = st.columns([2, 1])
with col_v:
with open(res_path, "rb") as f:
video_bytes = f.read()
st.video(video_bytes)
st.download_button("πŸ“₯ Download Analyzed Video", video_bytes, "analyzed_hockey.mp4", "video/mp4")
with col_t:
st.subheader("πŸ“Š Possession Time")
if poss_data:
df = pd.DataFrame(poss_data.items(), columns=['ID', 'Seconds'])
df = df.sort_values('Seconds', ascending=False).reset_index(drop=True)
df['Time'] = df['Seconds'].apply(lambda x: f"{x:.2f}s")
st.table(df[['ID', 'Time']])
else:
st.warning("No possession detected.")
else:
st.warning("Please upload a video or enable the sample video toggle to begin.")
# # --- UI LOGIC ---
# uploaded_file = st.file_uploader("Choose a hockey video clip...", type=['mp4', 'avi', 'mov'])
# if uploaded_file:
# # Use a fixed temp path to ensure OpenCV access permissions
# temp_input_path = os.path.join(tempfile.gettempdir(), "user_upload.mp4")
# with open(temp_input_path, "wb") as f:
# f.write(uploaded_file.getbuffer())
# if st.button("πŸš€ Run Analytics", use_container_width=True):
# with st.spinner("Processing... This may take a minute."):
# res_path, poss_data = process_video(temp_input_path)
# if res_path:
# st.success("Analysis Complete!")
# col_v, col_t = st.columns([2, 1])
# with col_v:
# with open(res_path, "rb") as f:
# video_bytes = f.read()
# st.video(video_bytes)
# st.download_button("πŸ“₯ Download Video", video_bytes, "analyzed_hockey.mp4", "video/mp4")
# with col_t:
# st.subheader("Possession Time")
# df = pd.DataFrame(poss_data.items(), columns=['ID', 'Seconds'])
# df = df.sort_values('Seconds', ascending=False)
# st.table(df)