""" model.py — DistilBERT multi-head model definition and loader """ import torch import torch.nn as nn from transformers import DistilBertModel, DistilBertTokenizer import logging logger = logging.getLogger(__name__) class PostWorkoutDistilBERT(nn.Module): """ Post-workout multi-head DistilBERT classifier. 5 independent classification heads sharing one BERT backbone: - mood (8 classes) - exertion (3 classes) - soreness_region (7 classes — which muscle group) - soreness_severity (4 classes — how intense) - completion (2 classes) """ def __init__( self, num_moods: int = 8, num_exertion_levels: int = 3, num_soreness_region_classes: int = 7, num_soreness_severity_classes:int = 4, num_completion_statuses: int = 2, ): super().__init__() self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") hidden_size = self.bert.config.hidden_size # 768 self.dropout = nn.Dropout(0.3) self.head_dropout = nn.Dropout(0.1) # Simple heads for easy tasks self.mood_head = nn.Linear(hidden_size, num_moods) self.completion_head = nn.Linear(hidden_size, num_completion_statuses) # Deeper head for exertion — 768→128→3 self.exertion_head = nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_exertion_levels), ) # Soreness region head — 768→128→7 (which muscle group) self.soreness_region_head = nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_soreness_region_classes), ) # Soreness severity head — 768→64→4 (how intense) self.soreness_severity_head = nn.Sequential( nn.Linear(hidden_size, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_soreness_severity_classes), ) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) x = self.head_dropout(cls_output) return ( self.mood_head(x), self.exertion_head(x), self.soreness_region_head(x), self.soreness_severity_head(x), self.completion_head(x), ) class PreWorkoutDistilBERT(nn.Module): """ Multi-head DistilBERT classifier for pre-workout state analysis. 6 independent classification heads sharing one BERT backbone: - mood (8 classes) - energy (3 classes — low / moderate / high) - motivation (3 classes — low / moderate / high) - stress (3 classes — low / moderate / high) - soreness_region (7 classes — which muscle group) - soreness_severity (4 classes — how intense) """ def __init__( self, num_moods: int = 8, num_energy_levels: int = 3, num_motivation_levels: int = 3, num_stress_levels: int = 3, num_soreness_region_classes: int = 7, num_soreness_severity_classes:int = 4, ): super().__init__() self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") hidden_size = self.bert.config.hidden_size # 768 self.dropout = nn.Dropout(0.3) self.head_dropout = nn.Dropout(0.1) # Simple head — mood has strong linguistic signal self.mood_head = nn.Linear(hidden_size, num_moods) # Energy head — 768→128→3 self.energy_head = nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_energy_levels), ) # Motivation head — 768→64→3 self.motivation_head = nn.Sequential( nn.Linear(hidden_size, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_motivation_levels), ) # Stress head — 768→64→3 self.stress_head = nn.Sequential( nn.Linear(hidden_size, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_stress_levels), ) # Soreness region head — 768→128→7 self.soreness_region_head = nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_soreness_region_classes), ) # Soreness severity head — 768→64→4 self.soreness_severity_head = nn.Sequential( nn.Linear(hidden_size, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, num_soreness_severity_classes), ) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) x = self.head_dropout(cls_output) return ( self.mood_head(x), self.energy_head(x), self.motivation_head(x), self.stress_head(x), self.soreness_region_head(x), self.soreness_severity_head(x), ) def load_pre_model( model_path: str, device: torch.device, num_moods: int = 8, num_energy_levels: int = 3, num_motivation_levels: int = 3, num_stress_levels: int = 3, num_soreness_region_classes: int = 7, num_soreness_severity_classes:int = 4, ): """ Instantiate the pre-workout model, load saved weights, set to eval mode. Returns (model, tokenizer). """ logger.info(f"Loading pre-workout model weights from: {model_path}") model = PreWorkoutDistilBERT( num_moods=num_moods, num_energy_levels=num_energy_levels, num_motivation_levels=num_motivation_levels, num_stress_levels=num_stress_levels, num_soreness_region_classes=num_soreness_region_classes, num_soreness_severity_classes=num_soreness_severity_classes, ) state_dict = torch.load(model_path, map_location=device, weights_only=True) model.load_state_dict(state_dict) model.to(device) model.eval() tokenizer = DistilBertTokenizer.from_pretrained( "distilbert-base-uncased", clean_up_tokenization_spaces=True, ) logger.info("Pre-workout model loaded and set to eval mode.") return model, tokenizer def load_post_model( model_path: str, device: torch.device, num_moods: int = 8, num_exertion_levels: int = 3, num_soreness_region_classes: int = 7, num_soreness_severity_classes:int = 4, num_completion_statuses: int = 2, ): """ Instantiate the post-workout model, load saved weights, set to eval mode. Returns (model, tokenizer). """ logger.info(f"Loading post-workout model weights from: {model_path}") model = PostWorkoutDistilBERT( num_moods=num_moods, num_exertion_levels=num_exertion_levels, num_soreness_region_classes=num_soreness_region_classes, num_soreness_severity_classes=num_soreness_severity_classes, num_completion_statuses=num_completion_statuses, ) state_dict = torch.load(model_path, map_location=device, weights_only=True) model.load_state_dict(state_dict) model.to(device) model.eval() tokenizer = DistilBertTokenizer.from_pretrained( "distilbert-base-uncased", clean_up_tokenization_spaces=True, ) logger.info("Post-workout model loaded and set to eval mode.") return model, tokenizer