| | from transformers import PretrainedConfig, AutoConfig
|
| |
|
| | class MultiTaskConfig(PretrainedConfig):
|
| | model_type = "barthez-multitask"
|
| |
|
| | def __init__(
|
| | self,
|
| | base_model_name="moussaKam/barthez",
|
| | num_labels_type=3,
|
| | num_labels_priorite=3,
|
| | **kwargs
|
| | ):
|
| | super().__init__(**kwargs)
|
| |
|
| | self.base_model_name = base_model_name
|
| | self.num_labels_type = num_labels_type
|
| | self.num_labels_priorite = num_labels_priorite
|
| |
|
| |
|
| | self.base_config = AutoConfig.from_pretrained(base_model_name)
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from transformers import BartPreTrainedModel, BartModel
|
| | from dataclasses import dataclass
|
| | from transformers.modeling_outputs import ModelOutput
|
| | from typing import Optional
|
| |
|
| |
|
| | @dataclass
|
| | class MultiTaskOutput(ModelOutput):
|
| | loss: Optional[torch.FloatTensor] = None
|
| | logits_type: torch.FloatTensor = None
|
| | logits_priorite: torch.FloatTensor = None
|
| |
|
| |
|
| | class MultiTaskModel(BartPreTrainedModel):
|
| |
|
| | def __init__(self, config):
|
| | super().__init__(config)
|
| |
|
| | self.model = BartModel(config)
|
| | hidden_size = config.d_model
|
| |
|
| | self.classifier_type = nn.Linear(hidden_size, config.num_labels_type)
|
| | self.classifier_priorite = nn.Linear(hidden_size, config.num_labels_priorite)
|
| |
|
| | self.loss_fct = nn.CrossEntropyLoss()
|
| |
|
| | self.post_init()
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids=None,
|
| | attention_mask=None,
|
| | labels_type=None,
|
| | labels_priorite=None,
|
| | ):
|
| | outputs = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | )
|
| |
|
| | pooled_output = outputs.last_hidden_state.mean(dim=1)
|
| |
|
| | logits_type = self.classifier_type(pooled_output)
|
| | logits_priorite = self.classifier_priorite(pooled_output)
|
| |
|
| | loss = None
|
| | if labels_type is not None and labels_priorite is not None:
|
| | loss_type = self.loss_fct(logits_type, labels_type)
|
| | loss_priorite = self.loss_fct(logits_priorite, labels_priorite)
|
| | loss = loss_type + loss_priorite
|
| |
|
| | return MultiTaskOutput(
|
| | loss=loss,
|
| | logits_type=logits_type,
|
| | logits_priorite=logits_priorite,
|
| | ) |