| import torch |
| import torch.nn as nn |
| from transformers import AutoModel, ModernBertConfig, ModernBertModel, PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
| class HavelockOralityConfig(PretrainedConfig): |
| model_type = "havelock-orality-regressor" |
|
|
| def __init__(self, dropout: float = 0.1, **kwargs): |
| super().__init__(**kwargs) |
| self.dropout = dropout |
|
|
|
|
| class HavelockOralityRegressor(PreTrainedModel): |
| config_class = HavelockOralityConfig |
|
|
| def __init__(self, config, backbone=None): |
| super().__init__(config) |
| if backbone is not None: |
| self.backbone = backbone |
| else: |
| backbone_config = ModernBertConfig.from_dict(config.to_dict()) |
| self.backbone = ModernBertModel(backbone_config) |
|
|
| self.dropout = nn.Dropout(config.dropout) |
| self.regressor = nn.Linear(config.hidden_size, 1) |
| self.post_init() |
|
|
| @classmethod |
| def from_backbone(cls, model_name: str, dropout: float = 0.1) -> "HavelockOralityRegressor": |
| backbone = AutoModel.from_pretrained(model_name) |
| config = HavelockOralityConfig( |
| dropout=dropout, |
| **backbone.config.to_dict(), |
| ) |
| return cls(config, backbone=backbone) |
|
|
| def _pool(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| mask = attention_mask.unsqueeze(-1).float() |
| return (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) |
|
|
| if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: |
| pooled = outputs.pooler_output |
| else: |
| pooled = self._pool(outputs.last_hidden_state, attention_mask) |
|
|
| pooled = self.dropout(pooled) |
| scores = self.regressor(pooled).squeeze(-1) |
|
|
| loss = None |
| if labels is not None: |
| loss = nn.MSELoss()(scores, labels) |
|
|
| return SequenceClassifierOutput(loss=loss, logits=scores) |