| | from dataclasses import dataclass
|
| | from typing import Optional, Tuple
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from transformers import AutoConfig, AutoModel, PreTrainedModel
|
| | from transformers.modeling_outputs import ModelOutput
|
| |
|
| | from .configuration_suave_multitask import SuaveMultitaskConfig
|
| |
|
| |
|
| | @dataclass
|
| | class SuaveMultitaskOutput(ModelOutput):
|
| | loss: Optional[torch.FloatTensor] = None
|
| | logits_binary: Optional[torch.FloatTensor] = None
|
| | logits_multiclass: Optional[torch.FloatTensor] = None
|
| | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| |
|
| |
|
| | class SuaveMultitaskModel(PreTrainedModel):
|
| | config_class = SuaveMultitaskConfig
|
| | base_model_prefix = "encoder"
|
| |
|
| | def __init__(self, config: SuaveMultitaskConfig):
|
| | super().__init__(config)
|
| | base_config = AutoConfig.from_pretrained(config.base_model_name)
|
| | self.encoder = AutoModel.from_config(base_config)
|
| | hidden_size = self.encoder.config.hidden_size
|
| |
|
| | self.dropout = nn.Dropout(config.classifier_dropout)
|
| | self.classifier_binary = nn.Linear(hidden_size, 2)
|
| | self.classifier_multiclass = nn.Linear(hidden_size, config.num_ai_classes)
|
| |
|
| | self.post_init()
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids=None,
|
| | attention_mask=None,
|
| | labels_binary=None,
|
| | labels_multiclass=None,
|
| | **kwargs,
|
| | ):
|
| | outputs = self.encoder(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | output_hidden_states=kwargs.get("output_hidden_states", False),
|
| | output_attentions=kwargs.get("output_attentions", False),
|
| | )
|
| |
|
| | pooled = outputs.last_hidden_state[:, 0]
|
| | pooled = self.dropout(pooled)
|
| |
|
| | logits_binary = self.classifier_binary(pooled)
|
| | logits_multiclass = self.classifier_multiclass(pooled)
|
| |
|
| | loss = None
|
| | if labels_binary is not None and labels_multiclass is not None:
|
| | loss_binary = nn.CrossEntropyLoss()(logits_binary, labels_binary)
|
| | loss_multiclass = nn.CrossEntropyLoss(ignore_index=-1)(
|
| | logits_multiclass, labels_multiclass
|
| | )
|
| | loss = loss_binary + 0.5 * loss_multiclass
|
| |
|
| | return SuaveMultitaskOutput(
|
| | loss=loss,
|
| | logits_binary=logits_binary,
|
| | logits_multiclass=logits_multiclass,
|
| | hidden_states=outputs.hidden_states,
|
| | attentions=outputs.attentions,
|
| | )
|
| |
|