File size: 2,976 Bytes
c5081c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel
import torch.nn as nn


class MultiHeadModelConfig(DebertaV2Config):
    def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
        super().__init__(**kwargs)
        self.label_maps = label_maps or {}
        self.num_labels_dict = num_labels_dict or {}

    def to_dict(self):
        output = super().to_dict()
        output["label_maps"] = self.label_maps
        output["num_labels_dict"] = self.num_labels_dict
        return output


class MultiHeadModel(DebertaV2PreTrainedModel):
    def __init__(self, config: MultiHeadModelConfig):
        super().__init__(config)

        self.deberta = DebertaV2Model(config)
        self.classifiers = nn.ModuleDict()

        hidden_size = config.hidden_size
        for label_name, n_labels in config.num_labels_dict.items():
            self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)

        # Initialize newly added weights
        self.post_init()

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels_dict=None,
            **kwargs
    ):
        """
        labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
                     If provided, we compute and return the sum of CE losses.
        """
        outputs = self.deberta(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs
        )

        sequence_output = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)

        logits_dict = {}
        for label_name, classifier in self.classifiers.items():
            logits_dict[label_name] = classifier(sequence_output)

        total_loss = None
        loss_dict = {}
        if labels_dict is not None:
            # We'll sum the losses from each head
            loss_fct = nn.CrossEntropyLoss()
            total_loss = 0.0

            for label_name, logits in logits_dict.items():
                if label_name not in labels_dict:
                    continue
                label_ids = labels_dict[label_name]

                # A typical approach for token classification:
                # We ignore positions where label_ids == -100
                active_loss = label_ids != -100  # shape (bs, seq_len)

                # flatten everything
                active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
                active_labels = label_ids.view(-1)[active_loss.view(-1)]

                loss = loss_fct(active_logits, active_labels)
                loss_dict[label_name] = loss.item()
                total_loss += loss

        if labels_dict is not None:
            # return (loss, predictions)
            return total_loss, logits_dict
        else:
            # just return predictions
            return logits_dict