Spaces:
Sleeping
Sleeping
| import subprocess | |
| import sys | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--no-cache-dir", "/app"], | |
| check=True | |
| ) | |
| import asyncio | |
| import gradio as gr | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from openai import AsyncOpenAI | |
| from ai_tennis_coach import ( | |
| classify_shots, | |
| extract_all_shots, | |
| create_client, | |
| analyze_all_shots, | |
| calculate_mean_by_shot_type, | |
| config, | |
| ) | |
| matplotlib.use("Agg") | |
| MODEL_PATH = "models/tennis_rnn.h5" | |
| METRIC_LABELS = { | |
| "preparation": "Preparation", | |
| "contact_point": "Contact", | |
| "swing_followthrough": "Follow-through", | |
| "balance_stance": "Balance", | |
| } | |
| def make_stats_plot(mean_scores: dict) -> plt.Figure: | |
| labels = list(METRIC_LABELS.values()) | |
| n = len(labels) | |
| angles = np.linspace(0, 2 * np.pi, n, endpoint=False).tolist() | |
| angles += angles[:1] # close the polygon | |
| fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) | |
| colors = ["#4C72B0", "#DD8452", "#55A868"] | |
| for i, (shot_type, metrics_dict) in enumerate(mean_scores.items()): | |
| values = [metrics_dict.get(m, 0) for m in METRIC_LABELS] | |
| values += values[:1] | |
| color = colors[i % len(colors)] | |
| ax.plot(angles, values, "o-", linewidth=2, color=color, label=shot_type.capitalize()) | |
| ax.fill(angles, values, alpha=0.2, color=color) | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels(labels, size=10) | |
| ax.set_ylim(0, 10) | |
| ax.set_yticks([2, 4, 6, 8, 10]) | |
| ax.set_yticklabels(["2", "4", "6", "8", "10"], size=8) | |
| ax.set_title("Stroke Mechanics", pad=20) | |
| ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.15)) | |
| plt.tight_layout() | |
| return fig | |
| def format_summary(results: list, analyses: list, mean_scores: dict) -> str: | |
| lines = [f"Detected {len(results)} shot(s).\n"] | |
| for shot_list in analyses: | |
| for s in shot_list: | |
| lines.append( | |
| f"{s.shot_type.capitalize()} stroke\n" | |
| f" Preparation: {s.preparation.score:.1f}/10 — {s.preparation.notes}\n" | |
| f" Contact point: {s.contact_point.score:.1f}/10 — {s.contact_point.notes}\n" | |
| f" Follow-through: {s.swing_followthrough.score:.1f}/10 — {s.swing_followthrough.notes}\n" | |
| f" Balance: {s.balance_stance.score:.1f}/10 — {s.balance_stance.notes}\n" | |
| f" Coach notes: {s.coach_notes}\n" | |
| ) | |
| lines.append("--- Mean scores by shot type ---") | |
| for shot_type, metrics in mean_scores.items(): | |
| lines.append(f"{shot_type.capitalize()}:") | |
| for metric, value in metrics.items(): | |
| lines.append(f" {metric.replace('_', ' ').capitalize()}: {value:.1f}/10") | |
| return "\n".join(lines) | |
| async def run_analysis(video_path: str, left_handed: bool = False): | |
| if video_path is None: | |
| raise gr.Error("Please upload a video first.") | |
| loop = asyncio.get_event_loop() | |
| results = await loop.run_in_executor( | |
| None, lambda: classify_shots(video_path, MODEL_PATH, left_handed=left_handed) | |
| ) | |
| if not results: | |
| raise gr.Error("No shots detected in the video.") | |
| extract_all_shots(video_path, results, output_dir=config.clips_dir) | |
| client = create_client() | |
| analyses = await analyze_all_shots(results[:5], client) | |
| mean_scores = calculate_mean_by_shot_type(analyses) | |
| fig = make_stats_plot(mean_scores) | |
| summary = format_summary(results, analyses, mean_scores) | |
| return fig, summary | |
| async def chat_fn(message: str, history: list, analysis_summary: str): | |
| context = analysis_summary or "No analysis has been run yet." | |
| system = ( | |
| "You are an expert tennis coach providing personalised advice, training recommendations, " | |
| "and improvement suggestions. Use the player's stroke analysis below to guide your responses " | |
| "— highlight strengths, flag weaknesses, and offer broad coaching guidance.\n\n" | |
| f"Player stroke analysis:\n{context}" | |
| ) | |
| # Gradio 4+ passes history as list[{"role": ..., "content": ...}] | |
| api_messages = [{"role": "system", "content": system}] | |
| for msg in history: | |
| api_messages.append({"role": msg["role"], "content": msg["content"]}) | |
| api_messages.append({"role": "user", "content": message}) | |
| chat_client = AsyncOpenAI(base_url=config.api_base_url, api_key=config.hf_token) | |
| stream = await chat_client.chat.completions.create( | |
| model=config.model_name, | |
| messages=api_messages, | |
| stream=True, | |
| max_tokens=1024, | |
| temperature=0.7, | |
| ) | |
| # Yield the full updated history so gr.Chatbot renders correctly | |
| updated_history = list(history) + [{"role": "user", "content": message}] | |
| reply = "" | |
| async for chunk in stream: | |
| delta = chunk.choices[0].delta.content or "" | |
| reply += delta | |
| yield updated_history + [{"role": "assistant", "content": reply}] | |
| with gr.Blocks(title="AI Tennis Coach") as demo: | |
| gr.Markdown("# AI Tennis Coach — Stroke Analysis") | |
| analysis_state = gr.State("") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_input = gr.Video( | |
| label="Upload Tennis Video", | |
| sources=["upload"], | |
| format="mp4", | |
| height=320 | |
| ) | |
| left_handed = gr.Checkbox(label="Front facing", value=False) | |
| analyse_btn = gr.Button("Analyse", variant="primary") | |
| stats_plot = gr.Plot(label="Stroke Mechanics Summary") | |
| gr.Examples( | |
| examples=[["data/sample_tennis_video.mp4"]], | |
| inputs=video_input | |
| ) | |
| with gr.Column(scale=1): | |
| chatbot = gr.Chatbot( | |
| label="Coach Chat", | |
| height=500, | |
| placeholder="Run the analysis first, then ask your coach anything.", | |
| ) | |
| chat_input = gr.Textbox( | |
| placeholder="Ask your coach...", | |
| label="Your message", | |
| show_label=False, | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("Send", variant="primary") | |
| gr.ClearButton([chat_input, chatbot], value="Clear chat") | |
| analyse_btn.click( | |
| fn=run_analysis, | |
| inputs=[video_input, left_handed], | |
| outputs=[stats_plot, analysis_state], | |
| ) | |
| send_btn.click( | |
| fn=chat_fn, | |
| inputs=[chat_input, chatbot, analysis_state], | |
| outputs=chatbot, | |
| ) | |
| chat_input.submit( | |
| fn=chat_fn, | |
| inputs=[chat_input, chatbot, analysis_state], | |
| outputs=chatbot, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |