Spaces:
Sleeping
Sleeping
| """ | |
| Workout Coach β FastAPI Inference App | |
| Runs DistilBERT classification + Claude debrief generation | |
| Designed for Hugging Face Spaces with Docker | |
| """ | |
| from fastapi import FastAPI, HTTPException, Header, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from contextlib import asynccontextmanager | |
| from typing import Optional, Dict, List | |
| import torch | |
| import anthropic | |
| import asyncio | |
| import json | |
| import os | |
| import logging | |
| from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model | |
| from inference import ( | |
| predict_post, decode_post_predictions, build_post_prompt, parse_debrief, | |
| predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOGGING | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # API KEY AUTHENTICATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| API_KEY = os.getenv("APP_API_KEY") | |
| async def verify_key(x_api_key: str = Header(...)): | |
| if x_api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API key") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # POST-WORKOUT LABEL MAPS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| POST_MOOD_MAP = { | |
| 0: "accomplished", 1: "anxious", 2: "distracted", | |
| 3: "energized", 4: "fatigued", 5: "frustrated", | |
| 6: "neutral", 7: "positive", | |
| } | |
| POST_EXERTION_MAP = {0: "low", 1: "moderate", 2: "high"} | |
| POST_COMPLETION_MAP = {0: "partial", 1: "full"} | |
| POST_SORENESS_REGION_MAP = { | |
| 0: "none", 1: "back", 2: "biceps", | |
| 3: "chest", 4: "legs", 5: "shoulder", 6: "triceps", | |
| } | |
| POST_SORENESS_SEVERITY_MAP = { | |
| 0: "none", 1: "mild", 2: "moderate", 3: "severe", | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # PRE-WORKOUT LABEL MAPS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| PRE_MOOD_MAP = { | |
| 0: "accomplished", 1: "anxious", 2: "distracted", | |
| 3: "energized", 4: "fatigued", 5: "frustrated", | |
| 6: "neutral", 7: "positive", | |
| } | |
| PRE_ENERGY_MAP = {0: "low", 1: "moderate", 2: "high"} | |
| PRE_MOTIVATION_MAP = {0: "low", 1: "moderate", 2: "high"} | |
| PRE_STRESS_MAP = {0: "low", 1: "moderate", 2: "high"} | |
| PRE_SORENESS_REGION_MAP = { | |
| 0: "none", 1: "back", 2: "biceps", | |
| 3: "chest", 4: "legs", 5: "shoulder", 6: "triceps", | |
| } | |
| PRE_SORENESS_SEVERITY_MAP = { | |
| 0: "none", 1: "mild", 2: "moderate", 3: "severe", | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # APP STATE β model loaded once at startup | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| app_state = {} | |
| async def lifespan(app: FastAPI): | |
| """Load both models once at startup, clean up at shutdown.""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # ββ Post-workout model ββββββββββββββββββββββββββββββββββββ | |
| logger.info("Loading post-workout model...") | |
| post_model, post_tokenizer = load_post_model( | |
| model_path=os.getenv("MODEL_PATH", "post_best_overall_model.pt"), | |
| device=device, | |
| ) | |
| app_state["post_model"] = post_model | |
| app_state["post_tokenizer"] = post_tokenizer | |
| # ββ Pre-workout model βββββββββββββββββββββββββββββββββββββ | |
| logger.info("Loading pre-workout model...") | |
| pre_model, pre_tokenizer = load_pre_model( | |
| model_path=os.getenv("PRE_MODEL_PATH", "pre_best_overall_model.pt"), | |
| device=device, | |
| ) | |
| app_state["pre_model"] = pre_model | |
| app_state["pre_tokenizer"] = pre_tokenizer | |
| app_state["device"] = device | |
| # ββ Anthropic client ββββββββββββββββββββββββββββββββββββββ | |
| app_state["anthropic_client"] = anthropic.Anthropic( | |
| api_key=os.getenv("ANTHROPIC_API_KEY") | |
| ) | |
| logger.info("All models and clients loaded successfully.") | |
| yield | |
| app_state.clear() | |
| logger.info("App shutdown complete.") | |
| app = FastAPI( | |
| title="Workout Coach Inference API", | |
| description="Post-workout DistilBERT + Pre-workout DistilBERT + Claude generation", | |
| version="2.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # REQUEST / RESPONSE SCHEMAS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # POST-WORKOUT SCHEMAS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class PostWorkoutRequest(BaseModel): | |
| # Free-text input from the user β fed into PostWorkoutDistilBERT | |
| user_text: str = Field(..., min_length=5, max_length=500, | |
| example="That was really tough, chest is killing me but I feel accomplished.") | |
| # UI form fields β collected separately in the app | |
| duration_minutes: int = Field(..., ge=1, le=300, example=45) | |
| workout_type: str = Field(..., example="upper_body_push") | |
| user_goal: str = Field(..., example="muscle_gain") | |
| # Optional β whether to generate the Claude debrief | |
| generate_debrief: bool = Field(default=True) | |
| class PostWorkoutBertLabels(BaseModel): | |
| mood: str | |
| exertion: str | |
| soreness_region: str | |
| soreness_severity: str | |
| completion: str | |
| class PostWorkoutSessionResponse(BaseModel): | |
| bert_labels: PostWorkoutBertLabels | |
| acknowledgement: Optional[str] = None | |
| highlights: Optional[str] = None | |
| next_session: Optional[str] = None | |
| raw_debrief: Optional[str] = None # full unmodified response β fallback | |
| class HealthResponse(BaseModel): | |
| model_config = {"protected_namespaces": ()} | |
| status: str | |
| device: str | |
| post_model_loaded: bool | |
| pre_model_loaded: bool | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # ROUTES | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| """Health check β confirms both models are loaded and ready.""" | |
| return { | |
| "status": "ok", | |
| "device": str(app_state.get("device", "unknown")), | |
| "post_model_loaded": "post_model" in app_state, | |
| "pre_model_loaded": "pre_model" in app_state, | |
| } | |
| def post_classify_session(req: PostWorkoutRequest): | |
| """ | |
| Runs PostWorkoutDistilBERT inference on user_text and optionally | |
| generates a Claude debrief using the classified labels | |
| combined with the session form data. | |
| """ | |
| model = app_state["post_model"] | |
| tokenizer = app_state["post_tokenizer"] | |
| device = app_state["device"] | |
| client = app_state["anthropic_client"] | |
| # ββ Step 1: PostWorkoutDistilBERT inference βββββββββββββββ | |
| try: | |
| raw_preds = predict_post(req.user_text, model, tokenizer, device) | |
| except Exception as e: | |
| logger.error(f"Post-workout inference error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") | |
| # ββ Step 2: Decode integer labels β strings βββββββββββββββ | |
| bert_labels = decode_post_predictions( | |
| raw_preds, | |
| POST_MOOD_MAP, | |
| POST_EXERTION_MAP, | |
| POST_SORENESS_REGION_MAP, | |
| POST_SORENESS_SEVERITY_MAP, | |
| POST_COMPLETION_MAP, | |
| ) | |
| # ββ Step 3: Optionally generate Claude debrief ββββββββββββ | |
| parsed = {} | |
| if req.generate_debrief: | |
| prompt = build_post_prompt( | |
| bert_labels=bert_labels, | |
| user_text=req.user_text, | |
| duration_minutes=req.duration_minutes, | |
| workout_type=req.workout_type, | |
| user_goal=req.user_goal, | |
| ) | |
| try: | |
| message = client.messages.create( | |
| model="claude-haiku-4-5-20251001", | |
| max_tokens=400, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| raw_debrief = message.content[0].text | |
| parsed = parse_debrief(raw_debrief) | |
| except Exception as e: | |
| logger.error(f"Claude API error (post-workout): {e}") | |
| parsed = {} | |
| return PostWorkoutSessionResponse( | |
| bert_labels=PostWorkoutBertLabels(**bert_labels), | |
| acknowledgement=parsed.get("acknowledgement"), | |
| highlights=parsed.get("highlights"), | |
| next_session=parsed.get("next_session"), | |
| raw_debrief=parsed.get("raw"), | |
| ) | |
| def post_classify_labels_only(req: PostWorkoutRequest): | |
| """ | |
| Runs only PostWorkoutDistilBERT inference. Skips Claude. | |
| Useful for storing labels to DB without generating a debrief yet. | |
| """ | |
| model = app_state["post_model"] | |
| tokenizer = app_state["post_tokenizer"] | |
| device = app_state["device"] | |
| try: | |
| raw_preds = predict_post(req.user_text, model, tokenizer, device) | |
| bert_labels = decode_post_predictions( | |
| raw_preds, | |
| POST_MOOD_MAP, | |
| POST_EXERTION_MAP, | |
| POST_SORENESS_REGION_MAP, | |
| POST_SORENESS_SEVERITY_MAP, | |
| POST_COMPLETION_MAP, | |
| ) | |
| return PostWorkoutBertLabels(**bert_labels) | |
| except Exception as e: | |
| logger.error(f"Post-workout inference error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # PRE-WORKOUT SCHEMAS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class PreWorkoutRequest(BaseModel): | |
| user_text: str = Field( | |
| ..., min_length=5, max_length=500, | |
| example="Feeling a bit tired but motivated. Back is slightly sore from Tuesday.", | |
| ) | |
| workout_type: str = Field(..., example="upper_body_push") | |
| duration_minutes: int = Field(..., ge=10, le=300, example=45) | |
| user_goal: str = Field(..., example="muscle_gain") | |
| equipment: List[str] = Field(default=[], example=["barbell", "dumbbells", "bench"]) | |
| generate_plan: bool = Field(default=True) | |
| class PreWorkoutBertLabels(BaseModel): | |
| mood: str | |
| energy: str | |
| motivation: str | |
| stress: str | |
| soreness_region: str | |
| soreness_severity: str | |
| class PreWorkoutResponse(BaseModel): | |
| bert_labels: PreWorkoutBertLabels | |
| warm_up: Optional[str] = None | |
| main_workout: Optional[str] = None | |
| cool_down: Optional[str] = None | |
| coaching_note: Optional[str] = None | |
| raw_plan: Optional[str] = None # full unmodified response β fallback | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # PRE-WORKOUT ROUTES | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def pre_classify_session(req: PreWorkoutRequest): | |
| """ | |
| Runs PreWorkoutDistilBERT inference on user_text and optionally | |
| calls Claude to generate a structured workout plan using the | |
| classified labels combined with the form selections. | |
| """ | |
| model = app_state["pre_model"] | |
| tokenizer = app_state["pre_tokenizer"] | |
| device = app_state["device"] | |
| client = app_state["anthropic_client"] | |
| # ββ Step 1: DistilBERT inference βββββββββββββββββββββββββ | |
| try: | |
| raw_preds = predict_pre(req.user_text, model, tokenizer, device) | |
| except Exception as e: | |
| logger.error(f"Pre-workout inference error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") | |
| # ββ Step 2: Decode integer labels β strings βββββββββββββββ | |
| bert_labels = decode_pre_predictions( | |
| raw_preds, | |
| PRE_MOOD_MAP, | |
| PRE_ENERGY_MAP, | |
| PRE_MOTIVATION_MAP, | |
| PRE_STRESS_MAP, | |
| PRE_SORENESS_REGION_MAP, | |
| PRE_SORENESS_SEVERITY_MAP, | |
| ) | |
| # ββ Step 3: Optionally generate Claude workout plan βββββββ | |
| parsed = {} | |
| if req.generate_plan: | |
| prompt = build_pre_prompt( | |
| bert_labels=bert_labels, | |
| user_text=req.user_text, | |
| workout_type=req.workout_type, | |
| duration_minutes=req.duration_minutes, | |
| user_goal=req.user_goal, | |
| equipment=req.equipment, | |
| ) | |
| try: | |
| message = client.messages.create( | |
| model="claude-haiku-4-5-20251001", | |
| max_tokens=800, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| raw_plan = message.content[0].text | |
| parsed = parse_workout_plan(raw_plan) | |
| except Exception as e: | |
| logger.error(f"Claude API error (pre-workout): {e}") | |
| parsed = {} | |
| return PreWorkoutResponse( | |
| bert_labels=PreWorkoutBertLabels(**bert_labels), | |
| warm_up=parsed.get("warm_up"), | |
| main_workout=parsed.get("main_workout"), | |
| cool_down=parsed.get("cool_down"), | |
| coaching_note=parsed.get("coaching_note"), | |
| raw_plan=parsed.get("raw"), | |
| ) | |
| def pre_classify_labels_only(req: PreWorkoutRequest): | |
| """ | |
| Runs only pre-workout DistilBERT inference. Skips Claude. | |
| Useful for storing labels to DB without generating a plan yet. | |
| """ | |
| model = app_state["pre_model"] | |
| tokenizer = app_state["pre_tokenizer"] | |
| device = app_state["device"] | |
| try: | |
| raw_preds = predict_pre(req.user_text, model, tokenizer, device) | |
| bert_labels = decode_pre_predictions( | |
| raw_preds, | |
| PRE_MOOD_MAP, | |
| PRE_ENERGY_MAP, | |
| PRE_MOTIVATION_MAP, | |
| PRE_STRESS_MAP, | |
| PRE_SORENESS_REGION_MAP, | |
| PRE_SORENESS_SEVERITY_MAP, | |
| ) | |
| return PreWorkoutBertLabels(**bert_labels) | |
| except Exception as e: | |
| logger.error(f"Pre-workout inference error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # STREAMING ROUTES | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # These endpoints use Server-Sent Events (SSE). | |
| # Each event is a line starting with "data: " | |
| # followed by a JSON payload, terminated by "\n\n". | |
| # | |
| # Event types: | |
| # {"type": "labels", "bert_labels": {...}} β emitted immediately after | |
| # DistilBERT inference (~0.5s) | |
| # {"type": "token", "text": "..."} β one Claude output token | |
| # {"type": "done"} β stream finished cleanly | |
| # {"type": "error", "detail": "..."} β non-fatal error | |
| def _sse(payload: dict) -> str: | |
| """Format a dict as a single SSE data line.""" | |
| return f"data: {json.dumps(payload)}\n\n" | |
| async def post_classify_stream(req: PostWorkoutRequest): | |
| """ | |
| Streaming version of /classify. | |
| Emits BERT labels immediately, then streams Claude debrief tokens. | |
| """ | |
| model = app_state["post_model"] | |
| tokenizer = app_state["post_tokenizer"] | |
| device = app_state["device"] | |
| client = app_state["anthropic_client"] | |
| # ββ Step 1: DistilBERT inference (fast β runs before stream opens) ββ | |
| try: | |
| raw_preds = predict_post(req.user_text, model, tokenizer, device) | |
| bert_labels = decode_post_predictions( | |
| raw_preds, | |
| POST_MOOD_MAP, | |
| POST_EXERTION_MAP, | |
| POST_SORENESS_REGION_MAP, | |
| POST_SORENESS_SEVERITY_MAP, | |
| POST_COMPLETION_MAP, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Post-workout inference error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") | |
| # ββ Step 2: Build Claude prompt βββββββββββββββββββββββββββββββββββββ | |
| prompt = build_post_prompt( | |
| bert_labels=bert_labels, | |
| user_text=req.user_text, | |
| duration_minutes=req.duration_minutes, | |
| workout_type=req.workout_type, | |
| user_goal=req.user_goal, | |
| ) | |
| async def event_stream(): | |
| # Emit labels immediately β iOS renders chips before Claude starts | |
| yield _sse({"type": "labels", "bert_labels": bert_labels}) | |
| if not req.generate_debrief: | |
| yield _sse({"type": "done"}) | |
| return | |
| # Stream Claude tokens as they arrive | |
| try: | |
| with client.messages.stream( | |
| model="claude-haiku-4-5-20251001", | |
| max_tokens=400, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) as stream: | |
| for text_chunk in stream.text_stream: | |
| yield _sse({"type": "token", "text": text_chunk}) | |
| await asyncio.sleep(0) # yield control to event loop | |
| except Exception as e: | |
| logger.error(f"Claude streaming error (post-workout): {e}") | |
| yield _sse({"type": "error", "detail": str(e)}) | |
| return | |
| yield _sse({"type": "done"}) | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", # disables nginx buffering on HF Spaces | |
| }, | |
| ) | |
| async def pre_classify_stream(req: PreWorkoutRequest): | |
| """ | |
| Streaming version of /pre-classify. | |
| Emits BERT labels immediately, then streams Claude workout plan tokens. | |
| """ | |
| model = app_state["pre_model"] | |
| tokenizer = app_state["pre_tokenizer"] | |
| device = app_state["device"] | |
| client = app_state["anthropic_client"] | |
| # ββ Step 1: DistilBERT inference ββββββββββββββββββββββββββββββββββββ | |
| try: | |
| raw_preds = predict_pre(req.user_text, model, tokenizer, device) | |
| bert_labels = decode_pre_predictions( | |
| raw_preds, | |
| PRE_MOOD_MAP, | |
| PRE_ENERGY_MAP, | |
| PRE_MOTIVATION_MAP, | |
| PRE_STRESS_MAP, | |
| PRE_SORENESS_REGION_MAP, | |
| PRE_SORENESS_SEVERITY_MAP, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Pre-workout inference error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") | |
| # ββ Step 2: Build Claude prompt βββββββββββββββββββββββββββββββββββββ | |
| prompt = build_pre_prompt( | |
| bert_labels=bert_labels, | |
| user_text=req.user_text, | |
| workout_type=req.workout_type, | |
| duration_minutes=req.duration_minutes, | |
| user_goal=req.user_goal, | |
| equipment=req.equipment, | |
| ) | |
| async def event_stream(): | |
| # Emit labels immediately | |
| yield _sse({"type": "labels", "bert_labels": bert_labels}) | |
| if not req.generate_plan: | |
| yield _sse({"type": "done"}) | |
| return | |
| # Stream Claude tokens as they arrive | |
| try: | |
| with client.messages.stream( | |
| model="claude-haiku-4-5-20251001", | |
| max_tokens=800, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) as stream: | |
| for text_chunk in stream.text_stream: | |
| yield _sse({"type": "token", "text": text_chunk}) | |
| await asyncio.sleep(0) | |
| except Exception as e: | |
| logger.error(f"Claude streaming error (pre-workout): {e}") | |
| yield _sse({"type": "error", "detail": str(e)}) | |
| return | |
| yield _sse({"type": "done"}) | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |