""" Workout Coach — FastAPI Inference App Runs DistilBERT classification + Claude debrief generation Designed for Hugging Face Spaces with Docker """ from fastapi import FastAPI, HTTPException, Header, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from contextlib import asynccontextmanager from typing import Optional, Dict, List import torch import anthropic import asyncio import json import os import logging from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model from inference import ( predict_post, decode_post_predictions, build_post_prompt, parse_debrief, predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan, ) # ───────────────────────────────────────────── # LOGGING # ───────────────────────────────────────────── logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ───────────────────────────────────────────── # API KEY AUTHENTICATION # ───────────────────────────────────────────── API_KEY = os.getenv("APP_API_KEY") async def verify_key(x_api_key: str = Header(...)): if x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") # ───────────────────────────────────────────── # POST-WORKOUT LABEL MAPS # ───────────────────────────────────────────── POST_MOOD_MAP = { 0: "accomplished", 1: "anxious", 2: "distracted", 3: "energized", 4: "fatigued", 5: "frustrated", 6: "neutral", 7: "positive", } POST_EXERTION_MAP = {0: "low", 1: "moderate", 2: "high"} POST_COMPLETION_MAP = {0: "partial", 1: "full"} POST_SORENESS_REGION_MAP = { 0: "none", 1: "back", 2: "biceps", 3: "chest", 4: "legs", 5: "shoulder", 6: "triceps", } POST_SORENESS_SEVERITY_MAP = { 0: "none", 1: "mild", 2: "moderate", 3: "severe", } # ───────────────────────────────────────────── # PRE-WORKOUT LABEL MAPS # ───────────────────────────────────────────── PRE_MOOD_MAP = { 0: "accomplished", 1: "anxious", 2: "distracted", 3: "energized", 4: "fatigued", 5: "frustrated", 6: "neutral", 7: "positive", } PRE_ENERGY_MAP = {0: "low", 1: "moderate", 2: "high"} PRE_MOTIVATION_MAP = {0: "low", 1: "moderate", 2: "high"} PRE_STRESS_MAP = {0: "low", 1: "moderate", 2: "high"} PRE_SORENESS_REGION_MAP = { 0: "none", 1: "back", 2: "biceps", 3: "chest", 4: "legs", 5: "shoulder", 6: "triceps", } PRE_SORENESS_SEVERITY_MAP = { 0: "none", 1: "mild", 2: "moderate", 3: "severe", } # ───────────────────────────────────────────── # APP STATE — model loaded once at startup # ───────────────────────────────────────────── app_state = {} @asynccontextmanager async def lifespan(app: FastAPI): """Load both models once at startup, clean up at shutdown.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # ── Post-workout model ──────────────────────────────────── logger.info("Loading post-workout model...") post_model, post_tokenizer = load_post_model( model_path=os.getenv("MODEL_PATH", "post_best_overall_model.pt"), device=device, ) app_state["post_model"] = post_model app_state["post_tokenizer"] = post_tokenizer # ── Pre-workout model ───────────────────────────────────── logger.info("Loading pre-workout model...") pre_model, pre_tokenizer = load_pre_model( model_path=os.getenv("PRE_MODEL_PATH", "pre_best_overall_model.pt"), device=device, ) app_state["pre_model"] = pre_model app_state["pre_tokenizer"] = pre_tokenizer app_state["device"] = device # ── Anthropic client ────────────────────────────────────── app_state["anthropic_client"] = anthropic.Anthropic( api_key=os.getenv("ANTHROPIC_API_KEY") ) logger.info("All models and clients loaded successfully.") yield app_state.clear() logger.info("App shutdown complete.") app = FastAPI( title="Workout Coach Inference API", description="Post-workout DistilBERT + Pre-workout DistilBERT + Claude generation", version="2.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ───────────────────────────────────────────── # REQUEST / RESPONSE SCHEMAS # ───────────────────────────────────────────── # ───────────────────────────────────────────── # POST-WORKOUT SCHEMAS # ───────────────────────────────────────────── class PostWorkoutRequest(BaseModel): # Free-text input from the user — fed into PostWorkoutDistilBERT user_text: str = Field(..., min_length=5, max_length=500, example="That was really tough, chest is killing me but I feel accomplished.") # UI form fields — collected separately in the app duration_minutes: int = Field(..., ge=1, le=300, example=45) workout_type: str = Field(..., example="upper_body_push") user_goal: str = Field(..., example="muscle_gain") # Optional — whether to generate the Claude debrief generate_debrief: bool = Field(default=True) class PostWorkoutBertLabels(BaseModel): mood: str exertion: str soreness_region: str soreness_severity: str completion: str class PostWorkoutSessionResponse(BaseModel): bert_labels: PostWorkoutBertLabels acknowledgement: Optional[str] = None highlights: Optional[str] = None next_session: Optional[str] = None raw_debrief: Optional[str] = None # full unmodified response — fallback class HealthResponse(BaseModel): model_config = {"protected_namespaces": ()} status: str device: str post_model_loaded: bool pre_model_loaded: bool # ───────────────────────────────────────────── # ROUTES # ───────────────────────────────────────────── @app.get("/health", response_model=HealthResponse) def health(): """Health check — confirms both models are loaded and ready.""" return { "status": "ok", "device": str(app_state.get("device", "unknown")), "post_model_loaded": "post_model" in app_state, "pre_model_loaded": "pre_model" in app_state, } @app.post("/post-classify", response_model=PostWorkoutSessionResponse, dependencies=[Depends(verify_key)]) def post_classify_session(req: PostWorkoutRequest): """ Runs PostWorkoutDistilBERT inference on user_text and optionally generates a Claude debrief using the classified labels combined with the session form data. """ model = app_state["post_model"] tokenizer = app_state["post_tokenizer"] device = app_state["device"] client = app_state["anthropic_client"] # ── Step 1: PostWorkoutDistilBERT inference ─────────────── try: raw_preds = predict_post(req.user_text, model, tokenizer, device) except Exception as e: logger.error(f"Post-workout inference error: {e}") raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") # ── Step 2: Decode integer labels → strings ─────────────── bert_labels = decode_post_predictions( raw_preds, POST_MOOD_MAP, POST_EXERTION_MAP, POST_SORENESS_REGION_MAP, POST_SORENESS_SEVERITY_MAP, POST_COMPLETION_MAP, ) # ── Step 3: Optionally generate Claude debrief ──────────── parsed = {} if req.generate_debrief: prompt = build_post_prompt( bert_labels=bert_labels, user_text=req.user_text, duration_minutes=req.duration_minutes, workout_type=req.workout_type, user_goal=req.user_goal, ) try: message = client.messages.create( model="claude-haiku-4-5-20251001", max_tokens=400, messages=[{"role": "user", "content": prompt}], ) raw_debrief = message.content[0].text parsed = parse_debrief(raw_debrief) except Exception as e: logger.error(f"Claude API error (post-workout): {e}") parsed = {} return PostWorkoutSessionResponse( bert_labels=PostWorkoutBertLabels(**bert_labels), acknowledgement=parsed.get("acknowledgement"), highlights=parsed.get("highlights"), next_session=parsed.get("next_session"), raw_debrief=parsed.get("raw"), ) @app.post("/post-classify/labels-only", response_model=PostWorkoutBertLabels, dependencies=[Depends(verify_key)]) def post_classify_labels_only(req: PostWorkoutRequest): """ Runs only PostWorkoutDistilBERT inference. Skips Claude. Useful for storing labels to DB without generating a debrief yet. """ model = app_state["post_model"] tokenizer = app_state["post_tokenizer"] device = app_state["device"] try: raw_preds = predict_post(req.user_text, model, tokenizer, device) bert_labels = decode_post_predictions( raw_preds, POST_MOOD_MAP, POST_EXERTION_MAP, POST_SORENESS_REGION_MAP, POST_SORENESS_SEVERITY_MAP, POST_COMPLETION_MAP, ) return PostWorkoutBertLabels(**bert_labels) except Exception as e: logger.error(f"Post-workout inference error: {e}") raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") # ───────────────────────────────────────────── # PRE-WORKOUT SCHEMAS # ───────────────────────────────────────────── class PreWorkoutRequest(BaseModel): user_text: str = Field( ..., min_length=5, max_length=500, example="Feeling a bit tired but motivated. Back is slightly sore from Tuesday.", ) workout_type: str = Field(..., example="upper_body_push") duration_minutes: int = Field(..., ge=10, le=300, example=45) user_goal: str = Field(..., example="muscle_gain") equipment: List[str] = Field(default=[], example=["barbell", "dumbbells", "bench"]) generate_plan: bool = Field(default=True) class PreWorkoutBertLabels(BaseModel): mood: str energy: str motivation: str stress: str soreness_region: str soreness_severity: str class PreWorkoutResponse(BaseModel): bert_labels: PreWorkoutBertLabels warm_up: Optional[str] = None main_workout: Optional[str] = None cool_down: Optional[str] = None coaching_note: Optional[str] = None raw_plan: Optional[str] = None # full unmodified response — fallback # ───────────────────────────────────────────── # PRE-WORKOUT ROUTES # ───────────────────────────────────────────── @app.post("/pre-classify", response_model=PreWorkoutResponse, dependencies=[Depends(verify_key)]) def pre_classify_session(req: PreWorkoutRequest): """ Runs PreWorkoutDistilBERT inference on user_text and optionally calls Claude to generate a structured workout plan using the classified labels combined with the form selections. """ model = app_state["pre_model"] tokenizer = app_state["pre_tokenizer"] device = app_state["device"] client = app_state["anthropic_client"] # ── Step 1: DistilBERT inference ───────────────────────── try: raw_preds = predict_pre(req.user_text, model, tokenizer, device) except Exception as e: logger.error(f"Pre-workout inference error: {e}") raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") # ── Step 2: Decode integer labels → strings ─────────────── bert_labels = decode_pre_predictions( raw_preds, PRE_MOOD_MAP, PRE_ENERGY_MAP, PRE_MOTIVATION_MAP, PRE_STRESS_MAP, PRE_SORENESS_REGION_MAP, PRE_SORENESS_SEVERITY_MAP, ) # ── Step 3: Optionally generate Claude workout plan ─────── parsed = {} if req.generate_plan: prompt = build_pre_prompt( bert_labels=bert_labels, user_text=req.user_text, workout_type=req.workout_type, duration_minutes=req.duration_minutes, user_goal=req.user_goal, equipment=req.equipment, ) try: message = client.messages.create( model="claude-haiku-4-5-20251001", max_tokens=800, messages=[{"role": "user", "content": prompt}], ) raw_plan = message.content[0].text parsed = parse_workout_plan(raw_plan) except Exception as e: logger.error(f"Claude API error (pre-workout): {e}") parsed = {} return PreWorkoutResponse( bert_labels=PreWorkoutBertLabels(**bert_labels), warm_up=parsed.get("warm_up"), main_workout=parsed.get("main_workout"), cool_down=parsed.get("cool_down"), coaching_note=parsed.get("coaching_note"), raw_plan=parsed.get("raw"), ) @app.post("/pre-classify/labels-only", response_model=PreWorkoutBertLabels, dependencies=[Depends(verify_key)]) def pre_classify_labels_only(req: PreWorkoutRequest): """ Runs only pre-workout DistilBERT inference. Skips Claude. Useful for storing labels to DB without generating a plan yet. """ model = app_state["pre_model"] tokenizer = app_state["pre_tokenizer"] device = app_state["device"] try: raw_preds = predict_pre(req.user_text, model, tokenizer, device) bert_labels = decode_pre_predictions( raw_preds, PRE_MOOD_MAP, PRE_ENERGY_MAP, PRE_MOTIVATION_MAP, PRE_STRESS_MAP, PRE_SORENESS_REGION_MAP, PRE_SORENESS_SEVERITY_MAP, ) return PreWorkoutBertLabels(**bert_labels) except Exception as e: logger.error(f"Pre-workout inference error: {e}") raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") # ───────────────────────────────────────────── # STREAMING ROUTES # ───────────────────────────────────────────── # These endpoints use Server-Sent Events (SSE). # Each event is a line starting with "data: " # followed by a JSON payload, terminated by "\n\n". # # Event types: # {"type": "labels", "bert_labels": {...}} — emitted immediately after # DistilBERT inference (~0.5s) # {"type": "token", "text": "..."} — one Claude output token # {"type": "done"} — stream finished cleanly # {"type": "error", "detail": "..."} — non-fatal error def _sse(payload: dict) -> str: """Format a dict as a single SSE data line.""" return f"data: {json.dumps(payload)}\n\n" @app.post("/post-classify/stream", dependencies=[Depends(verify_key)]) async def post_classify_stream(req: PostWorkoutRequest): """ Streaming version of /classify. Emits BERT labels immediately, then streams Claude debrief tokens. """ model = app_state["post_model"] tokenizer = app_state["post_tokenizer"] device = app_state["device"] client = app_state["anthropic_client"] # ── Step 1: DistilBERT inference (fast — runs before stream opens) ── try: raw_preds = predict_post(req.user_text, model, tokenizer, device) bert_labels = decode_post_predictions( raw_preds, POST_MOOD_MAP, POST_EXERTION_MAP, POST_SORENESS_REGION_MAP, POST_SORENESS_SEVERITY_MAP, POST_COMPLETION_MAP, ) except Exception as e: logger.error(f"Post-workout inference error: {e}") raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") # ── Step 2: Build Claude prompt ───────────────────────────────────── prompt = build_post_prompt( bert_labels=bert_labels, user_text=req.user_text, duration_minutes=req.duration_minutes, workout_type=req.workout_type, user_goal=req.user_goal, ) async def event_stream(): # Emit labels immediately — iOS renders chips before Claude starts yield _sse({"type": "labels", "bert_labels": bert_labels}) if not req.generate_debrief: yield _sse({"type": "done"}) return # Stream Claude tokens as they arrive try: with client.messages.stream( model="claude-haiku-4-5-20251001", max_tokens=400, messages=[{"role": "user", "content": prompt}], ) as stream: for text_chunk in stream.text_stream: yield _sse({"type": "token", "text": text_chunk}) await asyncio.sleep(0) # yield control to event loop except Exception as e: logger.error(f"Claude streaming error (post-workout): {e}") yield _sse({"type": "error", "detail": str(e)}) return yield _sse({"type": "done"}) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", # disables nginx buffering on HF Spaces }, ) @app.post("/pre-classify/stream", dependencies=[Depends(verify_key)]) async def pre_classify_stream(req: PreWorkoutRequest): """ Streaming version of /pre-classify. Emits BERT labels immediately, then streams Claude workout plan tokens. """ model = app_state["pre_model"] tokenizer = app_state["pre_tokenizer"] device = app_state["device"] client = app_state["anthropic_client"] # ── Step 1: DistilBERT inference ──────────────────────────────────── try: raw_preds = predict_pre(req.user_text, model, tokenizer, device) bert_labels = decode_pre_predictions( raw_preds, PRE_MOOD_MAP, PRE_ENERGY_MAP, PRE_MOTIVATION_MAP, PRE_STRESS_MAP, PRE_SORENESS_REGION_MAP, PRE_SORENESS_SEVERITY_MAP, ) except Exception as e: logger.error(f"Pre-workout inference error: {e}") raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") # ── Step 2: Build Claude prompt ───────────────────────────────────── prompt = build_pre_prompt( bert_labels=bert_labels, user_text=req.user_text, workout_type=req.workout_type, duration_minutes=req.duration_minutes, user_goal=req.user_goal, equipment=req.equipment, ) async def event_stream(): # Emit labels immediately yield _sse({"type": "labels", "bert_labels": bert_labels}) if not req.generate_plan: yield _sse({"type": "done"}) return # Stream Claude tokens as they arrive try: with client.messages.stream( model="claude-haiku-4-5-20251001", max_tokens=800, messages=[{"role": "user", "content": prompt}], ) as stream: for text_chunk in stream.text_stream: yield _sse({"type": "token", "text": text_chunk}) await asyncio.sleep(0) except Exception as e: logger.error(f"Claude streaming error (pre-workout): {e}") yield _sse({"type": "error", "detail": str(e)}) return yield _sse({"type": "done"}) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, )