#!/usr/bin/env python3 from __future__ import annotations from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers.utils import ModelOutput from model import hidden_size_from_config @dataclass class MultitaskSpanOutput(ModelOutput): loss: Optional[torch.Tensor] = None token_logits: Optional[torch.Tensor] = None start_logits: Optional[torch.Tensor] = None end_logits: Optional[torch.Tensor] = None class IrishCoreTokenSpanModel(PreTrainedModel): config_class = AutoConfig base_model_prefix = "encoder" def __init__(self, config): super().__init__(config) num_span_labels = int(getattr(config, "num_span_labels")) self.encoder = AutoModel.from_config(config) hidden_size = hidden_size_from_config(config) dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1))) self.dropout = nn.Dropout(dropout) self.token_classifier = nn.Linear(hidden_size, num_span_labels) self.start_classifier = nn.Linear(hidden_size, num_span_labels) self.end_classifier = nn.Linear(hidden_size, num_span_labels) boundary_pos_weight = float(getattr(config, "span_positive_weight", 6.0)) presence_pos_weight = float(getattr(config, "token_positive_weight", 4.0)) self.register_buffer("boundary_pos_weight", torch.full((num_span_labels,), boundary_pos_weight), persistent=False) self.register_buffer("presence_pos_weight", torch.full((num_span_labels,), presence_pos_weight), persistent=False) self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, token_labels=None, start_positions=None, end_positions=None, token_mask=None, **kwargs, ) -> MultitaskSpanOutput: encoder_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, **kwargs, } if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}: encoder_kwargs["token_type_ids"] = token_type_ids outputs = self.encoder(**encoder_kwargs) hidden = self.dropout(outputs.last_hidden_state) token_logits = self.token_classifier(hidden) start_logits = self.start_classifier(hidden) end_logits = self.end_classifier(hidden) loss = None if token_labels is not None and start_positions is not None and end_positions is not None: if token_mask is None: token_mask = attention_mask mask = token_mask.float().unsqueeze(-1) boundary_pos_weight = self.boundary_pos_weight.to(token_logits.device) presence_pos_weight = self.presence_pos_weight.to(token_logits.device) bce_boundary = nn.BCEWithLogitsLoss(reduction="none", pos_weight=boundary_pos_weight) bce_presence = nn.BCEWithLogitsLoss(reduction="none", pos_weight=presence_pos_weight) token_loss = bce_presence(token_logits, token_labels.float()) * mask start_loss = bce_boundary(start_logits, start_positions.float()) * mask end_loss = bce_boundary(end_logits, end_positions.float()) * mask denom = mask.sum().clamp_min(1.0) * token_logits.shape[-1] token_loss = token_loss.sum() / denom boundary_loss = (start_loss.sum() + end_loss.sum()) / (2.0 * denom) token_weight = float(getattr(self.config, "token_presence_weight", 1.0)) boundary_weight = float(getattr(self.config, "boundary_loss_weight", 1.0)) loss = token_weight * token_loss + boundary_weight * boundary_loss return MultitaskSpanOutput( loss=loss, token_logits=token_logits, start_logits=start_logits, end_logits=end_logits, )