File size: 5,305 Bytes
c5081c8 8e63bf6 c5081c8 8e63bf6 c5081c8 0dfbd20 c5081c8 8e63bf6 c5081c8 0dfbd20 c5081c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""
Focal Loss for multi-class classification.
gamma: focusing parameter that re-weights hard vs. easy examples.
alpha: optional weight for classes. Can be a single float or a tensor of shape [num_classes].
If float, it's a uniform factor for all classes. If you want per-class weighting,
pass a 1D tensor with each entry being the class weight.
reduction: 'none', 'mean', or 'sum'
"""
def __init__(self, gamma=2.0, alpha=1.0, reduction='mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
# If alpha is a scalar, user must broadcast it later if needed
# If alpha is a tensor, it should be one entry per class
def forward(self, logits, targets):
"""
logits: tensor of shape (N, C), where C is number of classes
targets: tensor of shape (N,), with class indices [0..C-1]
"""
# Standard cross-entropy (not reduced)
ce_loss = F.cross_entropy(logits, targets, reduction='none') # shape (N,)
# pt = exp(-CE) = predicted probability of the true class
pt = torch.exp(-ce_loss) # shape (N,)
# Focal loss = alpha * (1-pt)^gamma * CE
focal_loss = (1 - pt) ** self.gamma * ce_loss
# If alpha is a tensor with shape [C], pick per-target alpha
if isinstance(self.alpha, torch.Tensor):
# alpha[targets] => shape (N,)
alpha_t = self.alpha[targets]
focal_loss = alpha_t * focal_loss
else:
# alpha is just a scalar
focal_loss = self.alpha * focal_loss
# reduction
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
# 'none'
return focal_loss
class MultiHeadModelConfig(DebertaV2Config):
def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
super().__init__(**kwargs)
self.label_maps = label_maps or {}
self.num_labels_dict = num_labels_dict or {}
def to_dict(self):
output = super().to_dict()
output["label_maps"] = self.label_maps
output["num_labels_dict"] = self.num_labels_dict
return output
class MultiHeadModel(DebertaV2PreTrainedModel):
def __init__(self, config: MultiHeadModelConfig):
super().__init__(config)
self.deberta = DebertaV2Model(config)
self.classifiers = nn.ModuleDict()
self.loss_fct = FocalLoss(gamma=2.0, alpha=1.0, reduction='mean')
hidden_size = config.hidden_size
for label_name, n_labels in config.num_labels_dict.items():
# Small feedforward module for each head
self.classifiers[label_name] = nn.Sequential(
nn.Dropout(
0.2 # Try 0.2 or 0.3 to see if overfitting reduces, if dataset is small or has noisy labels
),
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, n_labels)
)
# Initialize newly added weights
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels_dict=None,
**kwargs
):
"""
labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
If provided, we compute and return the sum of CE losses.
"""
outputs = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
**kwargs
)
sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
logits_dict = {}
for label_name, classifier in self.classifiers.items():
logits_dict[label_name] = classifier(sequence_output)
total_loss = None
loss_dict = {}
if labels_dict is not None:
# We'll sum the losses from each head
total_loss = 0.0
for label_name, logits in logits_dict.items():
if label_name not in labels_dict:
continue
label_ids = labels_dict[label_name]
# A typical approach for token classification:
# We ignore positions where label_ids == -100
active_loss = label_ids != -100 # shape (bs, seq_len)
# flatten everything
active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
active_labels = label_ids.view(-1)[active_loss.view(-1)]
loss = self.loss_fct(active_logits, active_labels)
loss_dict[label_name] = loss.item()
total_loss += loss
if labels_dict is not None:
# return (loss, predictions)
return total_loss, logits_dict
else:
# just return predictions
return logits_dict
|