jflo's picture
Upload 6 files
27b62aa
# models/model.py
# BERT model definitions + loader for post-workout physical and mental classifiers
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
from models.label_config import (
PHYSICAL_LABEL_COLS, PHYSICAL_DECODERS,
MENTAL_LABEL_COLS, MENTAL_DECODERS,
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 128
# ─────────────────────────────────────────────
# MODEL DEFINITIONS
# ─────────────────────────────────────────────
class PostPhysicalClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased")
hidden = self.bert.config.hidden_size # 768
self.dropout = nn.Dropout(0.3)
self.pain_head = nn.Linear(hidden, 3)
self.completion_head = nn.Linear(hidden, 3)
self.fatigue_head = nn.Linear(hidden, 3)
self.recovery_need_head = nn.Linear(hidden, 3)
def forward(self, input_ids, attention_mask):
cls = self.dropout(
self.bert(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state[:, 0, :]
)
return {
"pain_label": self.pain_head(cls),
"completion_label": self.completion_head(cls),
"fatigue_label": self.fatigue_head(cls),
"recovery_need_label": self.recovery_need_head(cls),
}
class PostMentalClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased")
hidden = self.bert.config.hidden_size # 768
self.dropout = nn.Dropout(0.3)
self.performance_head = nn.Linear(hidden, 3)
self.satisfaction_head = nn.Linear(hidden, 3)
self.pr_achieved_head = nn.Linear(hidden, 2) # binary
self.motivation_head = nn.Linear(hidden, 3)
def forward(self, input_ids, attention_mask):
cls = self.dropout(
self.bert(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state[:, 0, :]
)
return {
"performance_label": self.performance_head(cls),
"satisfaction_label": self.satisfaction_head(cls),
"pr_achieved_label": self.pr_achieved_head(cls),
"motivation_label": self.motivation_head(cls),
}
# ─────────────────────────────────────────────
# LOADER (called once on app startup)
# ─────────────────────────────────────────────
def load_models():
print(f"Loading models on device: {DEVICE}")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
physical_model = PostPhysicalClassifier().to(DEVICE)
physical_model.load_state_dict(
torch.load("post_physical_bert.pt", map_location=DEVICE)
)
physical_model.eval()
print("post_physical_bert.pt loaded")
mental_model = PostMentalClassifier().to(DEVICE)
mental_model.load_state_dict(
torch.load("post_mental_bert.pt", map_location=DEVICE)
)
mental_model.eval()
print("post_mental_bert.pt loaded")
return tokenizer, physical_model, mental_model
# ─────────────────────────────────────────────
# SHARED INFERENCE FUNCTION
# ─────────────────────────────────────────────
def run_inference(model, tokenizer, enriched_text, label_cols, decoders):
"""
Runs a single forward pass and returns decoded labels with confidence scores.
Args:
model: one of PostPhysicalClassifier or PostMentalClassifier
tokenizer: shared BertTokenizer
enriched_text: user text already prepended with goal
label_cols: list of label column names for this model
decoders: dict mapping label col β†’ {index: string}
Returns:
dict of { label_col: { label: str, confidence: float } }
"""
encoding = tokenizer(
enriched_text,
max_length=MAX_LEN,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = encoding["input_ids"].to(DEVICE)
attention_mask = encoding["attention_mask"].to(DEVICE)
with torch.no_grad():
logits = model(input_ids, attention_mask)
result = {}
for col in label_cols:
probs = F.softmax(logits[col], dim=-1).cpu().squeeze()
pred_idx = torch.argmax(probs).item()
confidence = probs[pred_idx].item()
result[col] = {
"label": decoders[col][pred_idx],
"confidence": round(confidence, 3)
}
return result