#!/usr/bin/env python3 from __future__ import annotations from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers.utils import ModelOutput def hidden_size_from_config(config) -> int: return int(getattr(config, "hidden_size", getattr(config, "dim"))) @dataclass class MultilabelSpanOutput(ModelOutput): loss: Optional[torch.Tensor] = None start_logits: Optional[torch.Tensor] = None end_logits: Optional[torch.Tensor] = None class IrishCoreSpanHeadModel(PreTrainedModel): config_class = AutoConfig base_model_prefix = "encoder" def __init__(self, config): super().__init__(config) num_span_labels = int(getattr(config, "num_span_labels")) self.encoder = AutoModel.from_config(config) hidden_size = hidden_size_from_config(config) dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1))) self.dropout = nn.Dropout(dropout) self.start_classifier = nn.Linear(hidden_size, num_span_labels) self.end_classifier = nn.Linear(hidden_size, num_span_labels) pos_weight = float(getattr(config, "span_positive_weight", 6.0)) self.register_buffer("loss_pos_weight", torch.full((num_span_labels,), pos_weight), persistent=False) self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, start_positions=None, end_positions=None, token_mask=None, **kwargs, ) -> MultilabelSpanOutput: encoder_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, **kwargs, } if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}: encoder_kwargs["token_type_ids"] = token_type_ids outputs = self.encoder(**encoder_kwargs) hidden = self.dropout(outputs.last_hidden_state) start_logits = self.start_classifier(hidden) end_logits = self.end_classifier(hidden) loss = None if start_positions is not None and end_positions is not None: if token_mask is None: token_mask = attention_mask mask = token_mask.float().unsqueeze(-1) pos_weight = self.loss_pos_weight.to(start_logits.device) bce = nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight) start_loss = bce(start_logits, start_positions.float()) * mask end_loss = bce(end_logits, end_positions.float()) * mask denom = mask.sum().clamp_min(1.0) * start_logits.shape[-1] loss = (start_loss.sum() + end_loss.sum()) / (2.0 * denom) return MultilabelSpanOutput(loss=loss, start_logits=start_logits, end_logits=end_logits)