| import torch |
| import torch.nn as nn |
| from transformers import RobertaModel, RobertaPreTrainedModel |
|
|
|
|
| class RobertaMultiTask(RobertaPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.roberta = RobertaModel(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| self.span_classifier = nn.Linear(config.hidden_size, 2) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| labels=None, |
| span_labels=None |
| ): |
| outputs = self.roberta( |
| input_ids, |
| attention_mask=attention_mask |
| ) |
|
|
| sequence_output = self.dropout(outputs.last_hidden_state) |
| pooled_output = self.dropout(outputs.pooler_output) |
|
|
| logits = self.classifier(pooled_output) |
| span_logits = self.span_classifier(sequence_output) |
|
|
| loss = None |
| if labels is not None and span_labels is not None: |
| cls_loss = nn.CrossEntropyLoss()( |
| logits.view(-1, self.num_labels), |
| labels.view(-1) |
| ) |
| span_loss = nn.CrossEntropyLoss(ignore_index=-100)( |
| span_logits.view(-1, 2), |
| span_labels.view(-1) |
| ) |
| loss = cls_loss + 0.3 * span_loss |
|
|
| return { |
| "loss": loss, |
| "logits": logits, |
| "span_logits": span_logits |
| } |