File size: 4,103 Bytes
9416dba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput

try:
    from .model import hidden_size_from_config
except ImportError:
    from model import hidden_size_from_config


@dataclass
class MultitaskSpanOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    token_logits: Optional[torch.Tensor] = None
    start_logits: Optional[torch.Tensor] = None
    end_logits: Optional[torch.Tensor] = None


class IrishCoreTokenSpanModel(PreTrainedModel):
    config_class = AutoConfig
    base_model_prefix = "encoder"

    def __init__(self, config):
        super().__init__(config)
        num_span_labels = int(getattr(config, "num_span_labels"))
        self.encoder = AutoModel.from_config(config)
        hidden_size = hidden_size_from_config(config)
        dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
        self.dropout = nn.Dropout(dropout)
        self.token_classifier = nn.Linear(hidden_size, num_span_labels)
        self.start_classifier = nn.Linear(hidden_size, num_span_labels)
        self.end_classifier = nn.Linear(hidden_size, num_span_labels)
        boundary_pos_weight = float(getattr(config, "span_positive_weight", 6.0))
        presence_pos_weight = float(getattr(config, "token_positive_weight", 4.0))
        self.register_buffer("boundary_pos_weight", torch.full((num_span_labels,), boundary_pos_weight), persistent=False)
        self.register_buffer("presence_pos_weight", torch.full((num_span_labels,), presence_pos_weight), persistent=False)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        token_labels=None,
        start_positions=None,
        end_positions=None,
        token_mask=None,
        **kwargs,
    ) -> MultitaskSpanOutput:
        encoder_kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            **kwargs,
        }
        if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
            encoder_kwargs["token_type_ids"] = token_type_ids
        outputs = self.encoder(**encoder_kwargs)
        hidden = self.dropout(outputs.last_hidden_state)
        token_logits = self.token_classifier(hidden)
        start_logits = self.start_classifier(hidden)
        end_logits = self.end_classifier(hidden)

        loss = None
        if token_labels is not None and start_positions is not None and end_positions is not None:
            if token_mask is None:
                token_mask = attention_mask
            mask = token_mask.float().unsqueeze(-1)
            boundary_pos_weight = self.boundary_pos_weight.to(token_logits.device)
            presence_pos_weight = self.presence_pos_weight.to(token_logits.device)
            bce_boundary = nn.BCEWithLogitsLoss(reduction="none", pos_weight=boundary_pos_weight)
            bce_presence = nn.BCEWithLogitsLoss(reduction="none", pos_weight=presence_pos_weight)
            token_loss = bce_presence(token_logits, token_labels.float()) * mask
            start_loss = bce_boundary(start_logits, start_positions.float()) * mask
            end_loss = bce_boundary(end_logits, end_positions.float()) * mask
            denom = mask.sum().clamp_min(1.0) * token_logits.shape[-1]
            token_loss = token_loss.sum() / denom
            boundary_loss = (start_loss.sum() + end_loss.sum()) / (2.0 * denom)
            token_weight = float(getattr(self.config, "token_presence_weight", 1.0))
            boundary_weight = float(getattr(self.config, "boundary_loss_weight", 1.0))
            loss = token_weight * token_loss + boundary_weight * boundary_loss

        return MultitaskSpanOutput(
            loss=loss,
            token_logits=token_logits,
            start_logits=start_logits,
            end_logits=end_logits,
        )