Spaces:
Sleeping
Sleeping
| # models/model.py | |
| # BERT model definitions + loader for post-workout physical and mental classifiers | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import BertModel, BertTokenizer | |
| from models.label_config import ( | |
| PHYSICAL_LABEL_COLS, PHYSICAL_DECODERS, | |
| MENTAL_LABEL_COLS, MENTAL_DECODERS, | |
| ) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MAX_LEN = 128 | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL DEFINITIONS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class PostPhysicalClassifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained("bert-base-uncased") | |
| hidden = self.bert.config.hidden_size # 768 | |
| self.dropout = nn.Dropout(0.3) | |
| self.pain_head = nn.Linear(hidden, 3) | |
| self.completion_head = nn.Linear(hidden, 3) | |
| self.fatigue_head = nn.Linear(hidden, 3) | |
| self.recovery_need_head = nn.Linear(hidden, 3) | |
| def forward(self, input_ids, attention_mask): | |
| cls = self.dropout( | |
| self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ).last_hidden_state[:, 0, :] | |
| ) | |
| return { | |
| "pain_label": self.pain_head(cls), | |
| "completion_label": self.completion_head(cls), | |
| "fatigue_label": self.fatigue_head(cls), | |
| "recovery_need_label": self.recovery_need_head(cls), | |
| } | |
| class PostMentalClassifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained("bert-base-uncased") | |
| hidden = self.bert.config.hidden_size # 768 | |
| self.dropout = nn.Dropout(0.3) | |
| self.performance_head = nn.Linear(hidden, 3) | |
| self.satisfaction_head = nn.Linear(hidden, 3) | |
| self.pr_achieved_head = nn.Linear(hidden, 2) # binary | |
| self.motivation_head = nn.Linear(hidden, 3) | |
| def forward(self, input_ids, attention_mask): | |
| cls = self.dropout( | |
| self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ).last_hidden_state[:, 0, :] | |
| ) | |
| return { | |
| "performance_label": self.performance_head(cls), | |
| "satisfaction_label": self.satisfaction_head(cls), | |
| "pr_achieved_label": self.pr_achieved_head(cls), | |
| "motivation_label": self.motivation_head(cls), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOADER (called once on app startup) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_models(): | |
| print(f"Loading models on device: {DEVICE}") | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| physical_model = PostPhysicalClassifier().to(DEVICE) | |
| physical_model.load_state_dict( | |
| torch.load("post_physical_bert.pt", map_location=DEVICE) | |
| ) | |
| physical_model.eval() | |
| print("post_physical_bert.pt loaded") | |
| mental_model = PostMentalClassifier().to(DEVICE) | |
| mental_model.load_state_dict( | |
| torch.load("post_mental_bert.pt", map_location=DEVICE) | |
| ) | |
| mental_model.eval() | |
| print("post_mental_bert.pt loaded") | |
| return tokenizer, physical_model, mental_model | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # SHARED INFERENCE FUNCTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_inference(model, tokenizer, enriched_text, label_cols, decoders): | |
| """ | |
| Runs a single forward pass and returns decoded labels with confidence scores. | |
| Args: | |
| model: one of PostPhysicalClassifier or PostMentalClassifier | |
| tokenizer: shared BertTokenizer | |
| enriched_text: user text already prepended with goal | |
| label_cols: list of label column names for this model | |
| decoders: dict mapping label col β {index: string} | |
| Returns: | |
| dict of { label_col: { label: str, confidence: float } } | |
| """ | |
| encoding = tokenizer( | |
| enriched_text, | |
| max_length=MAX_LEN, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| input_ids = encoding["input_ids"].to(DEVICE) | |
| attention_mask = encoding["attention_mask"].to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(input_ids, attention_mask) | |
| result = {} | |
| for col in label_cols: | |
| probs = F.softmax(logits[col], dim=-1).cpu().squeeze() | |
| pred_idx = torch.argmax(probs).item() | |
| confidence = probs[pred_idx].item() | |
| result[col] = { | |
| "label": decoders[col][pred_idx], | |
| "confidence": round(confidence, 3) | |
| } | |
| return result | |