from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): """ Focal Loss for multi-class classification. gamma: focusing parameter that re-weights hard vs. easy examples. alpha: optional weight for classes. Can be a single float or a tensor of shape [num_classes]. If float, it's a uniform factor for all classes. If you want per-class weighting, pass a 1D tensor with each entry being the class weight. reduction: 'none', 'mean', or 'sum' """ def __init__(self, gamma=2.0, alpha=1.0, reduction='mean'): super().__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction # If alpha is a scalar, user must broadcast it later if needed # If alpha is a tensor, it should be one entry per class def forward(self, logits, targets): """ logits: tensor of shape (N, C), where C is number of classes targets: tensor of shape (N,), with class indices [0..C-1] """ # Standard cross-entropy (not reduced) ce_loss = F.cross_entropy(logits, targets, reduction='none') # shape (N,) # pt = exp(-CE) = predicted probability of the true class pt = torch.exp(-ce_loss) # shape (N,) # Focal loss = alpha * (1-pt)^gamma * CE focal_loss = (1 - pt) ** self.gamma * ce_loss # If alpha is a tensor with shape [C], pick per-target alpha if isinstance(self.alpha, torch.Tensor): # alpha[targets] => shape (N,) alpha_t = self.alpha[targets] focal_loss = alpha_t * focal_loss else: # alpha is just a scalar focal_loss = self.alpha * focal_loss # reduction if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: # 'none' return focal_loss class MultiHeadModelConfig(DebertaV2Config): def __init__(self, label_maps=None, num_labels_dict=None, **kwargs): super().__init__(**kwargs) self.label_maps = label_maps or {} self.num_labels_dict = num_labels_dict or {} def to_dict(self): output = super().to_dict() output["label_maps"] = self.label_maps output["num_labels_dict"] = self.num_labels_dict return output class MultiHeadModel(DebertaV2PreTrainedModel): def __init__(self, config: MultiHeadModelConfig): super().__init__(config) self.deberta = DebertaV2Model(config) self.classifiers = nn.ModuleDict() self.loss_fct = FocalLoss(gamma=2.0, alpha=1.0, reduction='mean') hidden_size = config.hidden_size for label_name, n_labels in config.num_labels_dict.items(): # Small feedforward module for each head self.classifiers[label_name] = nn.Sequential( nn.Dropout( 0.2 # Try 0.2 or 0.3 to see if overfitting reduces, if dataset is small or has noisy labels ), nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, n_labels) ) # Initialize newly added weights self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, labels_dict=None, **kwargs ): """ labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids. If provided, we compute and return the sum of CE losses. """ outputs = self.deberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **kwargs ) sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size) logits_dict = {} for label_name, classifier in self.classifiers.items(): logits_dict[label_name] = classifier(sequence_output) total_loss = None loss_dict = {} if labels_dict is not None: # We'll sum the losses from each head total_loss = 0.0 for label_name, logits in logits_dict.items(): if label_name not in labels_dict: continue label_ids = labels_dict[label_name] # A typical approach for token classification: # We ignore positions where label_ids == -100 active_loss = label_ids != -100 # shape (bs, seq_len) # flatten everything active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)] active_labels = label_ids.view(-1)[active_loss.view(-1)] loss = self.loss_fct(active_logits, active_labels) loss_dict[label_name] = loss.item() total_loss += loss if labels_dict is not None: # return (loss, predictions) return total_loss, logits_dict else: # just return predictions return logits_dict