File size: 2,366 Bytes
d93ed3a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | from transformers import PretrainedConfig, AutoConfig
class MultiTaskConfig(PretrainedConfig):
model_type = "barthez-multitask"
def __init__(
self,
base_model_name="moussaKam/barthez",
num_labels_type=3,
num_labels_priorite=3,
**kwargs
):
super().__init__(**kwargs)
self.base_model_name = base_model_name
self.num_labels_type = num_labels_type
self.num_labels_priorite = num_labels_priorite
# on charge la config du modèle de base
self.base_config = AutoConfig.from_pretrained(base_model_name)
import torch
import torch.nn as nn
from transformers import BartPreTrainedModel, BartModel
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
from typing import Optional
@dataclass
class MultiTaskOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits_type: torch.FloatTensor = None
logits_priorite: torch.FloatTensor = None
class MultiTaskModel(BartPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = BartModel(config)
hidden_size = config.d_model
self.classifier_type = nn.Linear(hidden_size, config.num_labels_type)
self.classifier_priorite = nn.Linear(hidden_size, config.num_labels_priorite)
self.loss_fct = nn.CrossEntropyLoss()
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
labels_type=None,
labels_priorite=None,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)
pooled_output = outputs.last_hidden_state.mean(dim=1)
logits_type = self.classifier_type(pooled_output)
logits_priorite = self.classifier_priorite(pooled_output)
loss = None
if labels_type is not None and labels_priorite is not None:
loss_type = self.loss_fct(logits_type, labels_type)
loss_priorite = self.loss_fct(logits_priorite, labels_priorite)
loss = loss_type + loss_priorite
return MultiTaskOutput(
loss=loss,
logits_type=logits_type,
logits_priorite=logits_priorite,
) |