import torch import torch.nn as nn from transformers import AutoModel, ModernBertConfig, ModernBertModel, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput class HavelockOralityConfig(PretrainedConfig): model_type = "havelock-orality-regressor" def __init__(self, dropout: float = 0.1, **kwargs): super().__init__(**kwargs) self.dropout = dropout class HavelockOralityRegressor(PreTrainedModel): config_class = HavelockOralityConfig def __init__(self, config, backbone=None): super().__init__(config) if backbone is not None: self.backbone = backbone else: backbone_config = ModernBertConfig.from_dict(config.to_dict()) self.backbone = ModernBertModel(backbone_config) self.dropout = nn.Dropout(config.dropout) self.regressor = nn.Linear(config.hidden_size, 1) self.post_init() @classmethod def from_backbone(cls, model_name: str, dropout: float = 0.1) -> "HavelockOralityRegressor": backbone = AutoModel.from_pretrained(model_name) config = HavelockOralityConfig( dropout=dropout, **backbone.config.to_dict(), ) return cls(config, backbone=backbone) def _pool(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: mask = attention_mask.unsqueeze(-1).float() return (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: pooled = outputs.pooler_output else: pooled = self._pool(outputs.last_hidden_state, attention_mask) pooled = self.dropout(pooled) scores = self.regressor(pooled).squeeze(-1) loss = None if labels is not None: loss = nn.MSELoss()(scores, labels) return SequenceClassifierOutput(loss=loss, logits=scores)