jflo's picture
Update app.py
ad20b81 verified
import os
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizer
from supabase import create_client, Client
# ── Logging setup ─────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ── Label Maps ────────────────────────────────────────────────────────────────
workout_label_map = {
0: "Chest", 1: "Back", 2: "Legs", 3: "Shoulders",
4: "Arms", 5: "Core", 6: "Full Body", 7: "Cardio"
}
mood_label_map = {
0: "Energized", 1: "Tired", 2: "Stressed",
3: "Motivated", 4: "Neutral"
}
soreness_label_map = {
0: "None", 1: "Mild", 2: "Severe"
}
# ── Model Definition ──────────────────────────────────────────────────────────
class MultiHeadDistilBERT(nn.Module):
def __init__(self, num_workout_types, num_moods, num_soreness_levels):
super(MultiHeadDistilBERT, self).__init__()
self.bert = DistilBertModel.from_pretrained(
'distilbert-base-uncased',
token=os.getenv('HF_TOKEN')
)
hidden_size = self.bert.config.hidden_size
self.dropout = nn.Dropout(0.3)
self.workout_head = nn.Linear(hidden_size, num_workout_types)
self.mood_head = nn.Linear(hidden_size, num_moods)
self.soreness_head = nn.Linear(hidden_size, num_soreness_levels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = self.dropout(outputs.last_hidden_state[:, 0, :])
return (
self.workout_head(cls_output),
self.mood_head(cls_output),
self.soreness_head(cls_output)
)
# ── App State β€” loaded once at startup ───────────────────────────────────────
class AppState:
model: MultiHeadDistilBERT = None
tokenizer: DistilBertTokenizer = None
supabase: Client = None
device: torch.device = None
state = AppState()
# ── Lifespan β€” runs once on startup and shutdown ──────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
# ── Startup ───────────────────────────────────────────────────────────────
logger.info("Loading model, tokenizer and Supabase client...")
state.device = torch.device('cpu')
# Load tokenizer once
state.tokenizer = DistilBertTokenizer.from_pretrained(
'distilbert-base-uncased',
token=os.getenv('HF_TOKEN')
)
logger.info("Tokenizer loaded")
# Load model once
state.model = MultiHeadDistilBERT(
num_workout_types=8,
num_moods=5,
num_soreness_levels=3
)
state.model.load_state_dict(
torch.load('best_DistilBERT_model.pt', map_location=state.device)
)
state.model.to(state.device)
state.model.eval()
logger.info("Model loaded")
# Create Supabase client once
state.supabase = create_client(
os.getenv('SUPA_URL'),
os.getenv('SUPA_KEY')
)
logger.info("Supabase client created")
logger.info("Startup complete β€” API is ready")
yield # ← API runs here
# ── Shutdown ──────────────────────────────────────────────────────────────
logger.info("Shutting down API")
app = FastAPI(lifespan=lifespan)
# ── Schemas ───────────────────────────────────────────────────────────────────
class PredictRequest(BaseModel):
user_input: str
class ExerciseResponse(BaseModel):
id: int
name: str
workout_type: int
difficulty: str
notes: str
suitable_moods: List[int]
suitable_soreness: List[int]
class PredictResponse(BaseModel):
workout: str
workout_conf: float
mood: str
mood_conf: float
soreness: str
soreness_conf: float
exercises: List[ExerciseResponse]
def format_pg_array(values: list) -> str:
"""Convert a Python list to PostgreSQL array literal format"""
return '{' + ','.join(str(v) for v in values) + '}'
# ── Supabase Helper ───────────────────────────────────────────────────────────
def get_suitable_exercises(workout_type: int, mood: int, soreness: int) -> List[ExerciseResponse]:
try:
logger.info(f"Querying exercises β€” workout_type: {workout_type}, mood: {mood}, soreness: {soreness}")
print("DATA")
response = (
state.supabase.table('exerciseai')
.select('*')
.eq('workout_type', workout_type)
.contains('suitable_moods', format_pg_array([mood]))
.contains('suitable_soreness', format_pg_array([soreness]))
.execute()
)
logger.info(f"Supabase returned {len(response.data)} exercises")
return [ExerciseResponse(**exercise) for exercise in response.data]
except Exception as e:
logger.error(f"Supabase query failed: {e}")
raise HTTPException(status_code=503, detail="Failed to fetch exercises from database")
# ── Health Check ──────────────────────────────────────────────────────────────
@app.get("/")
def health_check():
return {
"status": "ok",
"model": "MultiHeadDistilBERT",
"device": str(state.device)
}
# ── Predict Endpoint ──────────────────────────────────────────────────────────
@app.post("/predict", response_model=PredictResponse)
def predict(request: PredictRequest):
print("HERE")
# ── Input validation ──────────────────────────────────────────────────────
if not request.user_input.strip():
raise HTTPException(status_code=400, detail="user_input cannot be empty")
try:
# ── Tokenize ──────────────────────────────────────────────────────────
encoding = state.tokenizer(
request.user_input,
max_length=128,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(state.device)
attention_mask = encoding['attention_mask'].to(state.device)
# ── Inference ─────────────────────────────────────────────────────────
with torch.no_grad():
workout_logits, mood_logits, soreness_logits = state.model(
input_ids, attention_mask
)
print("HERE2")
# ── Softmax + confidence ──────────────────────────────────────────────
workout_probs = torch.softmax(workout_logits, dim=1)
mood_probs = torch.softmax(mood_logits, dim=1)
soreness_probs = torch.softmax(soreness_logits, dim=1)
workout_conf, workout_pred = workout_probs.max(dim=1)
mood_conf, mood_pred = mood_probs.max(dim=1)
soreness_conf, soreness_pred = soreness_probs.max(dim=1)
# ── Map to labels β€” reuse pred variables, no redundant argmax ─────────
predicted_workout = workout_label_map[workout_pred.item()]
predicted_mood = mood_label_map[mood_pred.item()]
predicted_soreness = soreness_label_map[soreness_pred.item()]
logger.info(
f"Prediction β€” Workout: {predicted_workout} ({workout_conf.item()*100:.1f}%) | "
f"Mood: {predicted_mood} ({mood_conf.item()*100:.1f}%) | "
f"Soreness: {predicted_soreness} ({soreness_conf.item()*100:.1f}%)"
)
# ── Fetch exercises ───────────────────────────────────────────────────
suitable_exercises = get_suitable_exercises(
workout_type = workout_pred.item(),
mood = mood_pred.item(),
soreness = soreness_pred.item()
)
return PredictResponse(
workout = predicted_workout,
workout_conf = round(workout_conf.item() * 100, 1),
mood = predicted_mood,
mood_conf = round(mood_conf.item() * 100, 1),
soreness = predicted_soreness,
soreness_conf = round(soreness_conf.item() * 100, 1),
exercises = suitable_exercises
)
except HTTPException:
raise # ← re-raise HTTP exceptions from get_suitable_exercises
except Exception as e:
logger.error(f"Prediction failed: {e}")
raise HTTPException(status_code=500, detail="Prediction failed. Please try again.")