Spaces:
Sleeping
Sleeping
File size: 5,323 Bytes
27b62aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | # models/model.py
# BERT model definitions + loader for post-workout physical and mental classifiers
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
from models.label_config import (
PHYSICAL_LABEL_COLS, PHYSICAL_DECODERS,
MENTAL_LABEL_COLS, MENTAL_DECODERS,
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 128
# βββββββββββββββββββββββββββββββββββββββββββββ
# MODEL DEFINITIONS
# βββββββββββββββββββββββββββββββββββββββββββββ
class PostPhysicalClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased")
hidden = self.bert.config.hidden_size # 768
self.dropout = nn.Dropout(0.3)
self.pain_head = nn.Linear(hidden, 3)
self.completion_head = nn.Linear(hidden, 3)
self.fatigue_head = nn.Linear(hidden, 3)
self.recovery_need_head = nn.Linear(hidden, 3)
def forward(self, input_ids, attention_mask):
cls = self.dropout(
self.bert(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state[:, 0, :]
)
return {
"pain_label": self.pain_head(cls),
"completion_label": self.completion_head(cls),
"fatigue_label": self.fatigue_head(cls),
"recovery_need_label": self.recovery_need_head(cls),
}
class PostMentalClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased")
hidden = self.bert.config.hidden_size # 768
self.dropout = nn.Dropout(0.3)
self.performance_head = nn.Linear(hidden, 3)
self.satisfaction_head = nn.Linear(hidden, 3)
self.pr_achieved_head = nn.Linear(hidden, 2) # binary
self.motivation_head = nn.Linear(hidden, 3)
def forward(self, input_ids, attention_mask):
cls = self.dropout(
self.bert(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state[:, 0, :]
)
return {
"performance_label": self.performance_head(cls),
"satisfaction_label": self.satisfaction_head(cls),
"pr_achieved_label": self.pr_achieved_head(cls),
"motivation_label": self.motivation_head(cls),
}
# βββββββββββββββββββββββββββββββββββββββββββββ
# LOADER (called once on app startup)
# βββββββββββββββββββββββββββββββββββββββββββββ
def load_models():
print(f"Loading models on device: {DEVICE}")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
physical_model = PostPhysicalClassifier().to(DEVICE)
physical_model.load_state_dict(
torch.load("post_physical_bert.pt", map_location=DEVICE)
)
physical_model.eval()
print("post_physical_bert.pt loaded")
mental_model = PostMentalClassifier().to(DEVICE)
mental_model.load_state_dict(
torch.load("post_mental_bert.pt", map_location=DEVICE)
)
mental_model.eval()
print("post_mental_bert.pt loaded")
return tokenizer, physical_model, mental_model
# βββββββββββββββββββββββββββββββββββββββββββββ
# SHARED INFERENCE FUNCTION
# βββββββββββββββββββββββββββββββββββββββββββββ
def run_inference(model, tokenizer, enriched_text, label_cols, decoders):
"""
Runs a single forward pass and returns decoded labels with confidence scores.
Args:
model: one of PostPhysicalClassifier or PostMentalClassifier
tokenizer: shared BertTokenizer
enriched_text: user text already prepended with goal
label_cols: list of label column names for this model
decoders: dict mapping label col β {index: string}
Returns:
dict of { label_col: { label: str, confidence: float } }
"""
encoding = tokenizer(
enriched_text,
max_length=MAX_LEN,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = encoding["input_ids"].to(DEVICE)
attention_mask = encoding["attention_mask"].to(DEVICE)
with torch.no_grad():
logits = model(input_ids, attention_mask)
result = {}
for col in label_cols:
probs = F.softmax(logits[col], dim=-1).cpu().squeeze()
pred_idx = torch.argmax(probs).item()
confidence = probs[pred_idx].item()
result[col] = {
"label": decoders[col][pred_idx],
"confidence": round(confidence, 3)
}
return result
|