bias-detection-api / modeling_roberta_multitask.py
PreranTej's picture
Update modeling_roberta_multitask.py
358e141 verified
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaPreTrainedModel
class RobertaMultiTask(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.span_classifier = nn.Linear(config.hidden_size, 2)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None,
span_labels=None
):
outputs = self.roberta(
input_ids,
attention_mask=attention_mask
)
sequence_output = self.dropout(outputs.last_hidden_state)
pooled_output = self.dropout(outputs.pooler_output)
logits = self.classifier(pooled_output)
span_logits = self.span_classifier(sequence_output)
loss = None
if labels is not None and span_labels is not None:
cls_loss = nn.CrossEntropyLoss()(
logits.view(-1, self.num_labels),
labels.view(-1)
)
span_loss = nn.CrossEntropyLoss(ignore_index=-100)(
span_logits.view(-1, 2),
span_labels.view(-1)
)
loss = cls_loss + 0.3 * span_loss
return {
"loss": loss,
"logits": logits,
"span_logits": span_logits
}