|
|
""" |
|
|
AI Fitness Coach - Hugging Face Spaces Demo |
|
|
Fine-tuned persona-based feedback system |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import json |
|
|
import tempfile |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
import sys |
|
|
PROJECT_ROOT = Path(__file__).parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
|
|
from fitness_coach.video_processing import process_video |
|
|
from fitness_coach.scoring import calculate_overall_score |
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO = "rlogh/fitness-coach-personas" |
|
|
|
|
|
|
|
|
PERSONAS = { |
|
|
"Hype Beast π₯": "persona_hype_beast", |
|
|
"Data Scientist π": "persona_data_scientist", |
|
|
"No-Nonsense Pro πͺ": "persona_no-nonsense_pro", |
|
|
"Mindful Aligner π§": "persona_mindful_aligner" |
|
|
} |
|
|
|
|
|
models = {} |
|
|
tokenizers = {} |
|
|
|
|
|
def load_models(): |
|
|
"""Load all fine-tuned persona models from Hugging Face Hub""" |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
|
|
|
for persona_name, model_dir in PERSONAS.items(): |
|
|
|
|
|
try: |
|
|
print(f"Loading {persona_name} from Hugging Face Hub...") |
|
|
model_path = f"{MODEL_REPO}/{model_dir}" |
|
|
|
|
|
tokenizers[persona_name] = AutoTokenizer.from_pretrained(model_path, use_fast=False) |
|
|
|
|
|
try: |
|
|
models[persona_name] = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
trust_remote_code=True, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
print(f"β Loaded {persona_name}") |
|
|
except Exception as e: |
|
|
print(f"β Failed to load {persona_name}: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
print("Loading fine-tuned models...") |
|
|
load_models() |
|
|
print(f"Loaded {len(models)} persona models") |
|
|
|
|
|
|
|
|
def generate_feedback(persona_name, input_report): |
|
|
"""Generate persona-specific feedback using fine-tuned model""" |
|
|
if persona_name not in models: |
|
|
return f"Model for {persona_name} not loaded" |
|
|
|
|
|
model = models[persona_name] |
|
|
tokenizer = tokenizers[persona_name] |
|
|
|
|
|
|
|
|
prompt = f"<|persona|>{persona_name}<|input|>{input_report}<|output|>" |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256) |
|
|
|
|
|
|
|
|
model_device = model.cpu() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model_device.generate( |
|
|
**inputs, |
|
|
max_new_tokens=500, |
|
|
temperature=0.9, |
|
|
top_p=0.95, |
|
|
top_k=50, |
|
|
do_sample=True, |
|
|
repetition_penalty=1.2, |
|
|
no_repeat_ngram_size=3, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "<|output|>" in full_text: |
|
|
feedback = full_text.split("<|output|>")[-1].strip() |
|
|
else: |
|
|
feedback = full_text |
|
|
|
|
|
return feedback |
|
|
|
|
|
|
|
|
def analyze_video(video_file, persona_choice, reference_video_choice): |
|
|
"""Main analysis function""" |
|
|
if video_file is None: |
|
|
return "Please upload a video", "", "" |
|
|
|
|
|
try: |
|
|
|
|
|
video_path = video_file |
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
exercise_map = { |
|
|
"Squat": "squat", |
|
|
"Pushup": "pushup", |
|
|
"Pullup": "pullup" |
|
|
} |
|
|
exercise_id = exercise_map.get(reference_video_choice, "pushup") |
|
|
|
|
|
|
|
|
try: |
|
|
ref_3d_path = hf_hub_download( |
|
|
repo_id="rlogh/fitness-coach-references", |
|
|
filename=f"{exercise_id}/keypoints_3D.npz", |
|
|
repo_type="dataset" |
|
|
) |
|
|
ref_results = {'keypoints_3d': np.load(ref_3d_path)['reconstruction']} |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load reference: {e}") |
|
|
ref_results = None |
|
|
|
|
|
|
|
|
print(f"Processing {video_path}...") |
|
|
user_results = process_video(video_path) |
|
|
|
|
|
if not user_results or 'keypoints_3d' not in user_results: |
|
|
return "Failed to process video - no pose detected", "", "" |
|
|
|
|
|
|
|
|
scores = calculate_overall_score(user_results, ref_results) |
|
|
|
|
|
|
|
|
report = f""" |
|
|
Exercise Analysis Report |
|
|
|
|
|
Overall Score: {scores.get('overall_score', 0):.1f}/100 |
|
|
|
|
|
Body Part Breakdown: |
|
|
- Head/Neck: {scores.get('head_score', 0):.1f}/100 |
|
|
- Shoulders: {scores.get('shoulder_score', 0):.1f}/100 |
|
|
- Arms: {scores.get('arm_score', 0):.1f}/100 |
|
|
- Torso: {scores.get('torso_score', 0):.1f}/100 |
|
|
- Legs: {scores.get('leg_score', 0):.1f}/100 |
|
|
|
|
|
Key Issues: |
|
|
{scores.get('issues', 'Form looks good!')} |
|
|
""" |
|
|
|
|
|
|
|
|
feedback = generate_feedback(persona_choice, report) |
|
|
|
|
|
|
|
|
analysis_json = json.dumps({ |
|
|
"scores": scores, |
|
|
"persona": persona_choice, |
|
|
"frames_analyzed": len(user_results.get('keypoints_3d', [])), |
|
|
"reference_used": reference_video_choice |
|
|
}, indent=2) |
|
|
|
|
|
return report, feedback, analysis_json |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"Error analyzing video: {str(e)}\n{traceback.format_exc()}" |
|
|
return error_msg, "", "" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Fitness Coach", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# ποΈ AI Fitness Coach with Fine-Tuned Personas |
|
|
|
|
|
Upload your exercise video and get personalized feedback from our AI coach! |
|
|
|
|
|
**Features:** |
|
|
- 3D pose estimation and analysis |
|
|
- 4 fine-tuned persona models (trained on GPT-4o synthetic data) |
|
|
- Real-time scoring and feedback |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input = gr.Video(label="Upload Your Exercise Video") |
|
|
|
|
|
persona_select = gr.Radio( |
|
|
choices=list(PERSONAS.keys()), |
|
|
value=list(PERSONAS.keys())[0], |
|
|
label="Choose Your Coach Persona", |
|
|
info="Each persona has a unique coaching style (fine-tuned model)" |
|
|
) |
|
|
|
|
|
reference_select = gr.Radio( |
|
|
choices=["Squat", "Pushup", "Pullup"], |
|
|
value="Squat", |
|
|
label="Exercise Type", |
|
|
info="Select reference exercise for comparison" |
|
|
) |
|
|
|
|
|
analyze_btn = gr.Button("Analyze My Form π―", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
report_output = gr.Textbox( |
|
|
label="π Analysis Report", |
|
|
lines=12, |
|
|
placeholder="Your detailed analysis will appear here..." |
|
|
) |
|
|
|
|
|
feedback_output = gr.Textbox( |
|
|
label="π¬ Persona Feedback (Fine-Tuned Model)", |
|
|
lines=10, |
|
|
placeholder="Personalized coaching feedback will appear here..." |
|
|
) |
|
|
|
|
|
json_output = gr.JSON(label="π Detailed Results (JSON)") |
|
|
|
|
|
|
|
|
gr.Markdown("### πΉ Example Videos") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["sample_squat.mp4", "Hype Beast π₯", "Squat"], |
|
|
["sample_pushup.mp4", "Data Scientist π", "Pushup"], |
|
|
], |
|
|
inputs=[video_input, persona_select, reference_select], |
|
|
label="Try these examples" |
|
|
) |
|
|
|
|
|
|
|
|
analyze_btn.click( |
|
|
fn=analyze_video, |
|
|
inputs=[video_input, persona_select, reference_select], |
|
|
outputs=[report_output, feedback_output, json_output] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### π¬ About the Models |
|
|
|
|
|
This app uses **4 fine-tuned DistilGPT-2 models**, each trained on persona-specific synthetic data: |
|
|
- **Training Data**: 320 examples (80 per persona) generated with GPT-4o |
|
|
- **Base Model**: DistilGPT-2 |
|
|
- **Training**: 1000 steps per persona with FP16 mixed precision |
|
|
- **Personas**: Hype Beast, Data Scientist, No-Nonsense Pro, Mindful Aligner |
|
|
|
|
|
### π Technology Stack |
|
|
- **3D Pose Estimation**: PoseFormer (transformer-based) |
|
|
- **Video Processing**: MediaPipe + OpenCV |
|
|
- **Fine-Tuned Models**: Hugging Face Transformers |
|
|
- **Framework**: Gradio + PyTorch |
|
|
""") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|