multi-classifier / multi_head_model.py
veryfansome's picture
feat: working end-to-end
c5081c8
raw
history blame
2.98 kB
from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel
import torch.nn as nn
class MultiHeadModelConfig(DebertaV2Config):
def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
super().__init__(**kwargs)
self.label_maps = label_maps or {}
self.num_labels_dict = num_labels_dict or {}
def to_dict(self):
output = super().to_dict()
output["label_maps"] = self.label_maps
output["num_labels_dict"] = self.num_labels_dict
return output
class MultiHeadModel(DebertaV2PreTrainedModel):
def __init__(self, config: MultiHeadModelConfig):
super().__init__(config)
self.deberta = DebertaV2Model(config)
self.classifiers = nn.ModuleDict()
hidden_size = config.hidden_size
for label_name, n_labels in config.num_labels_dict.items():
self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)
# Initialize newly added weights
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels_dict=None,
**kwargs
):
"""
labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
If provided, we compute and return the sum of CE losses.
"""
outputs = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
**kwargs
)
sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
logits_dict = {}
for label_name, classifier in self.classifiers.items():
logits_dict[label_name] = classifier(sequence_output)
total_loss = None
loss_dict = {}
if labels_dict is not None:
# We'll sum the losses from each head
loss_fct = nn.CrossEntropyLoss()
total_loss = 0.0
for label_name, logits in logits_dict.items():
if label_name not in labels_dict:
continue
label_ids = labels_dict[label_name]
# A typical approach for token classification:
# We ignore positions where label_ids == -100
active_loss = label_ids != -100 # shape (bs, seq_len)
# flatten everything
active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
active_labels = label_ids.view(-1)[active_loss.view(-1)]
loss = loss_fct(active_logits, active_labels)
loss_dict[label_name] = loss.item()
total_loss += loss
if labels_dict is not None:
# return (loss, predictions)
return total_loss, logits_dict
else:
# just return predictions
return logits_dict