personal_trainer / model.py
jflo's picture
Update model.py
d1e6be7 verified
"""
model.py β€” DistilBERT multi-head model definition and loader
"""
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizer
import logging
logger = logging.getLogger(__name__)
class PostWorkoutDistilBERT(nn.Module):
"""
Post-workout multi-head DistilBERT classifier.
5 independent classification heads sharing one BERT backbone:
- mood (8 classes)
- exertion (3 classes)
- soreness_region (7 classes β€” which muscle group)
- soreness_severity (4 classes β€” how intense)
- completion (2 classes)
"""
def __init__(
self,
num_moods: int = 8,
num_exertion_levels: int = 3,
num_soreness_region_classes: int = 7,
num_soreness_severity_classes:int = 4,
num_completion_statuses: int = 2,
):
super().__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
hidden_size = self.bert.config.hidden_size # 768
self.dropout = nn.Dropout(0.3)
self.head_dropout = nn.Dropout(0.1)
# Simple heads for easy tasks
self.mood_head = nn.Linear(hidden_size, num_moods)
self.completion_head = nn.Linear(hidden_size, num_completion_statuses)
# Deeper head for exertion β€” 768β†’128β†’3
self.exertion_head = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_exertion_levels),
)
# Soreness region head β€” 768β†’128β†’7 (which muscle group)
self.soreness_region_head = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_soreness_region_classes),
)
# Soreness severity head β€” 768β†’64β†’4 (how intense)
self.soreness_severity_head = nn.Sequential(
nn.Linear(hidden_size, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, num_soreness_severity_classes),
)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = self.dropout(outputs.last_hidden_state[:, 0, :])
x = self.head_dropout(cls_output)
return (
self.mood_head(x),
self.exertion_head(x),
self.soreness_region_head(x),
self.soreness_severity_head(x),
self.completion_head(x),
)
class PreWorkoutDistilBERT(nn.Module):
"""
Multi-head DistilBERT classifier for pre-workout state analysis.
6 independent classification heads sharing one BERT backbone:
- mood (8 classes)
- energy (3 classes β€” low / moderate / high)
- motivation (3 classes β€” low / moderate / high)
- stress (3 classes β€” low / moderate / high)
- soreness_region (7 classes β€” which muscle group)
- soreness_severity (4 classes β€” how intense)
"""
def __init__(
self,
num_moods: int = 8,
num_energy_levels: int = 3,
num_motivation_levels: int = 3,
num_stress_levels: int = 3,
num_soreness_region_classes: int = 7,
num_soreness_severity_classes:int = 4,
):
super().__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
hidden_size = self.bert.config.hidden_size # 768
self.dropout = nn.Dropout(0.3)
self.head_dropout = nn.Dropout(0.1)
# Simple head β€” mood has strong linguistic signal
self.mood_head = nn.Linear(hidden_size, num_moods)
# Energy head β€” 768β†’128β†’3
self.energy_head = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_energy_levels),
)
# Motivation head β€” 768β†’64β†’3
self.motivation_head = nn.Sequential(
nn.Linear(hidden_size, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, num_motivation_levels),
)
# Stress head β€” 768β†’64β†’3
self.stress_head = nn.Sequential(
nn.Linear(hidden_size, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, num_stress_levels),
)
# Soreness region head β€” 768β†’128β†’7
self.soreness_region_head = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_soreness_region_classes),
)
# Soreness severity head β€” 768β†’64β†’4
self.soreness_severity_head = nn.Sequential(
nn.Linear(hidden_size, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, num_soreness_severity_classes),
)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = self.dropout(outputs.last_hidden_state[:, 0, :])
x = self.head_dropout(cls_output)
return (
self.mood_head(x),
self.energy_head(x),
self.motivation_head(x),
self.stress_head(x),
self.soreness_region_head(x),
self.soreness_severity_head(x),
)
def load_pre_model(
model_path: str,
device: torch.device,
num_moods: int = 8,
num_energy_levels: int = 3,
num_motivation_levels: int = 3,
num_stress_levels: int = 3,
num_soreness_region_classes: int = 7,
num_soreness_severity_classes:int = 4,
):
"""
Instantiate the pre-workout model, load saved weights, set to eval mode.
Returns (model, tokenizer).
"""
logger.info(f"Loading pre-workout model weights from: {model_path}")
model = PreWorkoutDistilBERT(
num_moods=num_moods,
num_energy_levels=num_energy_levels,
num_motivation_levels=num_motivation_levels,
num_stress_levels=num_stress_levels,
num_soreness_region_classes=num_soreness_region_classes,
num_soreness_severity_classes=num_soreness_severity_classes,
)
state_dict = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
tokenizer = DistilBertTokenizer.from_pretrained(
"distilbert-base-uncased",
clean_up_tokenization_spaces=True,
)
logger.info("Pre-workout model loaded and set to eval mode.")
return model, tokenizer
def load_post_model(
model_path: str,
device: torch.device,
num_moods: int = 8,
num_exertion_levels: int = 3,
num_soreness_region_classes: int = 7,
num_soreness_severity_classes:int = 4,
num_completion_statuses: int = 2,
):
"""
Instantiate the post-workout model, load saved weights, set to eval mode.
Returns (model, tokenizer).
"""
logger.info(f"Loading post-workout model weights from: {model_path}")
model = PostWorkoutDistilBERT(
num_moods=num_moods,
num_exertion_levels=num_exertion_levels,
num_soreness_region_classes=num_soreness_region_classes,
num_soreness_severity_classes=num_soreness_severity_classes,
num_completion_statuses=num_completion_statuses,
)
state_dict = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
tokenizer = DistilBertTokenizer.from_pretrained(
"distilbert-base-uncased",
clean_up_tokenization_spaces=True,
)
logger.info("Post-workout model loaded and set to eval mode.")
return model, tokenizer