Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import tempfile | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| from openai import OpenAI | |
| import time | |
| import os | |
| # --- Page. Configuration --- | |
| st.set_page_config( | |
| page_title="Pro Table Tennis Analyzer", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # --- Custom CSS for Professional UI --- | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| background-color: #0e1117; | |
| } | |
| .stButton>button { | |
| width: 100%; | |
| background-color: #ff4b4b; | |
| color: white; | |
| border-radius: 8px; | |
| height: 3em; | |
| } | |
| .stat-box { | |
| background-color: #262730; | |
| padding: 20px; | |
| border-radius: 10px; | |
| border: 1px solid #41444e; | |
| text-align: center; | |
| } | |
| .highlight { | |
| color: #ff4b4b; | |
| font-weight: bold; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- Sidebar Controls --- | |
| with st.sidebar: | |
| st.title("π Analyzer Config") | |
| st.subheader("Model Settings") | |
| conf_threshold = st.slider("Detection Confidence", 0.1, 1.0, 0.25, help="Lower values detect more objects but may include errors.") | |
| st.markdown("---") | |
| st.subheader("AI Coach (Grok)") | |
| grok_api_key = st.text_input("xAI / Grok API Key", type="password", help="Required only for the AI Coach tab. The video analyzer works without this.") | |
| st.markdown("---") | |
| st.info("π‘ **Tip:** Use a stable video with a clear view of the table for best results.") | |
| # --- Helper Functions --- | |
| def load_model(): | |
| # Load YOLOv8 Nano model (small and fast for CPU environments) | |
| return YOLO('yolov8n.pt') | |
| def process_video(video_path, model, conf_thresh, max_frames=None): | |
| """ | |
| Processes the video to track ball (class 32) and persons (class 0). | |
| Returns stats and the processed video path. | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if max_frames: | |
| total_frames = min(total_frames, max_frames) | |
| # Output setup | |
| output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name | |
| # Using 'mp4v' codec for compatibility | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| ball_positions = [] | |
| player_positions = [] | |
| frame_count = 0 | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret or (max_frames and frame_count >= max_frames): | |
| break | |
| # Run YOLO inference | |
| results = model.predict(frame, conf=conf_thresh, classes=[0, 32], verbose=False) # 0=person, 32=sports ball | |
| # Draw annotations | |
| annotated_frame = results[0].plot() | |
| # Data extraction for stats | |
| boxes = results[0].boxes | |
| for box in boxes: | |
| cls = int(box.cls[0]) | |
| x, y, w, h = box.xywh[0].tolist() | |
| if cls == 32: # Ball | |
| ball_positions.append({'frame': frame_count, 'x': x, 'y': y}) | |
| elif cls == 0: # Person | |
| player_positions.append({'frame': frame_count, 'x': x, 'y': y}) | |
| out.write(annotated_frame) | |
| frame_count += 1 | |
| # Update progress | |
| if frame_count % 10 == 0: | |
| progress = frame_count / total_frames | |
| progress_bar.progress(progress) | |
| status_text.text(f"Processing Frame {frame_count}/{total_frames}...") | |
| cap.release() | |
| out.release() | |
| progress_bar.empty() | |
| status_text.empty() | |
| return output_path, pd.DataFrame(ball_positions), pd.DataFrame(player_positions) | |
| def get_grok_advice(stats_summary, user_question, api_key): | |
| """ | |
| Sends game stats to Grok via OpenAI-compatible API for strategic advice. | |
| """ | |
| if not api_key: | |
| return "β οΈ Please enter a Grok API key in the sidebar to use the AI Coach." | |
| client = OpenAI( | |
| api_key=api_key, | |
| base_url="https://api.x.ai/v1", | |
| ) | |
| prompt = f""" | |
| You are a professional Table Tennis Coach. | |
| Here is the data from the match analysis: | |
| {stats_summary} | |
| User Question: {user_question} | |
| Provide concise, strategic, and professional advice. | |
| """ | |
| try: | |
| completion = client.chat.completions.create( | |
| model="grok-beta", # Or "grok-2" depending on availability | |
| messages=[ | |
| {"role": "system", "content": "You are an expert Table Tennis Analyst."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| return f"Error connecting to Grok: {str(e)}" | |
| # --- Main App Logic --- | |
| st.title("π Pro Table Tennis Analyzer") | |
| st.markdown("Upload your match footage for computer vision analysis and AI coaching.") | |
| # Tabs for organization | |
| tab1, tab2, tab3 = st.tabs(["πΉ Video Analysis", "π Match Stats", "π€ AI Coach"]) | |
| # Global variable to store processed data across tabs | |
| if 'processed_data' not in st.session_state: | |
| st.session_state.processed_data = None | |
| with tab1: | |
| uploaded_file = st.file_uploader("Upload Match Video (MP4, MOV)", type=['mp4', 'mov']) | |
| if uploaded_file is not None: | |
| tfile = tempfile.NamedTemporaryFile(delete=False) | |
| tfile.write(uploaded_file.read()) | |
| video_path = tfile.name | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.markdown("### Original Video") | |
| st.video(video_path) | |
| with col2: | |
| st.markdown("### Analysis Control") | |
| process_mode = st.radio("Processing Mode", ["Preview (First 10s)", "Full Match (Slower)"]) | |
| start_btn = st.button("Start Analysis") | |
| if start_btn: | |
| model = load_model() | |
| max_f = 300 if "Preview" in process_mode else None # Approx 10s at 30fps | |
| with st.spinner('Analyzing court dynamics...'): | |
| processed_video, ball_df, player_df = process_video(video_path, model, conf_threshold, max_f) | |
| # Store data in session state | |
| st.session_state.processed_data = { | |
| 'video': processed_video, | |
| 'ball': ball_df, | |
| 'player': player_df | |
| } | |
| st.success("Analysis Complete!") | |
| st.markdown("### Analyzed Output") | |
| st.video(processed_video) | |
| with tab2: | |
| if st.session_state.processed_data: | |
| data = st.session_state.processed_data | |
| ball_df = data['ball'] | |
| player_df = data['player'] | |
| st.header("Match Statistics") | |
| # Metrics Row | |
| m1, m2, m3 = st.columns(3) | |
| with m1: | |
| st.markdown(f"<div class='stat-box'><h3>Frames Tracked</h3><h2 class='highlight'>{len(ball_df)}</h2></div>", unsafe_allow_html=True) | |
| with m2: | |
| if not ball_df.empty: | |
| avg_height = int(ball_df['y'].mean()) | |
| st.markdown(f"<div class='stat-box'><h3>Avg Ball Height</h3><h2 class='highlight'>{avg_height} px</h2></div>", unsafe_allow_html=True) | |
| else: | |
| st.markdown(f"<div class='stat-box'><h3>Ball Detect</h3><h2 class='highlight'>N/A</h2></div>", unsafe_allow_html=True) | |
| with m3: | |
| st.markdown(f"<div class='stat-box'><h3>Players Visible</h3><h2 class='highlight'>{len(player_df)//len(ball_df) if not ball_df.empty else 0}</h2></div>", unsafe_allow_html=True) | |
| st.markdown("---") | |
| # Charts | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| st.subheader("Ball Vertical Trajectory") | |
| if not ball_df.empty: | |
| fig_ball = px.line(ball_df, x='frame', y='y', title='Ball Height over Time (Inverted Y)') | |
| fig_ball.update_yaxes(autorange="reversed") # Y is 0 at top in images | |
| st.plotly_chart(fig_ball, use_container_width=True) | |
| else: | |
| st.info("No ball detected in the processed frames.") | |
| with c2: | |
| st.subheader("Player Position Heatmap") | |
| if not player_df.empty: | |
| fig_heat = px.density_heatmap(player_df, x='x', y='y', nbinsx=20, nbinsy=20, title='Player Movement Density') | |
| fig_heat.update_yaxes(autorange="reversed") | |
| st.plotly_chart(fig_heat, use_container_width=True) | |
| else: | |
| st.info("No players detected.") | |
| else: | |
| st.info("Please process a video in the 'Video Analysis' tab first.") | |
| with tab3: | |
| st.header("π€ AI Coach (Powered by Grok)") | |
| st.write("Ask strategic questions about the match data.") | |
| if st.session_state.processed_data: | |
| data = st.session_state.processed_data | |
| # Create a summary string for the LLM | |
| if not data['ball'].empty: | |
| summary = (f"Analysis Summary: Tracked {len(data['ball'])} frames. " | |
| f"Average ball vertical position: {data['ball']['y'].mean():.2f}. " | |
| f"Player movement density loaded in heatmap.") | |
| else: | |
| summary = "No specific ball tracking data available." | |
| user_input = st.text_area("Ask the Coach:", "Based on the ball trajectory, was the game aggressive or defensive?") | |
| if st.button("Get Advice"): | |
| with st.spinner("Contacting Coach Grok..."): | |
| advice = get_grok_advice(summary, user_input, grok_api_key) | |
| st.markdown("### π‘ Coach's Feedback") | |
| st.markdown(advice) | |
| else: | |
| st.warning("Please analyze a video first to give the Coach context.") |