| | from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel |
| | import torch.nn as nn |
| |
|
| |
|
| | class MultiHeadModelConfig(DebertaV2Config): |
| | def __init__(self, label_maps=None, num_labels_dict=None, **kwargs): |
| | super().__init__(**kwargs) |
| | self.label_maps = label_maps or {} |
| | self.num_labels_dict = num_labels_dict or {} |
| |
|
| | def to_dict(self): |
| | output = super().to_dict() |
| | output["label_maps"] = self.label_maps |
| | output["num_labels_dict"] = self.num_labels_dict |
| | return output |
| |
|
| |
|
| | class MultiHeadModel(DebertaV2PreTrainedModel): |
| | def __init__(self, config: MultiHeadModelConfig): |
| | super().__init__(config) |
| |
|
| | self.deberta = DebertaV2Model(config) |
| | self.classifiers = nn.ModuleDict() |
| |
|
| | hidden_size = config.hidden_size |
| | for label_name, n_labels in config.num_labels_dict.items(): |
| | self.classifiers[label_name] = nn.Linear(hidden_size, n_labels) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | token_type_ids=None, |
| | labels_dict=None, |
| | **kwargs |
| | ): |
| | """ |
| | labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids. |
| | If provided, we compute and return the sum of CE losses. |
| | """ |
| | outputs = self.deberta( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | **kwargs |
| | ) |
| |
|
| | sequence_output = outputs.last_hidden_state |
| |
|
| | logits_dict = {} |
| | for label_name, classifier in self.classifiers.items(): |
| | logits_dict[label_name] = classifier(sequence_output) |
| |
|
| | total_loss = None |
| | loss_dict = {} |
| | if labels_dict is not None: |
| | |
| | loss_fct = nn.CrossEntropyLoss() |
| | total_loss = 0.0 |
| |
|
| | for label_name, logits in logits_dict.items(): |
| | if label_name not in labels_dict: |
| | continue |
| | label_ids = labels_dict[label_name] |
| |
|
| | |
| | |
| | active_loss = label_ids != -100 |
| |
|
| | |
| | active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)] |
| | active_labels = label_ids.view(-1)[active_loss.view(-1)] |
| |
|
| | loss = loss_fct(active_logits, active_labels) |
| | loss_dict[label_name] = loss.item() |
| | total_loss += loss |
| |
|
| | if labels_dict is not None: |
| | |
| | return total_loss, logits_dict |
| | else: |
| | |
| | return logits_dict |
| |
|