| from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class FocalLoss(nn.Module): |
| """ |
| Focal Loss for multi-class classification. |
| gamma: focusing parameter that re-weights hard vs. easy examples. |
| alpha: optional weight for classes. Can be a single float or a tensor of shape [num_classes]. |
| If float, it's a uniform factor for all classes. If you want per-class weighting, |
| pass a 1D tensor with each entry being the class weight. |
| reduction: 'none', 'mean', or 'sum' |
| """ |
| def __init__(self, gamma=2.0, alpha=1.0, reduction='mean'): |
| super().__init__() |
| self.gamma = gamma |
| self.alpha = alpha |
| self.reduction = reduction |
|
|
| |
| |
|
|
| def forward(self, logits, targets): |
| """ |
| logits: tensor of shape (N, C), where C is number of classes |
| targets: tensor of shape (N,), with class indices [0..C-1] |
| """ |
| |
| ce_loss = F.cross_entropy(logits, targets, reduction='none') |
|
|
| |
| pt = torch.exp(-ce_loss) |
|
|
| |
| focal_loss = (1 - pt) ** self.gamma * ce_loss |
|
|
| |
| if isinstance(self.alpha, torch.Tensor): |
| |
| alpha_t = self.alpha[targets] |
| focal_loss = alpha_t * focal_loss |
| else: |
| |
| focal_loss = self.alpha * focal_loss |
|
|
| |
| if self.reduction == 'mean': |
| return focal_loss.mean() |
| elif self.reduction == 'sum': |
| return focal_loss.sum() |
| else: |
| |
| return focal_loss |
|
|
|
|
| class MultiHeadModelConfig(DebertaV2Config): |
| def __init__(self, label_maps=None, num_labels_dict=None, **kwargs): |
| super().__init__(**kwargs) |
| self.label_maps = label_maps or {} |
| self.num_labels_dict = num_labels_dict or {} |
|
|
| def to_dict(self): |
| output = super().to_dict() |
| output["label_maps"] = self.label_maps |
| output["num_labels_dict"] = self.num_labels_dict |
| return output |
|
|
|
|
| class MultiHeadModel(DebertaV2PreTrainedModel): |
| def __init__(self, config: MultiHeadModelConfig): |
| super().__init__(config) |
|
|
| self.deberta = DebertaV2Model(config) |
| self.classifiers = nn.ModuleDict() |
| self.loss_fct = FocalLoss(gamma=2.0, alpha=1.0, reduction='mean') |
|
|
| hidden_size = config.hidden_size |
| for label_name, n_labels in config.num_labels_dict.items(): |
| |
| self.classifiers[label_name] = nn.Sequential( |
| nn.Dropout( |
| 0.2 |
| ), |
| nn.Linear(hidden_size, hidden_size), |
| nn.GELU(), |
| nn.Linear(hidden_size, n_labels) |
| ) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| labels_dict=None, |
| **kwargs |
| ): |
| """ |
| labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids. |
| If provided, we compute and return the sum of CE losses. |
| """ |
| outputs = self.deberta( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| **kwargs |
| ) |
|
|
| sequence_output = outputs.last_hidden_state |
|
|
| logits_dict = {} |
| for label_name, classifier in self.classifiers.items(): |
| logits_dict[label_name] = classifier(sequence_output) |
|
|
| total_loss = None |
| loss_dict = {} |
| if labels_dict is not None: |
| |
| total_loss = 0.0 |
|
|
| for label_name, logits in logits_dict.items(): |
| if label_name not in labels_dict: |
| continue |
| label_ids = labels_dict[label_name] |
|
|
| |
| |
| active_loss = label_ids != -100 |
|
|
| |
| active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)] |
| active_labels = label_ids.view(-1)[active_loss.view(-1)] |
|
|
| loss = self.loss_fct(active_logits, active_labels) |
| loss_dict[label_name] = loss.item() |
| total_loss += loss |
|
|
| if labels_dict is not None: |
| |
| return total_loss, logits_dict |
| else: |
| |
| return logits_dict |
|
|