temsa's picture
Release IrishCore-DiffMask-135M-v1-rc1
58f9459 verified
#!/usr/bin/env python3
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput
def hidden_size_from_config(config) -> int:
return int(getattr(config, "hidden_size", getattr(config, "dim")))
@dataclass
class MultilabelSpanOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
start_logits: Optional[torch.Tensor] = None
end_logits: Optional[torch.Tensor] = None
class IrishCoreSpanHeadModel(PreTrainedModel):
config_class = AutoConfig
base_model_prefix = "encoder"
def __init__(self, config):
super().__init__(config)
num_span_labels = int(getattr(config, "num_span_labels"))
self.encoder = AutoModel.from_config(config)
hidden_size = hidden_size_from_config(config)
dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
self.dropout = nn.Dropout(dropout)
self.start_classifier = nn.Linear(hidden_size, num_span_labels)
self.end_classifier = nn.Linear(hidden_size, num_span_labels)
pos_weight = float(getattr(config, "span_positive_weight", 6.0))
self.register_buffer("loss_pos_weight", torch.full((num_span_labels,), pos_weight), persistent=False)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
start_positions=None,
end_positions=None,
token_mask=None,
**kwargs,
) -> MultilabelSpanOutput:
encoder_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**kwargs,
}
if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
encoder_kwargs["token_type_ids"] = token_type_ids
outputs = self.encoder(**encoder_kwargs)
hidden = self.dropout(outputs.last_hidden_state)
start_logits = self.start_classifier(hidden)
end_logits = self.end_classifier(hidden)
loss = None
if start_positions is not None and end_positions is not None:
if token_mask is None:
token_mask = attention_mask
mask = token_mask.float().unsqueeze(-1)
pos_weight = self.loss_pos_weight.to(start_logits.device)
bce = nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight)
start_loss = bce(start_logits, start_positions.float()) * mask
end_loss = bce(end_logits, end_positions.float()) * mask
denom = mask.sum().clamp_min(1.0) * start_logits.shape[-1]
loss = (start_loss.sum() + end_loss.sum()) / (2.0 * denom)
return MultilabelSpanOutput(loss=loss, start_logits=start_logits, end_logits=end_logits)