patrickott1 commited on
Commit
0feb130
·
verified ·
1 Parent(s): 3bd4e24

upload py

Browse files
Files changed (1) hide show
  1. modeling_multitask.py +79 -0
modeling_multitask.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, AutoConfig
2
+
3
+ class MultiTaskConfig(PretrainedConfig):
4
+ model_type = "barthez-multitask"
5
+
6
+ def __init__(
7
+ self,
8
+ base_model_name="moussaKam/barthez",
9
+ num_labels_type=3,
10
+ num_labels_priorite=3,
11
+ **kwargs
12
+ ):
13
+ super().__init__(**kwargs)
14
+
15
+ self.base_model_name = base_model_name
16
+ self.num_labels_type = num_labels_type
17
+ self.num_labels_priorite = num_labels_priorite
18
+
19
+ # on charge la config du modèle de base
20
+ self.base_config = AutoConfig.from_pretrained(base_model_name)
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from transformers import BartPreTrainedModel, BartModel
25
+ from dataclasses import dataclass
26
+ from transformers.modeling_outputs import ModelOutput
27
+ from typing import Optional
28
+
29
+
30
+ @dataclass
31
+ class MultiTaskOutput(ModelOutput):
32
+ loss: Optional[torch.FloatTensor] = None
33
+ logits_type: torch.FloatTensor = None
34
+ logits_priorite: torch.FloatTensor = None
35
+
36
+
37
+ class MultiTaskModel(BartPreTrainedModel):
38
+
39
+ def __init__(self, config):
40
+ super().__init__(config)
41
+
42
+ self.model = BartModel(config)
43
+ hidden_size = config.d_model
44
+
45
+ self.classifier_type = nn.Linear(hidden_size, config.num_labels_type)
46
+ self.classifier_priorite = nn.Linear(hidden_size, config.num_labels_priorite)
47
+
48
+ self.loss_fct = nn.CrossEntropyLoss()
49
+
50
+ self.post_init()
51
+
52
+ def forward(
53
+ self,
54
+ input_ids=None,
55
+ attention_mask=None,
56
+ labels_type=None,
57
+ labels_priorite=None,
58
+ ):
59
+ outputs = self.model(
60
+ input_ids=input_ids,
61
+ attention_mask=attention_mask,
62
+ )
63
+
64
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
65
+
66
+ logits_type = self.classifier_type(pooled_output)
67
+ logits_priorite = self.classifier_priorite(pooled_output)
68
+
69
+ loss = None
70
+ if labels_type is not None and labels_priorite is not None:
71
+ loss_type = self.loss_fct(logits_type, labels_type)
72
+ loss_priorite = self.loss_fct(logits_priorite, labels_priorite)
73
+ loss = loss_type + loss_priorite
74
+
75
+ return MultiTaskOutput(
76
+ loss=loss,
77
+ logits_type=logits_type,
78
+ logits_priorite=logits_priorite,
79
+ )