#!/usr/bin/env python3 from __future__ import annotations from dataclasses import dataclass import math from typing import Optional import torch import torch.nn as nn from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers.utils import ModelOutput try: from .model import hidden_size_from_config except ImportError: from model import hidden_size_from_config @dataclass class MultitaskSpanOutput(ModelOutput): loss: Optional[torch.Tensor] = None token_logits: Optional[torch.Tensor] = None start_logits: Optional[torch.Tensor] = None end_logits: Optional[torch.Tensor] = None class IrishCoreTokenSpanModel(PreTrainedModel): config_class = AutoConfig base_model_prefix = "encoder" def __init__(self, config): super().__init__(config) num_span_labels = int(getattr(config, "num_span_labels")) self.encoder = AutoModel.from_config(config) hidden_size = hidden_size_from_config(config) dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1))) self.dropout = nn.Dropout(dropout) self.token_classifier = nn.Linear(hidden_size, num_span_labels) self.start_classifier = nn.Linear(hidden_size, num_span_labels) self.end_classifier = nn.Linear(hidden_size, num_span_labels) boundary_pos_weight = float(getattr(config, "span_positive_weight", 6.0)) presence_pos_weight = float(getattr(config, "token_positive_weight", 4.0)) self.register_buffer("boundary_pos_weight", torch.full((num_span_labels,), boundary_pos_weight), persistent=False) self.register_buffer("presence_pos_weight", torch.full((num_span_labels,), presence_pos_weight), persistent=False) self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, token_labels=None, start_positions=None, end_positions=None, token_mask=None, **kwargs, ) -> MultitaskSpanOutput: encoder_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, **kwargs, } if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}: encoder_kwargs["token_type_ids"] = token_type_ids outputs = self.encoder(**encoder_kwargs) hidden = self.dropout(outputs.last_hidden_state) token_logits = self.token_classifier(hidden) start_logits = self.start_classifier(hidden) end_logits = self.end_classifier(hidden) loss = None if token_labels is not None and start_positions is not None and end_positions is not None: if token_mask is None: token_mask = attention_mask mask = token_mask.float().unsqueeze(-1) boundary_pos_weight = self.boundary_pos_weight.to(token_logits.device) presence_pos_weight = self.presence_pos_weight.to(token_logits.device) bce_boundary = nn.BCEWithLogitsLoss(reduction="none", pos_weight=boundary_pos_weight) bce_presence = nn.BCEWithLogitsLoss(reduction="none", pos_weight=presence_pos_weight) token_targets = token_labels.float() start_targets = start_positions.float() end_targets = end_positions.float() token_loss = bce_presence(token_logits, token_targets) start_loss = bce_boundary(start_logits, start_targets) end_loss = bce_boundary(end_logits, end_targets) token_focal_gamma = float(getattr(self.config, "token_focal_gamma", getattr(self.config, "focal_gamma", 0.0))) boundary_focal_gamma = float(getattr(self.config, "boundary_focal_gamma", getattr(self.config, "focal_gamma", 0.0))) if token_focal_gamma > 0.0: token_loss = apply_focal_weight(token_loss, token_logits, token_targets, token_focal_gamma) if boundary_focal_gamma > 0.0: start_loss = apply_focal_weight(start_loss, start_logits, start_targets, boundary_focal_gamma) end_loss = apply_focal_weight(end_loss, end_logits, end_targets, boundary_focal_gamma) token_hard_fraction = float(getattr(self.config, "token_hard_fraction", getattr(self.config, "hard_fraction", 1.0))) boundary_hard_fraction = float(getattr(self.config, "boundary_hard_fraction", getattr(self.config, "hard_fraction", 1.0))) token_loss = reduce_masked_loss(token_loss, mask, token_hard_fraction) start_loss = reduce_masked_loss(start_loss, mask, boundary_hard_fraction) end_loss = reduce_masked_loss(end_loss, mask, boundary_hard_fraction) boundary_loss = 0.5 * (start_loss + end_loss) token_weight = float(getattr(self.config, "token_presence_weight", 1.0)) boundary_weight = float(getattr(self.config, "boundary_loss_weight", 1.0)) loss = token_weight * token_loss + boundary_weight * boundary_loss return MultitaskSpanOutput( loss=loss, token_logits=token_logits, start_logits=start_logits, end_logits=end_logits, ) def apply_focal_weight(loss: torch.Tensor, logits: torch.Tensor, targets: torch.Tensor, gamma: float) -> torch.Tensor: probs = torch.sigmoid(logits) pt = probs * targets + (1.0 - probs) * (1.0 - targets) return loss * (1.0 - pt).pow(gamma) def reduce_masked_loss(loss: torch.Tensor, mask: torch.Tensor, hard_fraction: float) -> torch.Tensor: expanded_mask = mask.expand_as(loss).bool() masked = loss.masked_select(expanded_mask) if masked.numel() == 0: return loss.sum() * 0.0 if 0.0 < hard_fraction < 1.0: keep = max(1, math.ceil(masked.numel() * hard_fraction)) masked = torch.topk(masked, keep).values return masked.mean()