| | |
| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoConfig, AutoModel, PreTrainedModel |
| | from transformers.utils import ModelOutput |
| |
|
| | try: |
| | from .model import hidden_size_from_config |
| | except ImportError: |
| | from model import hidden_size_from_config |
| |
|
| |
|
| | @dataclass |
| | class MultitaskSpanOutput(ModelOutput): |
| | loss: Optional[torch.Tensor] = None |
| | token_logits: Optional[torch.Tensor] = None |
| | start_logits: Optional[torch.Tensor] = None |
| | end_logits: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class IrishCoreTokenSpanModel(PreTrainedModel): |
| | config_class = AutoConfig |
| | base_model_prefix = "encoder" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | num_span_labels = int(getattr(config, "num_span_labels")) |
| | self.encoder = AutoModel.from_config(config) |
| | hidden_size = hidden_size_from_config(config) |
| | dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1))) |
| | self.dropout = nn.Dropout(dropout) |
| | self.token_classifier = nn.Linear(hidden_size, num_span_labels) |
| | self.start_classifier = nn.Linear(hidden_size, num_span_labels) |
| | self.end_classifier = nn.Linear(hidden_size, num_span_labels) |
| | boundary_pos_weight = float(getattr(config, "span_positive_weight", 6.0)) |
| | presence_pos_weight = float(getattr(config, "token_positive_weight", 4.0)) |
| | self.register_buffer("boundary_pos_weight", torch.full((num_span_labels,), boundary_pos_weight), persistent=False) |
| | self.register_buffer("presence_pos_weight", torch.full((num_span_labels,), presence_pos_weight), persistent=False) |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | token_type_ids=None, |
| | token_labels=None, |
| | start_positions=None, |
| | end_positions=None, |
| | token_mask=None, |
| | **kwargs, |
| | ) -> MultitaskSpanOutput: |
| | encoder_kwargs = { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | **kwargs, |
| | } |
| | if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}: |
| | encoder_kwargs["token_type_ids"] = token_type_ids |
| | outputs = self.encoder(**encoder_kwargs) |
| | hidden = self.dropout(outputs.last_hidden_state) |
| | token_logits = self.token_classifier(hidden) |
| | start_logits = self.start_classifier(hidden) |
| | end_logits = self.end_classifier(hidden) |
| |
|
| | loss = None |
| | if token_labels is not None and start_positions is not None and end_positions is not None: |
| | if token_mask is None: |
| | token_mask = attention_mask |
| | mask = token_mask.float().unsqueeze(-1) |
| | boundary_pos_weight = self.boundary_pos_weight.to(token_logits.device) |
| | presence_pos_weight = self.presence_pos_weight.to(token_logits.device) |
| | bce_boundary = nn.BCEWithLogitsLoss(reduction="none", pos_weight=boundary_pos_weight) |
| | bce_presence = nn.BCEWithLogitsLoss(reduction="none", pos_weight=presence_pos_weight) |
| | token_loss = bce_presence(token_logits, token_labels.float()) * mask |
| | start_loss = bce_boundary(start_logits, start_positions.float()) * mask |
| | end_loss = bce_boundary(end_logits, end_positions.float()) * mask |
| | denom = mask.sum().clamp_min(1.0) * token_logits.shape[-1] |
| | token_loss = token_loss.sum() / denom |
| | boundary_loss = (start_loss.sum() + end_loss.sum()) / (2.0 * denom) |
| | token_weight = float(getattr(self.config, "token_presence_weight", 1.0)) |
| | boundary_weight = float(getattr(self.config, "boundary_loss_weight", 1.0)) |
| | loss = token_weight * token_loss + boundary_weight * boundary_loss |
| |
|
| | return MultitaskSpanOutput( |
| | loss=loss, |
| | token_logits=token_logits, |
| | start_logits=start_logits, |
| | end_logits=end_logits, |
| | ) |
| |
|