File size: 5,305 Bytes
c5081c8
8e63bf6
c5081c8
8e63bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5081c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0dfbd20
c5081c8
 
 
8e63bf6
 
 
 
 
 
 
 
 
c5081c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0dfbd20
c5081c8
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F


class FocalLoss(nn.Module):
    """
    Focal Loss for multi-class classification.
    gamma: focusing parameter that re-weights hard vs. easy examples.
    alpha: optional weight for classes. Can be a single float or a tensor of shape [num_classes].
           If float, it's a uniform factor for all classes. If you want per-class weighting,
           pass a 1D tensor with each entry being the class weight.
    reduction: 'none', 'mean', or 'sum'
    """
    def __init__(self, gamma=2.0, alpha=1.0, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

        # If alpha is a scalar, user must broadcast it later if needed
        # If alpha is a tensor, it should be one entry per class

    def forward(self, logits, targets):
        """
        logits: tensor of shape (N, C), where C is number of classes
        targets: tensor of shape (N,), with class indices [0..C-1]
        """
        # Standard cross-entropy (not reduced)
        ce_loss = F.cross_entropy(logits, targets, reduction='none')  # shape (N,)

        # pt = exp(-CE) = predicted probability of the true class
        pt = torch.exp(-ce_loss)  # shape (N,)

        # Focal loss = alpha * (1-pt)^gamma * CE
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        # If alpha is a tensor with shape [C], pick per-target alpha
        if isinstance(self.alpha, torch.Tensor):
            # alpha[targets] => shape (N,)
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss
        else:
            # alpha is just a scalar
            focal_loss = self.alpha * focal_loss

        # reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            # 'none'
            return focal_loss


class MultiHeadModelConfig(DebertaV2Config):
    def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
        super().__init__(**kwargs)
        self.label_maps = label_maps or {}
        self.num_labels_dict = num_labels_dict or {}

    def to_dict(self):
        output = super().to_dict()
        output["label_maps"] = self.label_maps
        output["num_labels_dict"] = self.num_labels_dict
        return output


class MultiHeadModel(DebertaV2PreTrainedModel):
    def __init__(self, config: MultiHeadModelConfig):
        super().__init__(config)

        self.deberta = DebertaV2Model(config)
        self.classifiers = nn.ModuleDict()
        self.loss_fct = FocalLoss(gamma=2.0, alpha=1.0, reduction='mean')

        hidden_size = config.hidden_size
        for label_name, n_labels in config.num_labels_dict.items():
            # Small feedforward module for each head
            self.classifiers[label_name] = nn.Sequential(
                nn.Dropout(
                    0.2  # Try 0.2 or 0.3 to see if overfitting reduces, if dataset is small or has noisy labels
                ),
                nn.Linear(hidden_size, hidden_size),
                nn.GELU(),
                nn.Linear(hidden_size, n_labels)
            )

        # Initialize newly added weights
        self.post_init()

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels_dict=None,
            **kwargs
    ):
        """
        labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
                     If provided, we compute and return the sum of CE losses.
        """
        outputs = self.deberta(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs
        )

        sequence_output = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)

        logits_dict = {}
        for label_name, classifier in self.classifiers.items():
            logits_dict[label_name] = classifier(sequence_output)

        total_loss = None
        loss_dict = {}
        if labels_dict is not None:
            # We'll sum the losses from each head
            total_loss = 0.0

            for label_name, logits in logits_dict.items():
                if label_name not in labels_dict:
                    continue
                label_ids = labels_dict[label_name]

                # A typical approach for token classification:
                # We ignore positions where label_ids == -100
                active_loss = label_ids != -100  # shape (bs, seq_len)

                # flatten everything
                active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
                active_labels = label_ids.view(-1)[active_loss.view(-1)]

                loss = self.loss_fct(active_logits, active_labels)
                loss_dict[label_name] = loss.item()
                total_loss += loss

        if labels_dict is not None:
            # return (loss, predictions)
            return total_loss, logits_dict
        else:
            # just return predictions
            return logits_dict