StrokeSense / app.py
marcosavoia's picture
Update app.py
81b4528 verified
Raw
History Blame Contribute Delete
6.68 kB
import subprocess
import sys
subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", "/app"],
check=True
)
import asyncio
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from openai import AsyncOpenAI
from ai_tennis_coach import (
classify_shots,
extract_all_shots,
create_client,
analyze_all_shots,
calculate_mean_by_shot_type,
config,
)
matplotlib.use("Agg")
MODEL_PATH = "models/tennis_rnn.h5"
METRIC_LABELS = {
"preparation": "Preparation",
"contact_point": "Contact",
"swing_followthrough": "Follow-through",
"balance_stance": "Balance",
}
def make_stats_plot(mean_scores: dict) -> plt.Figure:
labels = list(METRIC_LABELS.values())
n = len(labels)
angles = np.linspace(0, 2 * np.pi, n, endpoint=False).tolist()
angles += angles[:1] # close the polygon
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
colors = ["#4C72B0", "#DD8452", "#55A868"]
for i, (shot_type, metrics_dict) in enumerate(mean_scores.items()):
values = [metrics_dict.get(m, 0) for m in METRIC_LABELS]
values += values[:1]
color = colors[i % len(colors)]
ax.plot(angles, values, "o-", linewidth=2, color=color, label=shot_type.capitalize())
ax.fill(angles, values, alpha=0.2, color=color)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(labels, size=10)
ax.set_ylim(0, 10)
ax.set_yticks([2, 4, 6, 8, 10])
ax.set_yticklabels(["2", "4", "6", "8", "10"], size=8)
ax.set_title("Stroke Mechanics", pad=20)
ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.15))
plt.tight_layout()
return fig
def format_summary(results: list, analyses: list, mean_scores: dict) -> str:
lines = [f"Detected {len(results)} shot(s).\n"]
for shot_list in analyses:
for s in shot_list:
lines.append(
f"{s.shot_type.capitalize()} stroke\n"
f" Preparation: {s.preparation.score:.1f}/10 — {s.preparation.notes}\n"
f" Contact point: {s.contact_point.score:.1f}/10 — {s.contact_point.notes}\n"
f" Follow-through: {s.swing_followthrough.score:.1f}/10 — {s.swing_followthrough.notes}\n"
f" Balance: {s.balance_stance.score:.1f}/10 — {s.balance_stance.notes}\n"
f" Coach notes: {s.coach_notes}\n"
)
lines.append("--- Mean scores by shot type ---")
for shot_type, metrics in mean_scores.items():
lines.append(f"{shot_type.capitalize()}:")
for metric, value in metrics.items():
lines.append(f" {metric.replace('_', ' ').capitalize()}: {value:.1f}/10")
return "\n".join(lines)
async def run_analysis(video_path: str, left_handed: bool = False):
if video_path is None:
raise gr.Error("Please upload a video first.")
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
None, lambda: classify_shots(video_path, MODEL_PATH, left_handed=left_handed)
)
if not results:
raise gr.Error("No shots detected in the video.")
extract_all_shots(video_path, results, output_dir=config.clips_dir)
client = create_client()
analyses = await analyze_all_shots(results[:5], client)
mean_scores = calculate_mean_by_shot_type(analyses)
fig = make_stats_plot(mean_scores)
summary = format_summary(results, analyses, mean_scores)
return fig, summary
async def chat_fn(message: str, history: list, analysis_summary: str):
context = analysis_summary or "No analysis has been run yet."
system = (
"You are an expert tennis coach providing personalised advice, training recommendations, "
"and improvement suggestions. Use the player's stroke analysis below to guide your responses "
"— highlight strengths, flag weaknesses, and offer broad coaching guidance.\n\n"
f"Player stroke analysis:\n{context}"
)
# Gradio 4+ passes history as list[{"role": ..., "content": ...}]
api_messages = [{"role": "system", "content": system}]
for msg in history:
api_messages.append({"role": msg["role"], "content": msg["content"]})
api_messages.append({"role": "user", "content": message})
chat_client = AsyncOpenAI(base_url=config.api_base_url, api_key=config.hf_token)
stream = await chat_client.chat.completions.create(
model=config.model_name,
messages=api_messages,
stream=True,
max_tokens=1024,
temperature=0.7,
)
# Yield the full updated history so gr.Chatbot renders correctly
updated_history = list(history) + [{"role": "user", "content": message}]
reply = ""
async for chunk in stream:
delta = chunk.choices[0].delta.content or ""
reply += delta
yield updated_history + [{"role": "assistant", "content": reply}]
with gr.Blocks(title="AI Tennis Coach") as demo:
gr.Markdown("# AI Tennis Coach — Stroke Analysis")
analysis_state = gr.State("")
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(
label="Upload Tennis Video",
sources=["upload"],
format="mp4",
height=320
)
left_handed = gr.Checkbox(label="Front facing", value=False)
analyse_btn = gr.Button("Analyse", variant="primary")
stats_plot = gr.Plot(label="Stroke Mechanics Summary")
gr.Examples(
examples=[["data/sample_tennis_video.mp4"]],
inputs=video_input
)
with gr.Column(scale=1):
chatbot = gr.Chatbot(
label="Coach Chat",
height=500,
placeholder="Run the analysis first, then ask your coach anything.",
)
chat_input = gr.Textbox(
placeholder="Ask your coach...",
label="Your message",
show_label=False,
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
gr.ClearButton([chat_input, chatbot], value="Clear chat")
analyse_btn.click(
fn=run_analysis,
inputs=[video_input, left_handed],
outputs=[stats_plot, analysis_state],
)
send_btn.click(
fn=chat_fn,
inputs=[chat_input, chatbot, analysis_state],
outputs=chatbot,
)
chat_input.submit(
fn=chat_fn,
inputs=[chat_input, chatbot, analysis_state],
outputs=chatbot,
)
if __name__ == "__main__":
demo.launch()