File size: 2,578 Bytes
6316722
 
 
 
 
e90dc4c
6316722
 
cdb2325
6316722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e90dc4c
 
6316722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.modeling_outputs import ModelOutput

from .configuration_suave_multitask import SuaveMultitaskConfig


@dataclass
class SuaveMultitaskOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits_binary: Optional[torch.FloatTensor] = None
    logits_multiclass: Optional[torch.FloatTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class SuaveMultitaskModel(PreTrainedModel):
    config_class = SuaveMultitaskConfig
    base_model_prefix = "encoder"

    def __init__(self, config: SuaveMultitaskConfig):
        super().__init__(config)
        base_config = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_config(base_config)
        hidden_size = self.encoder.config.hidden_size

        self.dropout = nn.Dropout(config.classifier_dropout)
        self.classifier_binary = nn.Linear(hidden_size, 2)
        self.classifier_multiclass = nn.Linear(hidden_size, config.num_ai_classes)

        self.post_init()

    def forward(

        self,

        input_ids=None,

        attention_mask=None,

        labels_binary=None,

        labels_multiclass=None,

        **kwargs,

    ):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=kwargs.get("output_hidden_states", False),
            output_attentions=kwargs.get("output_attentions", False),
        )

        pooled = outputs.last_hidden_state[:, 0]
        pooled = self.dropout(pooled)

        logits_binary = self.classifier_binary(pooled)
        logits_multiclass = self.classifier_multiclass(pooled)

        loss = None
        if labels_binary is not None and labels_multiclass is not None:
            loss_binary = nn.CrossEntropyLoss()(logits_binary, labels_binary)
            loss_multiclass = nn.CrossEntropyLoss(ignore_index=-1)(
                logits_multiclass, labels_multiclass
            )
            loss = loss_binary + 0.5 * loss_multiclass

        return SuaveMultitaskOutput(
            loss=loss,
            logits_binary=logits_binary,
            logits_multiclass=logits_multiclass,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )