File size: 4,027 Bytes
49a55aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput

from model import hidden_size_from_config


@dataclass
class MultitaskSpanOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    token_logits: Optional[torch.Tensor] = None
    start_logits: Optional[torch.Tensor] = None
    end_logits: Optional[torch.Tensor] = None


class IrishCoreTokenSpanModel(PreTrainedModel):
    config_class = AutoConfig
    base_model_prefix = "encoder"

    def __init__(self, config):
        super().__init__(config)
        num_span_labels = int(getattr(config, "num_span_labels"))
        self.encoder = AutoModel.from_config(config)
        hidden_size = hidden_size_from_config(config)
        dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
        self.dropout = nn.Dropout(dropout)
        self.token_classifier = nn.Linear(hidden_size, num_span_labels)
        self.start_classifier = nn.Linear(hidden_size, num_span_labels)
        self.end_classifier = nn.Linear(hidden_size, num_span_labels)
        boundary_pos_weight = float(getattr(config, "span_positive_weight", 6.0))
        presence_pos_weight = float(getattr(config, "token_positive_weight", 4.0))
        self.register_buffer("boundary_pos_weight", torch.full((num_span_labels,), boundary_pos_weight), persistent=False)
        self.register_buffer("presence_pos_weight", torch.full((num_span_labels,), presence_pos_weight), persistent=False)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        token_labels=None,
        start_positions=None,
        end_positions=None,
        token_mask=None,
        **kwargs,
    ) -> MultitaskSpanOutput:
        encoder_kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            **kwargs,
        }
        if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
            encoder_kwargs["token_type_ids"] = token_type_ids
        outputs = self.encoder(**encoder_kwargs)
        hidden = self.dropout(outputs.last_hidden_state)
        token_logits = self.token_classifier(hidden)
        start_logits = self.start_classifier(hidden)
        end_logits = self.end_classifier(hidden)

        loss = None
        if token_labels is not None and start_positions is not None and end_positions is not None:
            if token_mask is None:
                token_mask = attention_mask
            mask = token_mask.float().unsqueeze(-1)
            boundary_pos_weight = self.boundary_pos_weight.to(token_logits.device)
            presence_pos_weight = self.presence_pos_weight.to(token_logits.device)
            bce_boundary = nn.BCEWithLogitsLoss(reduction="none", pos_weight=boundary_pos_weight)
            bce_presence = nn.BCEWithLogitsLoss(reduction="none", pos_weight=presence_pos_weight)
            token_loss = bce_presence(token_logits, token_labels.float()) * mask
            start_loss = bce_boundary(start_logits, start_positions.float()) * mask
            end_loss = bce_boundary(end_logits, end_positions.float()) * mask
            denom = mask.sum().clamp_min(1.0) * token_logits.shape[-1]
            token_loss = token_loss.sum() / denom
            boundary_loss = (start_loss.sum() + end_loss.sum()) / (2.0 * denom)
            token_weight = float(getattr(self.config, "token_presence_weight", 1.0))
            boundary_weight = float(getattr(self.config, "boundary_loss_weight", 1.0))
            loss = token_weight * token_loss + boundary_weight * boundary_loss

        return MultitaskSpanOutput(
            loss=loss,
            token_logits=token_logits,
            start_logits=start_logits,
            end_logits=end_logits,
        )