from transformers import PretrainedConfig, AutoConfig class MultiTaskConfig(PretrainedConfig): model_type = "barthez-multitask" def __init__( self, base_model_name="moussaKam/barthez", num_labels_type=3, num_labels_priorite=3, **kwargs ): super().__init__(**kwargs) self.base_model_name = base_model_name self.num_labels_type = num_labels_type self.num_labels_priorite = num_labels_priorite # on charge la config du modèle de base self.base_config = AutoConfig.from_pretrained(base_model_name) import torch import torch.nn as nn from transformers import BartPreTrainedModel, BartModel from dataclasses import dataclass from transformers.modeling_outputs import ModelOutput from typing import Optional @dataclass class MultiTaskOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits_type: torch.FloatTensor = None logits_priorite: torch.FloatTensor = None class MultiTaskModel(BartPreTrainedModel): def __init__(self, config): super().__init__(config) self.model = BartModel(config) hidden_size = config.d_model self.classifier_type = nn.Linear(hidden_size, config.num_labels_type) self.classifier_priorite = nn.Linear(hidden_size, config.num_labels_priorite) self.loss_fct = nn.CrossEntropyLoss() self.post_init() def forward( self, input_ids=None, attention_mask=None, labels_type=None, labels_priorite=None, ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, ) pooled_output = outputs.last_hidden_state.mean(dim=1) logits_type = self.classifier_type(pooled_output) logits_priorite = self.classifier_priorite(pooled_output) loss = None if labels_type is not None and labels_priorite is not None: loss_type = self.loss_fct(logits_type, labels_type) loss_priorite = self.loss_fct(logits_priorite, labels_priorite) loss = loss_type + loss_priorite return MultiTaskOutput( loss=loss, logits_type=logits_type, logits_priorite=logits_priorite, )