File size: 5,931 Bytes
b08ade7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
from __future__ import annotations

from dataclasses import dataclass
import math
from typing import Optional

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.utils import ModelOutput

try:
    from .model import hidden_size_from_config
except ImportError:
    from model import hidden_size_from_config


@dataclass
class MultitaskSpanOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    token_logits: Optional[torch.Tensor] = None
    start_logits: Optional[torch.Tensor] = None
    end_logits: Optional[torch.Tensor] = None


class IrishCoreTokenSpanModel(PreTrainedModel):
    config_class = AutoConfig
    base_model_prefix = "encoder"

    def __init__(self, config):
        super().__init__(config)
        num_span_labels = int(getattr(config, "num_span_labels"))
        self.encoder = AutoModel.from_config(config)
        hidden_size = hidden_size_from_config(config)
        dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1)))
        self.dropout = nn.Dropout(dropout)
        self.token_classifier = nn.Linear(hidden_size, num_span_labels)
        self.start_classifier = nn.Linear(hidden_size, num_span_labels)
        self.end_classifier = nn.Linear(hidden_size, num_span_labels)
        boundary_pos_weight = float(getattr(config, "span_positive_weight", 6.0))
        presence_pos_weight = float(getattr(config, "token_positive_weight", 4.0))
        self.register_buffer("boundary_pos_weight", torch.full((num_span_labels,), boundary_pos_weight), persistent=False)
        self.register_buffer("presence_pos_weight", torch.full((num_span_labels,), presence_pos_weight), persistent=False)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        token_labels=None,
        start_positions=None,
        end_positions=None,
        token_mask=None,
        **kwargs,
    ) -> MultitaskSpanOutput:
        encoder_kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            **kwargs,
        }
        if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}:
            encoder_kwargs["token_type_ids"] = token_type_ids
        outputs = self.encoder(**encoder_kwargs)
        hidden = self.dropout(outputs.last_hidden_state)
        token_logits = self.token_classifier(hidden)
        start_logits = self.start_classifier(hidden)
        end_logits = self.end_classifier(hidden)

        loss = None
        if token_labels is not None and start_positions is not None and end_positions is not None:
            if token_mask is None:
                token_mask = attention_mask
            mask = token_mask.float().unsqueeze(-1)
            boundary_pos_weight = self.boundary_pos_weight.to(token_logits.device)
            presence_pos_weight = self.presence_pos_weight.to(token_logits.device)
            bce_boundary = nn.BCEWithLogitsLoss(reduction="none", pos_weight=boundary_pos_weight)
            bce_presence = nn.BCEWithLogitsLoss(reduction="none", pos_weight=presence_pos_weight)
            token_targets = token_labels.float()
            start_targets = start_positions.float()
            end_targets = end_positions.float()
            token_loss = bce_presence(token_logits, token_targets)
            start_loss = bce_boundary(start_logits, start_targets)
            end_loss = bce_boundary(end_logits, end_targets)
            token_focal_gamma = float(getattr(self.config, "token_focal_gamma", getattr(self.config, "focal_gamma", 0.0)))
            boundary_focal_gamma = float(getattr(self.config, "boundary_focal_gamma", getattr(self.config, "focal_gamma", 0.0)))
            if token_focal_gamma > 0.0:
                token_loss = apply_focal_weight(token_loss, token_logits, token_targets, token_focal_gamma)
            if boundary_focal_gamma > 0.0:
                start_loss = apply_focal_weight(start_loss, start_logits, start_targets, boundary_focal_gamma)
                end_loss = apply_focal_weight(end_loss, end_logits, end_targets, boundary_focal_gamma)
            token_hard_fraction = float(getattr(self.config, "token_hard_fraction", getattr(self.config, "hard_fraction", 1.0)))
            boundary_hard_fraction = float(getattr(self.config, "boundary_hard_fraction", getattr(self.config, "hard_fraction", 1.0)))
            token_loss = reduce_masked_loss(token_loss, mask, token_hard_fraction)
            start_loss = reduce_masked_loss(start_loss, mask, boundary_hard_fraction)
            end_loss = reduce_masked_loss(end_loss, mask, boundary_hard_fraction)
            boundary_loss = 0.5 * (start_loss + end_loss)
            token_weight = float(getattr(self.config, "token_presence_weight", 1.0))
            boundary_weight = float(getattr(self.config, "boundary_loss_weight", 1.0))
            loss = token_weight * token_loss + boundary_weight * boundary_loss

        return MultitaskSpanOutput(
            loss=loss,
            token_logits=token_logits,
            start_logits=start_logits,
            end_logits=end_logits,
        )


def apply_focal_weight(loss: torch.Tensor, logits: torch.Tensor, targets: torch.Tensor, gamma: float) -> torch.Tensor:
    probs = torch.sigmoid(logits)
    pt = probs * targets + (1.0 - probs) * (1.0 - targets)
    return loss * (1.0 - pt).pow(gamma)


def reduce_masked_loss(loss: torch.Tensor, mask: torch.Tensor, hard_fraction: float) -> torch.Tensor:
    expanded_mask = mask.expand_as(loss).bool()
    masked = loss.masked_select(expanded_mask)
    if masked.numel() == 0:
        return loss.sum() * 0.0
    if 0.0 < hard_fraction < 1.0:
        keep = max(1, math.ceil(masked.numel() * hard_fraction))
        masked = torch.topk(masked, keep).values
    return masked.mean()