bert-orality-regressor / modeling_havelock.py
permutans's picture
Upload folder using huggingface_hub
2796c1f verified
import torch
import torch.nn as nn
from transformers import AutoModel, ModernBertConfig, ModernBertModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
class HavelockOralityConfig(PretrainedConfig):
model_type = "havelock-orality-regressor"
def __init__(self, dropout: float = 0.1, **kwargs):
super().__init__(**kwargs)
self.dropout = dropout
class HavelockOralityRegressor(PreTrainedModel):
config_class = HavelockOralityConfig
def __init__(self, config, backbone=None):
super().__init__(config)
if backbone is not None:
self.backbone = backbone
else:
backbone_config = ModernBertConfig.from_dict(config.to_dict())
self.backbone = ModernBertModel(backbone_config)
self.dropout = nn.Dropout(config.dropout)
self.regressor = nn.Linear(config.hidden_size, 1)
self.post_init()
@classmethod
def from_backbone(cls, model_name: str, dropout: float = 0.1) -> "HavelockOralityRegressor":
backbone = AutoModel.from_pretrained(model_name)
config = HavelockOralityConfig(
dropout=dropout,
**backbone.config.to_dict(),
)
return cls(config, backbone=backbone)
def _pool(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).float()
return (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
pooled = outputs.pooler_output
else:
pooled = self._pool(outputs.last_hidden_state, attention_mask)
pooled = self.dropout(pooled)
scores = self.regressor(pooled).squeeze(-1)
loss = None
if labels is not None:
loss = nn.MSELoss()(scores, labels)
return SequenceClassifierOutput(loss=loss, logits=scores)