File size: 1,626 Bytes
1cb894c
 
 
 
358e141
1cb894c
 
 
 
 
 
 
 
 
 
358e141
 
 
 
 
 
 
 
 
 
 
 
1cb894c
 
358e141
 
 
 
1cb894c
 
358e141
 
 
 
1cb894c
358e141
 
 
1cb894c
358e141
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaPreTrainedModel


class RobertaMultiTask(RobertaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels      = config.num_labels
        self.roberta         = RobertaModel(config)
        self.dropout         = nn.Dropout(config.hidden_dropout_prob)
        self.classifier      = nn.Linear(config.hidden_size, config.num_labels)
        self.span_classifier = nn.Linear(config.hidden_size, 2)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        span_labels=None
    ):
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask
        )
        sequence_output = self.dropout(outputs.last_hidden_state)
        pooled_output   = self.dropout(outputs.pooler_output)

        logits      = self.classifier(pooled_output)
        span_logits = self.span_classifier(sequence_output)

        loss = None
        if labels is not None and span_labels is not None:
            cls_loss = nn.CrossEntropyLoss()(
                logits.view(-1, self.num_labels),
                labels.view(-1)
            )
            span_loss = nn.CrossEntropyLoss(ignore_index=-100)(
                span_logits.view(-1, 2),
                span_labels.view(-1)
            )
            loss = cls_loss + 0.3 * span_loss

        return {
            "loss":        loss,
            "logits":      logits,
            "span_logits": span_logits
        }