| import torch |
| import torch.nn as nn |
| from datasets import load_dataset, load_from_disk, concatenate_datasets |
| from transformers import AutoTokenizer, TrainingArguments, Trainer, DataCollatorWithPadding, XLMRobertaPreTrainedModel, XLMRobertaModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
| import evaluate |
| import numpy as np |
|
|
|
|
|
|
| class HierarchicalXLMRoberta(XLMRobertaPreTrainedModel): |
| def __init__(self, config, num_labels_level1, num_labels_level2): |
| super().__init__(config) |
| self.num_labels_level1 = num_labels_level1 |
| self.num_labels_level2 = num_labels_level2 |
| self.roberta = XLMRobertaModel(config) |
| self.classifier_level1 = nn.Linear(config.hidden_size, num_labels_level1) |
| |
| self.classifier_level2 = nn.Linear(config.hidden_size + num_labels_level1, num_labels_level2) |
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| level1_encoded=None, |
| level2_encoded=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.roberta( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = outputs.pooler_output |
|
|
| logits_level1 = self.classifier_level1(pooled_output) |
| |
| |
| combined_for_level2 = torch.cat([pooled_output, logits_level1], dim=-1) |
| logits_level2 = self.classifier_level2(combined_for_level2) |
|
|
| loss = None |
| if level1_encoded is not None and level2_encoded is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss_level1 = loss_fct(logits_level1.view(-1, self.num_labels_level1), level1_encoded.view(-1)) |
| loss_level2 = loss_fct(logits_level2.view(-1, self.num_labels_level2), level2_encoded.view(-1)) |
| loss = loss_level1 + loss_level2 |
|
|
| if not return_dict: |
| output = (logits_level1, logits_level2) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=(logits_level1, logits_level2), |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |