| |
| from transformers import LEDConfig, LEDModel, LEDPreTrainedModel |
| from transformers.modeling_outputs import TokenClassifierOutput |
|
|
| import torch.nn as nn |
|
|
| class CustomLEDForResultsIdModel(LEDPreTrainedModel): |
| def __init__(self, config: LEDConfig, checkpoint=None): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| print("Configs") |
| print(config.num_labels) |
| print(config.dropout) |
|
|
| |
| if (checkpoint): |
| self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder() |
| else: |
| self.led = LEDModel(config).get_encoder() |
| |
| |
| self.dropout = nn.Dropout(config.dropout) |
| self.classifier = nn.Linear(self.led.config.d_model,self.num_labels) |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, global_attention_mask=None, return_loss=True): |
| |
| outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) |
| |
| sequence_output = self.dropout(outputs.last_hidden_state) |
| logits = self.classifier(sequence_output) |
|
|
| |
| |
| |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| return { |
| 'loss': loss, |
| 'logits': logits |
| } |