File size: 2,366 Bytes
d93ed3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from transformers import PretrainedConfig, AutoConfig

class MultiTaskConfig(PretrainedConfig):
    model_type = "barthez-multitask"

    def __init__(

        self,

        base_model_name="moussaKam/barthez",

        num_labels_type=3,

        num_labels_priorite=3,

        **kwargs

    ):
        super().__init__(**kwargs)

        self.base_model_name = base_model_name
        self.num_labels_type = num_labels_type
        self.num_labels_priorite = num_labels_priorite

        # on charge la config du modèle de base
        self.base_config = AutoConfig.from_pretrained(base_model_name)

import torch
import torch.nn as nn
from transformers import BartPreTrainedModel, BartModel
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
from typing import Optional


@dataclass
class MultiTaskOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits_type: torch.FloatTensor = None
    logits_priorite: torch.FloatTensor = None


class MultiTaskModel(BartPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)

        self.model = BartModel(config)
        hidden_size = config.d_model

        self.classifier_type = nn.Linear(hidden_size, config.num_labels_type)
        self.classifier_priorite = nn.Linear(hidden_size, config.num_labels_priorite)

        self.loss_fct = nn.CrossEntropyLoss()

        self.post_init()

    def forward(

        self,

        input_ids=None,

        attention_mask=None,

        labels_type=None,

        labels_priorite=None,

    ):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        pooled_output = outputs.last_hidden_state.mean(dim=1)

        logits_type = self.classifier_type(pooled_output)
        logits_priorite = self.classifier_priorite(pooled_output)

        loss = None
        if labels_type is not None and labels_priorite is not None:
            loss_type = self.loss_fct(logits_type, labels_type)
            loss_priorite = self.loss_fct(logits_priorite, labels_priorite)
            loss = loss_type + loss_priorite

        return MultiTaskOutput(
            loss=loss,
            logits_type=logits_type,
            logits_priorite=logits_priorite,
        )