Spaces:
Running
Running
| import os | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| from transformers import DistilBertModel, DistilBertTokenizer | |
| from supabase import create_client, Client | |
| # ββ Logging setup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ββ Label Maps ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| workout_label_map = { | |
| 0: "Chest", 1: "Back", 2: "Legs", 3: "Shoulders", | |
| 4: "Arms", 5: "Core", 6: "Full Body", 7: "Cardio" | |
| } | |
| mood_label_map = { | |
| 0: "Energized", 1: "Tired", 2: "Stressed", | |
| 3: "Motivated", 4: "Neutral" | |
| } | |
| soreness_label_map = { | |
| 0: "None", 1: "Mild", 2: "Severe" | |
| } | |
| # ββ Model Definition ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class MultiHeadDistilBERT(nn.Module): | |
| def __init__(self, num_workout_types, num_moods, num_soreness_levels): | |
| super(MultiHeadDistilBERT, self).__init__() | |
| self.bert = DistilBertModel.from_pretrained( | |
| 'distilbert-base-uncased', | |
| token=os.getenv('HF_TOKEN') | |
| ) | |
| hidden_size = self.bert.config.hidden_size | |
| self.dropout = nn.Dropout(0.3) | |
| self.workout_head = nn.Linear(hidden_size, num_workout_types) | |
| self.mood_head = nn.Linear(hidden_size, num_moods) | |
| self.soreness_head = nn.Linear(hidden_size, num_soreness_levels) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) | |
| return ( | |
| self.workout_head(cls_output), | |
| self.mood_head(cls_output), | |
| self.soreness_head(cls_output) | |
| ) | |
| # ββ App State β loaded once at startup βββββββββββββββββββββββββββββββββββββββ | |
| class AppState: | |
| model: MultiHeadDistilBERT = None | |
| tokenizer: DistilBertTokenizer = None | |
| supabase: Client = None | |
| device: torch.device = None | |
| state = AppState() | |
| # ββ Lifespan β runs once on startup and shutdown ββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| # ββ Startup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Loading model, tokenizer and Supabase client...") | |
| state.device = torch.device('cpu') | |
| # Load tokenizer once | |
| state.tokenizer = DistilBertTokenizer.from_pretrained( | |
| 'distilbert-base-uncased', | |
| token=os.getenv('HF_TOKEN') | |
| ) | |
| logger.info("Tokenizer loaded") | |
| # Load model once | |
| state.model = MultiHeadDistilBERT( | |
| num_workout_types=8, | |
| num_moods=5, | |
| num_soreness_levels=3 | |
| ) | |
| state.model.load_state_dict( | |
| torch.load('best_DistilBERT_model.pt', map_location=state.device) | |
| ) | |
| state.model.to(state.device) | |
| state.model.eval() | |
| logger.info("Model loaded") | |
| # Create Supabase client once | |
| state.supabase = create_client( | |
| os.getenv('SUPA_URL'), | |
| os.getenv('SUPA_KEY') | |
| ) | |
| logger.info("Supabase client created") | |
| logger.info("Startup complete β API is ready") | |
| yield # β API runs here | |
| # ββ Shutdown ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Shutting down API") | |
| app = FastAPI(lifespan=lifespan) | |
| # ββ Schemas βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PredictRequest(BaseModel): | |
| user_input: str | |
| class ExerciseResponse(BaseModel): | |
| id: int | |
| name: str | |
| workout_type: int | |
| difficulty: str | |
| notes: str | |
| suitable_moods: List[int] | |
| suitable_soreness: List[int] | |
| class PredictResponse(BaseModel): | |
| workout: str | |
| workout_conf: float | |
| mood: str | |
| mood_conf: float | |
| soreness: str | |
| soreness_conf: float | |
| exercises: List[ExerciseResponse] | |
| def format_pg_array(values: list) -> str: | |
| """Convert a Python list to PostgreSQL array literal format""" | |
| return '{' + ','.join(str(v) for v in values) + '}' | |
| # ββ Supabase Helper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_suitable_exercises(workout_type: int, mood: int, soreness: int) -> List[ExerciseResponse]: | |
| try: | |
| logger.info(f"Querying exercises β workout_type: {workout_type}, mood: {mood}, soreness: {soreness}") | |
| print("DATA") | |
| response = ( | |
| state.supabase.table('exerciseai') | |
| .select('*') | |
| .eq('workout_type', workout_type) | |
| .contains('suitable_moods', format_pg_array([mood])) | |
| .contains('suitable_soreness', format_pg_array([soreness])) | |
| .execute() | |
| ) | |
| logger.info(f"Supabase returned {len(response.data)} exercises") | |
| return [ExerciseResponse(**exercise) for exercise in response.data] | |
| except Exception as e: | |
| logger.error(f"Supabase query failed: {e}") | |
| raise HTTPException(status_code=503, detail="Failed to fetch exercises from database") | |
| # ββ Health Check ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health_check(): | |
| return { | |
| "status": "ok", | |
| "model": "MultiHeadDistilBERT", | |
| "device": str(state.device) | |
| } | |
| # ββ Predict Endpoint ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(request: PredictRequest): | |
| print("HERE") | |
| # ββ Input validation ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not request.user_input.strip(): | |
| raise HTTPException(status_code=400, detail="user_input cannot be empty") | |
| try: | |
| # ββ Tokenize ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| encoding = state.tokenizer( | |
| request.user_input, | |
| max_length=128, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(state.device) | |
| attention_mask = encoding['attention_mask'].to(state.device) | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with torch.no_grad(): | |
| workout_logits, mood_logits, soreness_logits = state.model( | |
| input_ids, attention_mask | |
| ) | |
| print("HERE2") | |
| # ββ Softmax + confidence ββββββββββββββββββββββββββββββββββββββββββββββ | |
| workout_probs = torch.softmax(workout_logits, dim=1) | |
| mood_probs = torch.softmax(mood_logits, dim=1) | |
| soreness_probs = torch.softmax(soreness_logits, dim=1) | |
| workout_conf, workout_pred = workout_probs.max(dim=1) | |
| mood_conf, mood_pred = mood_probs.max(dim=1) | |
| soreness_conf, soreness_pred = soreness_probs.max(dim=1) | |
| # ββ Map to labels β reuse pred variables, no redundant argmax βββββββββ | |
| predicted_workout = workout_label_map[workout_pred.item()] | |
| predicted_mood = mood_label_map[mood_pred.item()] | |
| predicted_soreness = soreness_label_map[soreness_pred.item()] | |
| logger.info( | |
| f"Prediction β Workout: {predicted_workout} ({workout_conf.item()*100:.1f}%) | " | |
| f"Mood: {predicted_mood} ({mood_conf.item()*100:.1f}%) | " | |
| f"Soreness: {predicted_soreness} ({soreness_conf.item()*100:.1f}%)" | |
| ) | |
| # ββ Fetch exercises βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| suitable_exercises = get_suitable_exercises( | |
| workout_type = workout_pred.item(), | |
| mood = mood_pred.item(), | |
| soreness = soreness_pred.item() | |
| ) | |
| return PredictResponse( | |
| workout = predicted_workout, | |
| workout_conf = round(workout_conf.item() * 100, 1), | |
| mood = predicted_mood, | |
| mood_conf = round(mood_conf.item() * 100, 1), | |
| soreness = predicted_soreness, | |
| soreness_conf = round(soreness_conf.item() * 100, 1), | |
| exercises = suitable_exercises | |
| ) | |
| except HTTPException: | |
| raise # β re-raise HTTP exceptions from get_suitable_exercises | |
| except Exception as e: | |
| logger.error(f"Prediction failed: {e}") | |
| raise HTTPException(status_code=500, detail="Prediction failed. Please try again.") | |