compare-docs / model.py
anhtunguyen98's picture
Upload model.py with huggingface_hub
640344a verified
import torch
import torch.nn as nn
from datasets import load_dataset, load_from_disk, concatenate_datasets
from transformers import AutoTokenizer, TrainingArguments, Trainer, DataCollatorWithPadding, XLMRobertaPreTrainedModel, XLMRobertaModel
from transformers.modeling_outputs import SequenceClassifierOutput
import evaluate
import numpy as np
class HierarchicalXLMRoberta(XLMRobertaPreTrainedModel):
def __init__(self, config, num_labels_level1, num_labels_level2):
super().__init__(config)
self.num_labels_level1 = num_labels_level1
self.num_labels_level2 = num_labels_level2
self.roberta = XLMRobertaModel(config)
self.classifier_level1 = nn.Linear(config.hidden_size, num_labels_level1)
# Level2 classifier takes concatenated pooled_output + level1_logits
self.classifier_level2 = nn.Linear(config.hidden_size + num_labels_level1, num_labels_level2)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
level1_encoded=None,
level2_encoded=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs.pooler_output
logits_level1 = self.classifier_level1(pooled_output)
# Concatenate pooled_output with level1 logits for level2
combined_for_level2 = torch.cat([pooled_output, logits_level1], dim=-1)
logits_level2 = self.classifier_level2(combined_for_level2)
loss = None
if level1_encoded is not None and level2_encoded is not None:
loss_fct = nn.CrossEntropyLoss()
loss_level1 = loss_fct(logits_level1.view(-1, self.num_labels_level1), level1_encoded.view(-1))
loss_level2 = loss_fct(logits_level2.view(-1, self.num_labels_level2), level2_encoded.view(-1))
loss = loss_level1 + loss_level2 # Joint loss (you can weight them if needed, e.g., 0.5 * each)
if not return_dict:
output = (logits_level1, logits_level2) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=(logits_level1, logits_level2),
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)