multi-classifier / multi_head_model.py
veryfansome's picture
Big cleanup
0dfbd20
from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""
Focal Loss for multi-class classification.
gamma: focusing parameter that re-weights hard vs. easy examples.
alpha: optional weight for classes. Can be a single float or a tensor of shape [num_classes].
If float, it's a uniform factor for all classes. If you want per-class weighting,
pass a 1D tensor with each entry being the class weight.
reduction: 'none', 'mean', or 'sum'
"""
def __init__(self, gamma=2.0, alpha=1.0, reduction='mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
# If alpha is a scalar, user must broadcast it later if needed
# If alpha is a tensor, it should be one entry per class
def forward(self, logits, targets):
"""
logits: tensor of shape (N, C), where C is number of classes
targets: tensor of shape (N,), with class indices [0..C-1]
"""
# Standard cross-entropy (not reduced)
ce_loss = F.cross_entropy(logits, targets, reduction='none') # shape (N,)
# pt = exp(-CE) = predicted probability of the true class
pt = torch.exp(-ce_loss) # shape (N,)
# Focal loss = alpha * (1-pt)^gamma * CE
focal_loss = (1 - pt) ** self.gamma * ce_loss
# If alpha is a tensor with shape [C], pick per-target alpha
if isinstance(self.alpha, torch.Tensor):
# alpha[targets] => shape (N,)
alpha_t = self.alpha[targets]
focal_loss = alpha_t * focal_loss
else:
# alpha is just a scalar
focal_loss = self.alpha * focal_loss
# reduction
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
# 'none'
return focal_loss
class MultiHeadModelConfig(DebertaV2Config):
def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
super().__init__(**kwargs)
self.label_maps = label_maps or {}
self.num_labels_dict = num_labels_dict or {}
def to_dict(self):
output = super().to_dict()
output["label_maps"] = self.label_maps
output["num_labels_dict"] = self.num_labels_dict
return output
class MultiHeadModel(DebertaV2PreTrainedModel):
def __init__(self, config: MultiHeadModelConfig):
super().__init__(config)
self.deberta = DebertaV2Model(config)
self.classifiers = nn.ModuleDict()
self.loss_fct = FocalLoss(gamma=2.0, alpha=1.0, reduction='mean')
hidden_size = config.hidden_size
for label_name, n_labels in config.num_labels_dict.items():
# Small feedforward module for each head
self.classifiers[label_name] = nn.Sequential(
nn.Dropout(
0.2 # Try 0.2 or 0.3 to see if overfitting reduces, if dataset is small or has noisy labels
),
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, n_labels)
)
# Initialize newly added weights
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels_dict=None,
**kwargs
):
"""
labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
If provided, we compute and return the sum of CE losses.
"""
outputs = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
**kwargs
)
sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
logits_dict = {}
for label_name, classifier in self.classifiers.items():
logits_dict[label_name] = classifier(sequence_output)
total_loss = None
loss_dict = {}
if labels_dict is not None:
# We'll sum the losses from each head
total_loss = 0.0
for label_name, logits in logits_dict.items():
if label_name not in labels_dict:
continue
label_ids = labels_dict[label_name]
# A typical approach for token classification:
# We ignore positions where label_ids == -100
active_loss = label_ids != -100 # shape (bs, seq_len)
# flatten everything
active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
active_labels = label_ids.view(-1)[active_loss.view(-1)]
loss = self.loss_fct(active_logits, active_labels)
loss_dict[label_name] = loss.item()
total_loss += loss
if labels_dict is not None:
# return (loss, predictions)
return total_loss, logits_dict
else:
# just return predictions
return logits_dict