import os import logging from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List import torch import torch.nn as nn from transformers import DistilBertModel, DistilBertTokenizer from supabase import create_client, Client # ── Logging setup ───────────────────────────────────────────────────────────── logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ── Label Maps ──────────────────────────────────────────────────────────────── workout_label_map = { 0: "Chest", 1: "Back", 2: "Legs", 3: "Shoulders", 4: "Arms", 5: "Core", 6: "Full Body", 7: "Cardio" } mood_label_map = { 0: "Energized", 1: "Tired", 2: "Stressed", 3: "Motivated", 4: "Neutral" } soreness_label_map = { 0: "None", 1: "Mild", 2: "Severe" } # ── Model Definition ────────────────────────────────────────────────────────── class MultiHeadDistilBERT(nn.Module): def __init__(self, num_workout_types, num_moods, num_soreness_levels): super(MultiHeadDistilBERT, self).__init__() self.bert = DistilBertModel.from_pretrained( 'distilbert-base-uncased', token=os.getenv('HF_TOKEN') ) hidden_size = self.bert.config.hidden_size self.dropout = nn.Dropout(0.3) self.workout_head = nn.Linear(hidden_size, num_workout_types) self.mood_head = nn.Linear(hidden_size, num_moods) self.soreness_head = nn.Linear(hidden_size, num_soreness_levels) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) return ( self.workout_head(cls_output), self.mood_head(cls_output), self.soreness_head(cls_output) ) # ── App State — loaded once at startup ─────────────────────────────────────── class AppState: model: MultiHeadDistilBERT = None tokenizer: DistilBertTokenizer = None supabase: Client = None device: torch.device = None state = AppState() # ── Lifespan — runs once on startup and shutdown ────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): # ── Startup ─────────────────────────────────────────────────────────────── logger.info("Loading model, tokenizer and Supabase client...") state.device = torch.device('cpu') # Load tokenizer once state.tokenizer = DistilBertTokenizer.from_pretrained( 'distilbert-base-uncased', token=os.getenv('HF_TOKEN') ) logger.info("Tokenizer loaded") # Load model once state.model = MultiHeadDistilBERT( num_workout_types=8, num_moods=5, num_soreness_levels=3 ) state.model.load_state_dict( torch.load('best_DistilBERT_model.pt', map_location=state.device) ) state.model.to(state.device) state.model.eval() logger.info("Model loaded") # Create Supabase client once state.supabase = create_client( os.getenv('SUPA_URL'), os.getenv('SUPA_KEY') ) logger.info("Supabase client created") logger.info("Startup complete — API is ready") yield # ← API runs here # ── Shutdown ────────────────────────────────────────────────────────────── logger.info("Shutting down API") app = FastAPI(lifespan=lifespan) # ── Schemas ─────────────────────────────────────────────────────────────────── class PredictRequest(BaseModel): user_input: str class ExerciseResponse(BaseModel): id: int name: str workout_type: int difficulty: str notes: str suitable_moods: List[int] suitable_soreness: List[int] class PredictResponse(BaseModel): workout: str workout_conf: float mood: str mood_conf: float soreness: str soreness_conf: float exercises: List[ExerciseResponse] def format_pg_array(values: list) -> str: """Convert a Python list to PostgreSQL array literal format""" return '{' + ','.join(str(v) for v in values) + '}' # ── Supabase Helper ─────────────────────────────────────────────────────────── def get_suitable_exercises(workout_type: int, mood: int, soreness: int) -> List[ExerciseResponse]: try: logger.info(f"Querying exercises — workout_type: {workout_type}, mood: {mood}, soreness: {soreness}") print("DATA") response = ( state.supabase.table('exerciseai') .select('*') .eq('workout_type', workout_type) .contains('suitable_moods', format_pg_array([mood])) .contains('suitable_soreness', format_pg_array([soreness])) .execute() ) logger.info(f"Supabase returned {len(response.data)} exercises") return [ExerciseResponse(**exercise) for exercise in response.data] except Exception as e: logger.error(f"Supabase query failed: {e}") raise HTTPException(status_code=503, detail="Failed to fetch exercises from database") # ── Health Check ────────────────────────────────────────────────────────────── @app.get("/") def health_check(): return { "status": "ok", "model": "MultiHeadDistilBERT", "device": str(state.device) } # ── Predict Endpoint ────────────────────────────────────────────────────────── @app.post("/predict", response_model=PredictResponse) def predict(request: PredictRequest): print("HERE") # ── Input validation ────────────────────────────────────────────────────── if not request.user_input.strip(): raise HTTPException(status_code=400, detail="user_input cannot be empty") try: # ── Tokenize ────────────────────────────────────────────────────────── encoding = state.tokenizer( request.user_input, max_length=128, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(state.device) attention_mask = encoding['attention_mask'].to(state.device) # ── Inference ───────────────────────────────────────────────────────── with torch.no_grad(): workout_logits, mood_logits, soreness_logits = state.model( input_ids, attention_mask ) print("HERE2") # ── Softmax + confidence ────────────────────────────────────────────── workout_probs = torch.softmax(workout_logits, dim=1) mood_probs = torch.softmax(mood_logits, dim=1) soreness_probs = torch.softmax(soreness_logits, dim=1) workout_conf, workout_pred = workout_probs.max(dim=1) mood_conf, mood_pred = mood_probs.max(dim=1) soreness_conf, soreness_pred = soreness_probs.max(dim=1) # ── Map to labels — reuse pred variables, no redundant argmax ───────── predicted_workout = workout_label_map[workout_pred.item()] predicted_mood = mood_label_map[mood_pred.item()] predicted_soreness = soreness_label_map[soreness_pred.item()] logger.info( f"Prediction — Workout: {predicted_workout} ({workout_conf.item()*100:.1f}%) | " f"Mood: {predicted_mood} ({mood_conf.item()*100:.1f}%) | " f"Soreness: {predicted_soreness} ({soreness_conf.item()*100:.1f}%)" ) # ── Fetch exercises ─────────────────────────────────────────────────── suitable_exercises = get_suitable_exercises( workout_type = workout_pred.item(), mood = mood_pred.item(), soreness = soreness_pred.item() ) return PredictResponse( workout = predicted_workout, workout_conf = round(workout_conf.item() * 100, 1), mood = predicted_mood, mood_conf = round(mood_conf.item() * 100, 1), soreness = predicted_soreness, soreness_conf = round(soreness_conf.item() * 100, 1), exercises = suitable_exercises ) except HTTPException: raise # ← re-raise HTTP exceptions from get_suitable_exercises except Exception as e: logger.error(f"Prediction failed: {e}") raise HTTPException(status_code=500, detail="Prediction failed. Please try again.")