# models/model.py # BERT model definitions + loader for post-workout physical and mental classifiers import torch import torch.nn as nn import torch.nn.functional as F from transformers import BertModel, BertTokenizer from models.label_config import ( PHYSICAL_LABEL_COLS, PHYSICAL_DECODERS, MENTAL_LABEL_COLS, MENTAL_DECODERS, ) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_LEN = 128 # ───────────────────────────────────────────── # MODEL DEFINITIONS # ───────────────────────────────────────────── class PostPhysicalClassifier(nn.Module): def __init__(self): super().__init__() self.bert = BertModel.from_pretrained("bert-base-uncased") hidden = self.bert.config.hidden_size # 768 self.dropout = nn.Dropout(0.3) self.pain_head = nn.Linear(hidden, 3) self.completion_head = nn.Linear(hidden, 3) self.fatigue_head = nn.Linear(hidden, 3) self.recovery_need_head = nn.Linear(hidden, 3) def forward(self, input_ids, attention_mask): cls = self.dropout( self.bert( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state[:, 0, :] ) return { "pain_label": self.pain_head(cls), "completion_label": self.completion_head(cls), "fatigue_label": self.fatigue_head(cls), "recovery_need_label": self.recovery_need_head(cls), } class PostMentalClassifier(nn.Module): def __init__(self): super().__init__() self.bert = BertModel.from_pretrained("bert-base-uncased") hidden = self.bert.config.hidden_size # 768 self.dropout = nn.Dropout(0.3) self.performance_head = nn.Linear(hidden, 3) self.satisfaction_head = nn.Linear(hidden, 3) self.pr_achieved_head = nn.Linear(hidden, 2) # binary self.motivation_head = nn.Linear(hidden, 3) def forward(self, input_ids, attention_mask): cls = self.dropout( self.bert( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state[:, 0, :] ) return { "performance_label": self.performance_head(cls), "satisfaction_label": self.satisfaction_head(cls), "pr_achieved_label": self.pr_achieved_head(cls), "motivation_label": self.motivation_head(cls), } # ───────────────────────────────────────────── # LOADER (called once on app startup) # ───────────────────────────────────────────── def load_models(): print(f"Loading models on device: {DEVICE}") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") physical_model = PostPhysicalClassifier().to(DEVICE) physical_model.load_state_dict( torch.load("post_physical_bert.pt", map_location=DEVICE) ) physical_model.eval() print("post_physical_bert.pt loaded") mental_model = PostMentalClassifier().to(DEVICE) mental_model.load_state_dict( torch.load("post_mental_bert.pt", map_location=DEVICE) ) mental_model.eval() print("post_mental_bert.pt loaded") return tokenizer, physical_model, mental_model # ───────────────────────────────────────────── # SHARED INFERENCE FUNCTION # ───────────────────────────────────────────── def run_inference(model, tokenizer, enriched_text, label_cols, decoders): """ Runs a single forward pass and returns decoded labels with confidence scores. Args: model: one of PostPhysicalClassifier or PostMentalClassifier tokenizer: shared BertTokenizer enriched_text: user text already prepended with goal label_cols: list of label column names for this model decoders: dict mapping label col → {index: string} Returns: dict of { label_col: { label: str, confidence: float } } """ encoding = tokenizer( enriched_text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt" ) input_ids = encoding["input_ids"].to(DEVICE) attention_mask = encoding["attention_mask"].to(DEVICE) with torch.no_grad(): logits = model(input_ids, attention_mask) result = {} for col in label_cols: probs = F.softmax(logits[col], dim=-1).cpu().squeeze() pred_idx = torch.argmax(probs).item() confidence = probs[pred_idx].item() result[col] = { "label": decoders[col][pred_idx], "confidence": round(confidence, 3) } return result