import subprocess import sys subprocess.run( [sys.executable, "-m", "pip", "install", "--no-cache-dir", "/app"], check=True ) import asyncio import gradio as gr import matplotlib import matplotlib.pyplot as plt import numpy as np from openai import AsyncOpenAI from ai_tennis_coach import ( classify_shots, extract_all_shots, create_client, analyze_all_shots, calculate_mean_by_shot_type, config, ) matplotlib.use("Agg") MODEL_PATH = "models/tennis_rnn.h5" METRIC_LABELS = { "preparation": "Preparation", "contact_point": "Contact", "swing_followthrough": "Follow-through", "balance_stance": "Balance", } def make_stats_plot(mean_scores: dict) -> plt.Figure: labels = list(METRIC_LABELS.values()) n = len(labels) angles = np.linspace(0, 2 * np.pi, n, endpoint=False).tolist() angles += angles[:1] # close the polygon fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) colors = ["#4C72B0", "#DD8452", "#55A868"] for i, (shot_type, metrics_dict) in enumerate(mean_scores.items()): values = [metrics_dict.get(m, 0) for m in METRIC_LABELS] values += values[:1] color = colors[i % len(colors)] ax.plot(angles, values, "o-", linewidth=2, color=color, label=shot_type.capitalize()) ax.fill(angles, values, alpha=0.2, color=color) ax.set_xticks(angles[:-1]) ax.set_xticklabels(labels, size=10) ax.set_ylim(0, 10) ax.set_yticks([2, 4, 6, 8, 10]) ax.set_yticklabels(["2", "4", "6", "8", "10"], size=8) ax.set_title("Stroke Mechanics", pad=20) ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.15)) plt.tight_layout() return fig def format_summary(results: list, analyses: list, mean_scores: dict) -> str: lines = [f"Detected {len(results)} shot(s).\n"] for shot_list in analyses: for s in shot_list: lines.append( f"{s.shot_type.capitalize()} stroke\n" f" Preparation: {s.preparation.score:.1f}/10 — {s.preparation.notes}\n" f" Contact point: {s.contact_point.score:.1f}/10 — {s.contact_point.notes}\n" f" Follow-through: {s.swing_followthrough.score:.1f}/10 — {s.swing_followthrough.notes}\n" f" Balance: {s.balance_stance.score:.1f}/10 — {s.balance_stance.notes}\n" f" Coach notes: {s.coach_notes}\n" ) lines.append("--- Mean scores by shot type ---") for shot_type, metrics in mean_scores.items(): lines.append(f"{shot_type.capitalize()}:") for metric, value in metrics.items(): lines.append(f" {metric.replace('_', ' ').capitalize()}: {value:.1f}/10") return "\n".join(lines) async def run_analysis(video_path: str, left_handed: bool = False): if video_path is None: raise gr.Error("Please upload a video first.") loop = asyncio.get_event_loop() results = await loop.run_in_executor( None, lambda: classify_shots(video_path, MODEL_PATH, left_handed=left_handed) ) if not results: raise gr.Error("No shots detected in the video.") extract_all_shots(video_path, results, output_dir=config.clips_dir) client = create_client() analyses = await analyze_all_shots(results[:5], client) mean_scores = calculate_mean_by_shot_type(analyses) fig = make_stats_plot(mean_scores) summary = format_summary(results, analyses, mean_scores) return fig, summary async def chat_fn(message: str, history: list, analysis_summary: str): context = analysis_summary or "No analysis has been run yet." system = ( "You are an expert tennis coach providing personalised advice, training recommendations, " "and improvement suggestions. Use the player's stroke analysis below to guide your responses " "— highlight strengths, flag weaknesses, and offer broad coaching guidance.\n\n" f"Player stroke analysis:\n{context}" ) # Gradio 4+ passes history as list[{"role": ..., "content": ...}] api_messages = [{"role": "system", "content": system}] for msg in history: api_messages.append({"role": msg["role"], "content": msg["content"]}) api_messages.append({"role": "user", "content": message}) chat_client = AsyncOpenAI(base_url=config.api_base_url, api_key=config.hf_token) stream = await chat_client.chat.completions.create( model=config.model_name, messages=api_messages, stream=True, max_tokens=1024, temperature=0.7, ) # Yield the full updated history so gr.Chatbot renders correctly updated_history = list(history) + [{"role": "user", "content": message}] reply = "" async for chunk in stream: delta = chunk.choices[0].delta.content or "" reply += delta yield updated_history + [{"role": "assistant", "content": reply}] with gr.Blocks(title="AI Tennis Coach") as demo: gr.Markdown("# AI Tennis Coach — Stroke Analysis") analysis_state = gr.State("") with gr.Row(): with gr.Column(scale=1): video_input = gr.Video( label="Upload Tennis Video", sources=["upload"], format="mp4", height=320 ) left_handed = gr.Checkbox(label="Front facing", value=False) analyse_btn = gr.Button("Analyse", variant="primary") stats_plot = gr.Plot(label="Stroke Mechanics Summary") gr.Examples( examples=[["data/sample_tennis_video.mp4"]], inputs=video_input ) with gr.Column(scale=1): chatbot = gr.Chatbot( label="Coach Chat", height=500, placeholder="Run the analysis first, then ask your coach anything.", ) chat_input = gr.Textbox( placeholder="Ask your coach...", label="Your message", show_label=False, ) with gr.Row(): send_btn = gr.Button("Send", variant="primary") gr.ClearButton([chat_input, chatbot], value="Clear chat") analyse_btn.click( fn=run_analysis, inputs=[video_input, left_handed], outputs=[stats_plot, analysis_state], ) send_btn.click( fn=chat_fn, inputs=[chat_input, chatbot, analysis_state], outputs=chatbot, ) chat_input.submit( fn=chat_fn, inputs=[chat_input, chatbot, analysis_state], outputs=chatbot, ) if __name__ == "__main__": demo.launch()