model-ticket-ti / modeling_multitask.py
patrickott1's picture
Add modeling_multitask
d93ed3a verified
from transformers import PretrainedConfig, AutoConfig
class MultiTaskConfig(PretrainedConfig):
model_type = "barthez-multitask"
def __init__(
self,
base_model_name="moussaKam/barthez",
num_labels_type=3,
num_labels_priorite=3,
**kwargs
):
super().__init__(**kwargs)
self.base_model_name = base_model_name
self.num_labels_type = num_labels_type
self.num_labels_priorite = num_labels_priorite
# on charge la config du modèle de base
self.base_config = AutoConfig.from_pretrained(base_model_name)
import torch
import torch.nn as nn
from transformers import BartPreTrainedModel, BartModel
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
from typing import Optional
@dataclass
class MultiTaskOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits_type: torch.FloatTensor = None
logits_priorite: torch.FloatTensor = None
class MultiTaskModel(BartPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = BartModel(config)
hidden_size = config.d_model
self.classifier_type = nn.Linear(hidden_size, config.num_labels_type)
self.classifier_priorite = nn.Linear(hidden_size, config.num_labels_priorite)
self.loss_fct = nn.CrossEntropyLoss()
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
labels_type=None,
labels_priorite=None,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)
pooled_output = outputs.last_hidden_state.mean(dim=1)
logits_type = self.classifier_type(pooled_output)
logits_priorite = self.classifier_priorite(pooled_output)
loss = None
if labels_type is not None and labels_priorite is not None:
loss_type = self.loss_fct(logits_type, labels_type)
loss_priorite = self.loss_fct(logits_priorite, labels_priorite)
loss = loss_type + loss_priorite
return MultiTaskOutput(
loss=loss,
logits_type=logits_type,
logits_priorite=logits_priorite,
)