File size: 2,976 Bytes
c5081c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from transformers import DebertaV2Config, DebertaV2Model, DebertaV2PreTrainedModel
import torch.nn as nn
class MultiHeadModelConfig(DebertaV2Config):
def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
super().__init__(**kwargs)
self.label_maps = label_maps or {}
self.num_labels_dict = num_labels_dict or {}
def to_dict(self):
output = super().to_dict()
output["label_maps"] = self.label_maps
output["num_labels_dict"] = self.num_labels_dict
return output
class MultiHeadModel(DebertaV2PreTrainedModel):
def __init__(self, config: MultiHeadModelConfig):
super().__init__(config)
self.deberta = DebertaV2Model(config)
self.classifiers = nn.ModuleDict()
hidden_size = config.hidden_size
for label_name, n_labels in config.num_labels_dict.items():
self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)
# Initialize newly added weights
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels_dict=None,
**kwargs
):
"""
labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
If provided, we compute and return the sum of CE losses.
"""
outputs = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
**kwargs
)
sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
logits_dict = {}
for label_name, classifier in self.classifiers.items():
logits_dict[label_name] = classifier(sequence_output)
total_loss = None
loss_dict = {}
if labels_dict is not None:
# We'll sum the losses from each head
loss_fct = nn.CrossEntropyLoss()
total_loss = 0.0
for label_name, logits in logits_dict.items():
if label_name not in labels_dict:
continue
label_ids = labels_dict[label_name]
# A typical approach for token classification:
# We ignore positions where label_ids == -100
active_loss = label_ids != -100 # shape (bs, seq_len)
# flatten everything
active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
active_labels = label_ids.view(-1)[active_loss.view(-1)]
loss = loss_fct(active_logits, active_labels)
loss_dict[label_name] = loss.item()
total_loss += loss
if labels_dict is not None:
# return (loss, predictions)
return total_loss, logits_dict
else:
# just return predictions
return logits_dict
|