jflo's picture
Update app.py
d35a46f verified
"""
Workout Coach β€” FastAPI Inference App
Runs DistilBERT classification + Claude debrief generation
Designed for Hugging Face Spaces with Docker
"""
from fastapi import FastAPI, HTTPException, Header, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
from typing import Optional, Dict, List
import torch
import anthropic
import asyncio
import json
import os
import logging
from model import PostWorkoutDistilBERT, load_post_model, PreWorkoutDistilBERT, load_pre_model
from inference import (
predict_post, decode_post_predictions, build_post_prompt, parse_debrief,
predict_pre, decode_pre_predictions, build_pre_prompt, parse_workout_plan,
)
# ─────────────────────────────────────────────
# LOGGING
# ─────────────────────────────────────────────
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────
# API KEY AUTHENTICATION
# ─────────────────────────────────────────────
API_KEY = os.getenv("APP_API_KEY")
async def verify_key(x_api_key: str = Header(...)):
if x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
# ─────────────────────────────────────────────
# POST-WORKOUT LABEL MAPS
# ─────────────────────────────────────────────
POST_MOOD_MAP = {
0: "accomplished", 1: "anxious", 2: "distracted",
3: "energized", 4: "fatigued", 5: "frustrated",
6: "neutral", 7: "positive",
}
POST_EXERTION_MAP = {0: "low", 1: "moderate", 2: "high"}
POST_COMPLETION_MAP = {0: "partial", 1: "full"}
POST_SORENESS_REGION_MAP = {
0: "none", 1: "back", 2: "biceps",
3: "chest", 4: "legs", 5: "shoulder", 6: "triceps",
}
POST_SORENESS_SEVERITY_MAP = {
0: "none", 1: "mild", 2: "moderate", 3: "severe",
}
# ─────────────────────────────────────────────
# PRE-WORKOUT LABEL MAPS
# ─────────────────────────────────────────────
PRE_MOOD_MAP = {
0: "accomplished", 1: "anxious", 2: "distracted",
3: "energized", 4: "fatigued", 5: "frustrated",
6: "neutral", 7: "positive",
}
PRE_ENERGY_MAP = {0: "low", 1: "moderate", 2: "high"}
PRE_MOTIVATION_MAP = {0: "low", 1: "moderate", 2: "high"}
PRE_STRESS_MAP = {0: "low", 1: "moderate", 2: "high"}
PRE_SORENESS_REGION_MAP = {
0: "none", 1: "back", 2: "biceps",
3: "chest", 4: "legs", 5: "shoulder", 6: "triceps",
}
PRE_SORENESS_SEVERITY_MAP = {
0: "none", 1: "mild", 2: "moderate", 3: "severe",
}
# ─────────────────────────────────────────────
# APP STATE β€” model loaded once at startup
# ─────────────────────────────────────────────
app_state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load both models once at startup, clean up at shutdown."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# ── Post-workout model ────────────────────────────────────
logger.info("Loading post-workout model...")
post_model, post_tokenizer = load_post_model(
model_path=os.getenv("MODEL_PATH", "post_best_overall_model.pt"),
device=device,
)
app_state["post_model"] = post_model
app_state["post_tokenizer"] = post_tokenizer
# ── Pre-workout model ─────────────────────────────────────
logger.info("Loading pre-workout model...")
pre_model, pre_tokenizer = load_pre_model(
model_path=os.getenv("PRE_MODEL_PATH", "pre_best_overall_model.pt"),
device=device,
)
app_state["pre_model"] = pre_model
app_state["pre_tokenizer"] = pre_tokenizer
app_state["device"] = device
# ── Anthropic client ──────────────────────────────────────
app_state["anthropic_client"] = anthropic.Anthropic(
api_key=os.getenv("ANTHROPIC_API_KEY")
)
logger.info("All models and clients loaded successfully.")
yield
app_state.clear()
logger.info("App shutdown complete.")
app = FastAPI(
title="Workout Coach Inference API",
description="Post-workout DistilBERT + Pre-workout DistilBERT + Claude generation",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ─────────────────────────────────────────────
# REQUEST / RESPONSE SCHEMAS
# ─────────────────────────────────────────────
# ─────────────────────────────────────────────
# POST-WORKOUT SCHEMAS
# ─────────────────────────────────────────────
class PostWorkoutRequest(BaseModel):
# Free-text input from the user β€” fed into PostWorkoutDistilBERT
user_text: str = Field(..., min_length=5, max_length=500,
example="That was really tough, chest is killing me but I feel accomplished.")
# UI form fields β€” collected separately in the app
duration_minutes: int = Field(..., ge=1, le=300, example=45)
workout_type: str = Field(..., example="upper_body_push")
user_goal: str = Field(..., example="muscle_gain")
# Optional β€” whether to generate the Claude debrief
generate_debrief: bool = Field(default=True)
class PostWorkoutBertLabels(BaseModel):
mood: str
exertion: str
soreness_region: str
soreness_severity: str
completion: str
class PostWorkoutSessionResponse(BaseModel):
bert_labels: PostWorkoutBertLabels
acknowledgement: Optional[str] = None
highlights: Optional[str] = None
next_session: Optional[str] = None
raw_debrief: Optional[str] = None # full unmodified response β€” fallback
class HealthResponse(BaseModel):
model_config = {"protected_namespaces": ()}
status: str
device: str
post_model_loaded: bool
pre_model_loaded: bool
# ─────────────────────────────────────────────
# ROUTES
# ─────────────────────────────────────────────
@app.get("/health", response_model=HealthResponse)
def health():
"""Health check β€” confirms both models are loaded and ready."""
return {
"status": "ok",
"device": str(app_state.get("device", "unknown")),
"post_model_loaded": "post_model" in app_state,
"pre_model_loaded": "pre_model" in app_state,
}
@app.post("/post-classify", response_model=PostWorkoutSessionResponse,
dependencies=[Depends(verify_key)])
def post_classify_session(req: PostWorkoutRequest):
"""
Runs PostWorkoutDistilBERT inference on user_text and optionally
generates a Claude debrief using the classified labels
combined with the session form data.
"""
model = app_state["post_model"]
tokenizer = app_state["post_tokenizer"]
device = app_state["device"]
client = app_state["anthropic_client"]
# ── Step 1: PostWorkoutDistilBERT inference ───────────────
try:
raw_preds = predict_post(req.user_text, model, tokenizer, device)
except Exception as e:
logger.error(f"Post-workout inference error: {e}")
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
# ── Step 2: Decode integer labels β†’ strings ───────────────
bert_labels = decode_post_predictions(
raw_preds,
POST_MOOD_MAP,
POST_EXERTION_MAP,
POST_SORENESS_REGION_MAP,
POST_SORENESS_SEVERITY_MAP,
POST_COMPLETION_MAP,
)
# ── Step 3: Optionally generate Claude debrief ────────────
parsed = {}
if req.generate_debrief:
prompt = build_post_prompt(
bert_labels=bert_labels,
user_text=req.user_text,
duration_minutes=req.duration_minutes,
workout_type=req.workout_type,
user_goal=req.user_goal,
)
try:
message = client.messages.create(
model="claude-haiku-4-5-20251001",
max_tokens=400,
messages=[{"role": "user", "content": prompt}],
)
raw_debrief = message.content[0].text
parsed = parse_debrief(raw_debrief)
except Exception as e:
logger.error(f"Claude API error (post-workout): {e}")
parsed = {}
return PostWorkoutSessionResponse(
bert_labels=PostWorkoutBertLabels(**bert_labels),
acknowledgement=parsed.get("acknowledgement"),
highlights=parsed.get("highlights"),
next_session=parsed.get("next_session"),
raw_debrief=parsed.get("raw"),
)
@app.post("/post-classify/labels-only", response_model=PostWorkoutBertLabels,
dependencies=[Depends(verify_key)])
def post_classify_labels_only(req: PostWorkoutRequest):
"""
Runs only PostWorkoutDistilBERT inference. Skips Claude.
Useful for storing labels to DB without generating a debrief yet.
"""
model = app_state["post_model"]
tokenizer = app_state["post_tokenizer"]
device = app_state["device"]
try:
raw_preds = predict_post(req.user_text, model, tokenizer, device)
bert_labels = decode_post_predictions(
raw_preds,
POST_MOOD_MAP,
POST_EXERTION_MAP,
POST_SORENESS_REGION_MAP,
POST_SORENESS_SEVERITY_MAP,
POST_COMPLETION_MAP,
)
return PostWorkoutBertLabels(**bert_labels)
except Exception as e:
logger.error(f"Post-workout inference error: {e}")
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
# ─────────────────────────────────────────────
# PRE-WORKOUT SCHEMAS
# ─────────────────────────────────────────────
class PreWorkoutRequest(BaseModel):
user_text: str = Field(
..., min_length=5, max_length=500,
example="Feeling a bit tired but motivated. Back is slightly sore from Tuesday.",
)
workout_type: str = Field(..., example="upper_body_push")
duration_minutes: int = Field(..., ge=10, le=300, example=45)
user_goal: str = Field(..., example="muscle_gain")
equipment: List[str] = Field(default=[], example=["barbell", "dumbbells", "bench"])
generate_plan: bool = Field(default=True)
class PreWorkoutBertLabels(BaseModel):
mood: str
energy: str
motivation: str
stress: str
soreness_region: str
soreness_severity: str
class PreWorkoutResponse(BaseModel):
bert_labels: PreWorkoutBertLabels
warm_up: Optional[str] = None
main_workout: Optional[str] = None
cool_down: Optional[str] = None
coaching_note: Optional[str] = None
raw_plan: Optional[str] = None # full unmodified response β€” fallback
# ─────────────────────────────────────────────
# PRE-WORKOUT ROUTES
# ─────────────────────────────────────────────
@app.post("/pre-classify", response_model=PreWorkoutResponse,
dependencies=[Depends(verify_key)])
def pre_classify_session(req: PreWorkoutRequest):
"""
Runs PreWorkoutDistilBERT inference on user_text and optionally
calls Claude to generate a structured workout plan using the
classified labels combined with the form selections.
"""
model = app_state["pre_model"]
tokenizer = app_state["pre_tokenizer"]
device = app_state["device"]
client = app_state["anthropic_client"]
# ── Step 1: DistilBERT inference ─────────────────────────
try:
raw_preds = predict_pre(req.user_text, model, tokenizer, device)
except Exception as e:
logger.error(f"Pre-workout inference error: {e}")
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
# ── Step 2: Decode integer labels β†’ strings ───────────────
bert_labels = decode_pre_predictions(
raw_preds,
PRE_MOOD_MAP,
PRE_ENERGY_MAP,
PRE_MOTIVATION_MAP,
PRE_STRESS_MAP,
PRE_SORENESS_REGION_MAP,
PRE_SORENESS_SEVERITY_MAP,
)
# ── Step 3: Optionally generate Claude workout plan ───────
parsed = {}
if req.generate_plan:
prompt = build_pre_prompt(
bert_labels=bert_labels,
user_text=req.user_text,
workout_type=req.workout_type,
duration_minutes=req.duration_minutes,
user_goal=req.user_goal,
equipment=req.equipment,
)
try:
message = client.messages.create(
model="claude-haiku-4-5-20251001",
max_tokens=800,
messages=[{"role": "user", "content": prompt}],
)
raw_plan = message.content[0].text
parsed = parse_workout_plan(raw_plan)
except Exception as e:
logger.error(f"Claude API error (pre-workout): {e}")
parsed = {}
return PreWorkoutResponse(
bert_labels=PreWorkoutBertLabels(**bert_labels),
warm_up=parsed.get("warm_up"),
main_workout=parsed.get("main_workout"),
cool_down=parsed.get("cool_down"),
coaching_note=parsed.get("coaching_note"),
raw_plan=parsed.get("raw"),
)
@app.post("/pre-classify/labels-only", response_model=PreWorkoutBertLabels,
dependencies=[Depends(verify_key)])
def pre_classify_labels_only(req: PreWorkoutRequest):
"""
Runs only pre-workout DistilBERT inference. Skips Claude.
Useful for storing labels to DB without generating a plan yet.
"""
model = app_state["pre_model"]
tokenizer = app_state["pre_tokenizer"]
device = app_state["device"]
try:
raw_preds = predict_pre(req.user_text, model, tokenizer, device)
bert_labels = decode_pre_predictions(
raw_preds,
PRE_MOOD_MAP,
PRE_ENERGY_MAP,
PRE_MOTIVATION_MAP,
PRE_STRESS_MAP,
PRE_SORENESS_REGION_MAP,
PRE_SORENESS_SEVERITY_MAP,
)
return PreWorkoutBertLabels(**bert_labels)
except Exception as e:
logger.error(f"Pre-workout inference error: {e}")
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
# ─────────────────────────────────────────────
# STREAMING ROUTES
# ─────────────────────────────────────────────
# These endpoints use Server-Sent Events (SSE).
# Each event is a line starting with "data: "
# followed by a JSON payload, terminated by "\n\n".
#
# Event types:
# {"type": "labels", "bert_labels": {...}} β€” emitted immediately after
# DistilBERT inference (~0.5s)
# {"type": "token", "text": "..."} β€” one Claude output token
# {"type": "done"} β€” stream finished cleanly
# {"type": "error", "detail": "..."} β€” non-fatal error
def _sse(payload: dict) -> str:
"""Format a dict as a single SSE data line."""
return f"data: {json.dumps(payload)}\n\n"
@app.post("/post-classify/stream", dependencies=[Depends(verify_key)])
async def post_classify_stream(req: PostWorkoutRequest):
"""
Streaming version of /classify.
Emits BERT labels immediately, then streams Claude debrief tokens.
"""
model = app_state["post_model"]
tokenizer = app_state["post_tokenizer"]
device = app_state["device"]
client = app_state["anthropic_client"]
# ── Step 1: DistilBERT inference (fast β€” runs before stream opens) ──
try:
raw_preds = predict_post(req.user_text, model, tokenizer, device)
bert_labels = decode_post_predictions(
raw_preds,
POST_MOOD_MAP,
POST_EXERTION_MAP,
POST_SORENESS_REGION_MAP,
POST_SORENESS_SEVERITY_MAP,
POST_COMPLETION_MAP,
)
except Exception as e:
logger.error(f"Post-workout inference error: {e}")
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
# ── Step 2: Build Claude prompt ─────────────────────────────────────
prompt = build_post_prompt(
bert_labels=bert_labels,
user_text=req.user_text,
duration_minutes=req.duration_minutes,
workout_type=req.workout_type,
user_goal=req.user_goal,
)
async def event_stream():
# Emit labels immediately β€” iOS renders chips before Claude starts
yield _sse({"type": "labels", "bert_labels": bert_labels})
if not req.generate_debrief:
yield _sse({"type": "done"})
return
# Stream Claude tokens as they arrive
try:
with client.messages.stream(
model="claude-haiku-4-5-20251001",
max_tokens=400,
messages=[{"role": "user", "content": prompt}],
) as stream:
for text_chunk in stream.text_stream:
yield _sse({"type": "token", "text": text_chunk})
await asyncio.sleep(0) # yield control to event loop
except Exception as e:
logger.error(f"Claude streaming error (post-workout): {e}")
yield _sse({"type": "error", "detail": str(e)})
return
yield _sse({"type": "done"})
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # disables nginx buffering on HF Spaces
},
)
@app.post("/pre-classify/stream", dependencies=[Depends(verify_key)])
async def pre_classify_stream(req: PreWorkoutRequest):
"""
Streaming version of /pre-classify.
Emits BERT labels immediately, then streams Claude workout plan tokens.
"""
model = app_state["pre_model"]
tokenizer = app_state["pre_tokenizer"]
device = app_state["device"]
client = app_state["anthropic_client"]
# ── Step 1: DistilBERT inference ────────────────────────────────────
try:
raw_preds = predict_pre(req.user_text, model, tokenizer, device)
bert_labels = decode_pre_predictions(
raw_preds,
PRE_MOOD_MAP,
PRE_ENERGY_MAP,
PRE_MOTIVATION_MAP,
PRE_STRESS_MAP,
PRE_SORENESS_REGION_MAP,
PRE_SORENESS_SEVERITY_MAP,
)
except Exception as e:
logger.error(f"Pre-workout inference error: {e}")
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
# ── Step 2: Build Claude prompt ─────────────────────────────────────
prompt = build_pre_prompt(
bert_labels=bert_labels,
user_text=req.user_text,
workout_type=req.workout_type,
duration_minutes=req.duration_minutes,
user_goal=req.user_goal,
equipment=req.equipment,
)
async def event_stream():
# Emit labels immediately
yield _sse({"type": "labels", "bert_labels": bert_labels})
if not req.generate_plan:
yield _sse({"type": "done"})
return
# Stream Claude tokens as they arrive
try:
with client.messages.stream(
model="claude-haiku-4-5-20251001",
max_tokens=800,
messages=[{"role": "user", "content": prompt}],
) as stream:
for text_chunk in stream.text_stream:
yield _sse({"type": "token", "text": text_chunk})
await asyncio.sleep(0)
except Exception as e:
logger.error(f"Claude streaming error (pre-workout): {e}")
yield _sse({"type": "error", "detail": str(e)})
return
yield _sse({"type": "done"})
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)