Yaawer12's picture
Update app.py
afa08f8 verified
import streamlit as st
import cv2
import tempfile
import numpy as np
import pandas as pd
import plotly.express as px
from ultralytics import YOLO
from PIL import Image
from openai import OpenAI
import time
import os
# --- Page. Configuration ---
st.set_page_config(
page_title="Pro Table Tennis Analyzer",
page_icon="πŸ“",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Custom CSS for Professional UI ---
st.markdown("""
<style>
.main {
background-color: #0e1117;
}
.stButton>button {
width: 100%;
background-color: #ff4b4b;
color: white;
border-radius: 8px;
height: 3em;
}
.stat-box {
background-color: #262730;
padding: 20px;
border-radius: 10px;
border: 1px solid #41444e;
text-align: center;
}
.highlight {
color: #ff4b4b;
font-weight: bold;
}
</style>
""", unsafe_allow_html=True)
# --- Sidebar Controls ---
with st.sidebar:
st.title("πŸ“ Analyzer Config")
st.subheader("Model Settings")
conf_threshold = st.slider("Detection Confidence", 0.1, 1.0, 0.25, help="Lower values detect more objects but may include errors.")
st.markdown("---")
st.subheader("AI Coach (Grok)")
grok_api_key = st.text_input("xAI / Grok API Key", type="password", help="Required only for the AI Coach tab. The video analyzer works without this.")
st.markdown("---")
st.info("πŸ’‘ **Tip:** Use a stable video with a clear view of the table for best results.")
# --- Helper Functions ---
@st.cache_resource
def load_model():
# Load YOLOv8 Nano model (small and fast for CPU environments)
return YOLO('yolov8n.pt')
def process_video(video_path, model, conf_thresh, max_frames=None):
"""
Processes the video to track ball (class 32) and persons (class 0).
Returns stats and the processed video path.
"""
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if max_frames:
total_frames = min(total_frames, max_frames)
# Output setup
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
# Using 'mp4v' codec for compatibility
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
ball_positions = []
player_positions = []
frame_count = 0
progress_bar = st.progress(0)
status_text = st.empty()
while cap.isOpened():
ret, frame = cap.read()
if not ret or (max_frames and frame_count >= max_frames):
break
# Run YOLO inference
results = model.predict(frame, conf=conf_thresh, classes=[0, 32], verbose=False) # 0=person, 32=sports ball
# Draw annotations
annotated_frame = results[0].plot()
# Data extraction for stats
boxes = results[0].boxes
for box in boxes:
cls = int(box.cls[0])
x, y, w, h = box.xywh[0].tolist()
if cls == 32: # Ball
ball_positions.append({'frame': frame_count, 'x': x, 'y': y})
elif cls == 0: # Person
player_positions.append({'frame': frame_count, 'x': x, 'y': y})
out.write(annotated_frame)
frame_count += 1
# Update progress
if frame_count % 10 == 0:
progress = frame_count / total_frames
progress_bar.progress(progress)
status_text.text(f"Processing Frame {frame_count}/{total_frames}...")
cap.release()
out.release()
progress_bar.empty()
status_text.empty()
return output_path, pd.DataFrame(ball_positions), pd.DataFrame(player_positions)
def get_grok_advice(stats_summary, user_question, api_key):
"""
Sends game stats to Grok via OpenAI-compatible API for strategic advice.
"""
if not api_key:
return "⚠️ Please enter a Grok API key in the sidebar to use the AI Coach."
client = OpenAI(
api_key=api_key,
base_url="https://api.x.ai/v1",
)
prompt = f"""
You are a professional Table Tennis Coach.
Here is the data from the match analysis:
{stats_summary}
User Question: {user_question}
Provide concise, strategic, and professional advice.
"""
try:
completion = client.chat.completions.create(
model="grok-beta", # Or "grok-2" depending on availability
messages=[
{"role": "system", "content": "You are an expert Table Tennis Analyst."},
{"role": "user", "content": prompt}
]
)
return completion.choices[0].message.content
except Exception as e:
return f"Error connecting to Grok: {str(e)}"
# --- Main App Logic ---
st.title("πŸ“ Pro Table Tennis Analyzer")
st.markdown("Upload your match footage for computer vision analysis and AI coaching.")
# Tabs for organization
tab1, tab2, tab3 = st.tabs(["πŸ“Ή Video Analysis", "πŸ“Š Match Stats", "πŸ€– AI Coach"])
# Global variable to store processed data across tabs
if 'processed_data' not in st.session_state:
st.session_state.processed_data = None
with tab1:
uploaded_file = st.file_uploader("Upload Match Video (MP4, MOV)", type=['mp4', 'mov'])
if uploaded_file is not None:
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
video_path = tfile.name
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### Original Video")
st.video(video_path)
with col2:
st.markdown("### Analysis Control")
process_mode = st.radio("Processing Mode", ["Preview (First 10s)", "Full Match (Slower)"])
start_btn = st.button("Start Analysis")
if start_btn:
model = load_model()
max_f = 300 if "Preview" in process_mode else None # Approx 10s at 30fps
with st.spinner('Analyzing court dynamics...'):
processed_video, ball_df, player_df = process_video(video_path, model, conf_threshold, max_f)
# Store data in session state
st.session_state.processed_data = {
'video': processed_video,
'ball': ball_df,
'player': player_df
}
st.success("Analysis Complete!")
st.markdown("### Analyzed Output")
st.video(processed_video)
with tab2:
if st.session_state.processed_data:
data = st.session_state.processed_data
ball_df = data['ball']
player_df = data['player']
st.header("Match Statistics")
# Metrics Row
m1, m2, m3 = st.columns(3)
with m1:
st.markdown(f"<div class='stat-box'><h3>Frames Tracked</h3><h2 class='highlight'>{len(ball_df)}</h2></div>", unsafe_allow_html=True)
with m2:
if not ball_df.empty:
avg_height = int(ball_df['y'].mean())
st.markdown(f"<div class='stat-box'><h3>Avg Ball Height</h3><h2 class='highlight'>{avg_height} px</h2></div>", unsafe_allow_html=True)
else:
st.markdown(f"<div class='stat-box'><h3>Ball Detect</h3><h2 class='highlight'>N/A</h2></div>", unsafe_allow_html=True)
with m3:
st.markdown(f"<div class='stat-box'><h3>Players Visible</h3><h2 class='highlight'>{len(player_df)//len(ball_df) if not ball_df.empty else 0}</h2></div>", unsafe_allow_html=True)
st.markdown("---")
# Charts
c1, c2 = st.columns(2)
with c1:
st.subheader("Ball Vertical Trajectory")
if not ball_df.empty:
fig_ball = px.line(ball_df, x='frame', y='y', title='Ball Height over Time (Inverted Y)')
fig_ball.update_yaxes(autorange="reversed") # Y is 0 at top in images
st.plotly_chart(fig_ball, use_container_width=True)
else:
st.info("No ball detected in the processed frames.")
with c2:
st.subheader("Player Position Heatmap")
if not player_df.empty:
fig_heat = px.density_heatmap(player_df, x='x', y='y', nbinsx=20, nbinsy=20, title='Player Movement Density')
fig_heat.update_yaxes(autorange="reversed")
st.plotly_chart(fig_heat, use_container_width=True)
else:
st.info("No players detected.")
else:
st.info("Please process a video in the 'Video Analysis' tab first.")
with tab3:
st.header("πŸ€– AI Coach (Powered by Grok)")
st.write("Ask strategic questions about the match data.")
if st.session_state.processed_data:
data = st.session_state.processed_data
# Create a summary string for the LLM
if not data['ball'].empty:
summary = (f"Analysis Summary: Tracked {len(data['ball'])} frames. "
f"Average ball vertical position: {data['ball']['y'].mean():.2f}. "
f"Player movement density loaded in heatmap.")
else:
summary = "No specific ball tracking data available."
user_input = st.text_area("Ask the Coach:", "Based on the ball trajectory, was the game aggressive or defensive?")
if st.button("Get Advice"):
with st.spinner("Contacting Coach Grok..."):
advice = get_grok_advice(summary, user_input, grok_api_key)
st.markdown("### πŸ’‘ Coach's Feedback")
st.markdown(advice)
else:
st.warning("Please analyze a video first to give the Coach context.")