|
|
from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
""" |
|
|
Focal Loss for multi-class classification. |
|
|
gamma: focusing parameter that re-weights hard vs. easy examples. |
|
|
alpha: optional weight for classes. Can be a single float or a tensor of shape [num_classes]. |
|
|
If float, it's a uniform factor for all classes. If you want per-class weighting, |
|
|
pass a 1D tensor with each entry being the class weight. |
|
|
reduction: 'none', 'mean', or 'sum' |
|
|
""" |
|
|
def __init__(self, gamma=2.0, alpha=1.0, reduction='mean'): |
|
|
super().__init__() |
|
|
self.gamma = gamma |
|
|
self.alpha = alpha |
|
|
self.reduction = reduction |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, logits, targets): |
|
|
""" |
|
|
logits: tensor of shape (N, C), where C is number of classes |
|
|
targets: tensor of shape (N,), with class indices [0..C-1] |
|
|
""" |
|
|
|
|
|
ce_loss = F.cross_entropy(logits, targets, reduction='none') |
|
|
|
|
|
|
|
|
pt = torch.exp(-ce_loss) |
|
|
|
|
|
|
|
|
focal_loss = (1 - pt) ** self.gamma * ce_loss |
|
|
|
|
|
|
|
|
if isinstance(self.alpha, torch.Tensor): |
|
|
|
|
|
alpha_t = self.alpha[targets] |
|
|
focal_loss = alpha_t * focal_loss |
|
|
else: |
|
|
|
|
|
focal_loss = self.alpha * focal_loss |
|
|
|
|
|
|
|
|
if self.reduction == 'mean': |
|
|
return focal_loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
return focal_loss.sum() |
|
|
else: |
|
|
|
|
|
return focal_loss |
|
|
|
|
|
|
|
|
class MultiHeadModelConfig(DebertaV2Config): |
|
|
def __init__(self, label_maps=None, num_labels_dict=None, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.label_maps = label_maps or {} |
|
|
self.num_labels_dict = num_labels_dict or {} |
|
|
|
|
|
def to_dict(self): |
|
|
output = super().to_dict() |
|
|
output["label_maps"] = self.label_maps |
|
|
output["num_labels_dict"] = self.num_labels_dict |
|
|
return output |
|
|
|
|
|
|
|
|
class MultiHeadModel(DebertaV2PreTrainedModel): |
|
|
def __init__(self, config: MultiHeadModelConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.deberta = DebertaV2Model(config) |
|
|
self.classifiers = nn.ModuleDict() |
|
|
self.loss_fct = FocalLoss(gamma=2.0, alpha=1.0, reduction='mean') |
|
|
|
|
|
hidden_size = config.hidden_size |
|
|
for label_name, n_labels in config.num_labels_dict.items(): |
|
|
|
|
|
self.classifiers[label_name] = nn.Sequential( |
|
|
nn.Dropout( |
|
|
0.2 |
|
|
), |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_size, n_labels) |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
token_type_ids=None, |
|
|
labels_dict=None, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids. |
|
|
If provided, we compute and return the sum of CE losses. |
|
|
""" |
|
|
outputs = self.deberta( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
sequence_output = outputs.last_hidden_state |
|
|
|
|
|
logits_dict = {} |
|
|
for label_name, classifier in self.classifiers.items(): |
|
|
logits_dict[label_name] = classifier(sequence_output) |
|
|
|
|
|
total_loss = None |
|
|
loss_dict = {} |
|
|
if labels_dict is not None: |
|
|
|
|
|
total_loss = 0.0 |
|
|
|
|
|
for label_name, logits in logits_dict.items(): |
|
|
if label_name not in labels_dict: |
|
|
continue |
|
|
label_ids = labels_dict[label_name] |
|
|
|
|
|
|
|
|
|
|
|
active_loss = label_ids != -100 |
|
|
|
|
|
|
|
|
active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)] |
|
|
active_labels = label_ids.view(-1)[active_loss.view(-1)] |
|
|
|
|
|
loss = self.loss_fct(active_logits, active_labels) |
|
|
loss_dict[label_name] = loss.item() |
|
|
total_loss += loss |
|
|
|
|
|
if labels_dict is not None: |
|
|
|
|
|
return total_loss, logits_dict |
|
|
else: |
|
|
|
|
|
return logits_dict |
|
|
|