Spaces:
Sleeping
Sleeping
| """ | |
| model.py β DistilBERT multi-head model definition and loader | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from transformers import DistilBertModel, DistilBertTokenizer | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class PostWorkoutDistilBERT(nn.Module): | |
| """ | |
| Post-workout multi-head DistilBERT classifier. | |
| 5 independent classification heads sharing one BERT backbone: | |
| - mood (8 classes) | |
| - exertion (3 classes) | |
| - soreness_region (7 classes β which muscle group) | |
| - soreness_severity (4 classes β how intense) | |
| - completion (2 classes) | |
| """ | |
| def __init__( | |
| self, | |
| num_moods: int = 8, | |
| num_exertion_levels: int = 3, | |
| num_soreness_region_classes: int = 7, | |
| num_soreness_severity_classes:int = 4, | |
| num_completion_statuses: int = 2, | |
| ): | |
| super().__init__() | |
| self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
| hidden_size = self.bert.config.hidden_size # 768 | |
| self.dropout = nn.Dropout(0.3) | |
| self.head_dropout = nn.Dropout(0.1) | |
| # Simple heads for easy tasks | |
| self.mood_head = nn.Linear(hidden_size, num_moods) | |
| self.completion_head = nn.Linear(hidden_size, num_completion_statuses) | |
| # Deeper head for exertion β 768β128β3 | |
| self.exertion_head = nn.Sequential( | |
| nn.Linear(hidden_size, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_exertion_levels), | |
| ) | |
| # Soreness region head β 768β128β7 (which muscle group) | |
| self.soreness_region_head = nn.Sequential( | |
| nn.Linear(hidden_size, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_soreness_region_classes), | |
| ) | |
| # Soreness severity head β 768β64β4 (how intense) | |
| self.soreness_severity_head = nn.Sequential( | |
| nn.Linear(hidden_size, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, num_soreness_severity_classes), | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) | |
| x = self.head_dropout(cls_output) | |
| return ( | |
| self.mood_head(x), | |
| self.exertion_head(x), | |
| self.soreness_region_head(x), | |
| self.soreness_severity_head(x), | |
| self.completion_head(x), | |
| ) | |
| class PreWorkoutDistilBERT(nn.Module): | |
| """ | |
| Multi-head DistilBERT classifier for pre-workout state analysis. | |
| 6 independent classification heads sharing one BERT backbone: | |
| - mood (8 classes) | |
| - energy (3 classes β low / moderate / high) | |
| - motivation (3 classes β low / moderate / high) | |
| - stress (3 classes β low / moderate / high) | |
| - soreness_region (7 classes β which muscle group) | |
| - soreness_severity (4 classes β how intense) | |
| """ | |
| def __init__( | |
| self, | |
| num_moods: int = 8, | |
| num_energy_levels: int = 3, | |
| num_motivation_levels: int = 3, | |
| num_stress_levels: int = 3, | |
| num_soreness_region_classes: int = 7, | |
| num_soreness_severity_classes:int = 4, | |
| ): | |
| super().__init__() | |
| self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
| hidden_size = self.bert.config.hidden_size # 768 | |
| self.dropout = nn.Dropout(0.3) | |
| self.head_dropout = nn.Dropout(0.1) | |
| # Simple head β mood has strong linguistic signal | |
| self.mood_head = nn.Linear(hidden_size, num_moods) | |
| # Energy head β 768β128β3 | |
| self.energy_head = nn.Sequential( | |
| nn.Linear(hidden_size, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_energy_levels), | |
| ) | |
| # Motivation head β 768β64β3 | |
| self.motivation_head = nn.Sequential( | |
| nn.Linear(hidden_size, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, num_motivation_levels), | |
| ) | |
| # Stress head β 768β64β3 | |
| self.stress_head = nn.Sequential( | |
| nn.Linear(hidden_size, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, num_stress_levels), | |
| ) | |
| # Soreness region head β 768β128β7 | |
| self.soreness_region_head = nn.Sequential( | |
| nn.Linear(hidden_size, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_soreness_region_classes), | |
| ) | |
| # Soreness severity head β 768β64β4 | |
| self.soreness_severity_head = nn.Sequential( | |
| nn.Linear(hidden_size, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(64, num_soreness_severity_classes), | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| cls_output = self.dropout(outputs.last_hidden_state[:, 0, :]) | |
| x = self.head_dropout(cls_output) | |
| return ( | |
| self.mood_head(x), | |
| self.energy_head(x), | |
| self.motivation_head(x), | |
| self.stress_head(x), | |
| self.soreness_region_head(x), | |
| self.soreness_severity_head(x), | |
| ) | |
| def load_pre_model( | |
| model_path: str, | |
| device: torch.device, | |
| num_moods: int = 8, | |
| num_energy_levels: int = 3, | |
| num_motivation_levels: int = 3, | |
| num_stress_levels: int = 3, | |
| num_soreness_region_classes: int = 7, | |
| num_soreness_severity_classes:int = 4, | |
| ): | |
| """ | |
| Instantiate the pre-workout model, load saved weights, set to eval mode. | |
| Returns (model, tokenizer). | |
| """ | |
| logger.info(f"Loading pre-workout model weights from: {model_path}") | |
| model = PreWorkoutDistilBERT( | |
| num_moods=num_moods, | |
| num_energy_levels=num_energy_levels, | |
| num_motivation_levels=num_motivation_levels, | |
| num_stress_levels=num_stress_levels, | |
| num_soreness_region_classes=num_soreness_region_classes, | |
| num_soreness_severity_classes=num_soreness_severity_classes, | |
| ) | |
| state_dict = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| tokenizer = DistilBertTokenizer.from_pretrained( | |
| "distilbert-base-uncased", | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| logger.info("Pre-workout model loaded and set to eval mode.") | |
| return model, tokenizer | |
| def load_post_model( | |
| model_path: str, | |
| device: torch.device, | |
| num_moods: int = 8, | |
| num_exertion_levels: int = 3, | |
| num_soreness_region_classes: int = 7, | |
| num_soreness_severity_classes:int = 4, | |
| num_completion_statuses: int = 2, | |
| ): | |
| """ | |
| Instantiate the post-workout model, load saved weights, set to eval mode. | |
| Returns (model, tokenizer). | |
| """ | |
| logger.info(f"Loading post-workout model weights from: {model_path}") | |
| model = PostWorkoutDistilBERT( | |
| num_moods=num_moods, | |
| num_exertion_levels=num_exertion_levels, | |
| num_soreness_region_classes=num_soreness_region_classes, | |
| num_soreness_severity_classes=num_soreness_severity_classes, | |
| num_completion_statuses=num_completion_statuses, | |
| ) | |
| state_dict = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| tokenizer = DistilBertTokenizer.from_pretrained( | |
| "distilbert-base-uncased", | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| logger.info("Post-workout model loaded and set to eval mode.") | |
| return model, tokenizer | |