from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.nn as nn from transformers import AutoConfig, AutoModel, PreTrainedModel from transformers.modeling_outputs import ModelOutput from .configuration_suave_multitask import SuaveMultitaskConfig @dataclass class SuaveMultitaskOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits_binary: Optional[torch.FloatTensor] = None logits_multiclass: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class SuaveMultitaskModel(PreTrainedModel): config_class = SuaveMultitaskConfig base_model_prefix = "encoder" def __init__(self, config: SuaveMultitaskConfig): super().__init__(config) base_config = AutoConfig.from_pretrained(config.base_model_name) self.encoder = AutoModel.from_config(base_config) hidden_size = self.encoder.config.hidden_size self.dropout = nn.Dropout(config.classifier_dropout) self.classifier_binary = nn.Linear(hidden_size, 2) self.classifier_multiclass = nn.Linear(hidden_size, config.num_ai_classes) self.post_init() def forward( self, input_ids=None, attention_mask=None, labels_binary=None, labels_multiclass=None, **kwargs, ): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=kwargs.get("output_hidden_states", False), output_attentions=kwargs.get("output_attentions", False), ) pooled = outputs.last_hidden_state[:, 0] pooled = self.dropout(pooled) logits_binary = self.classifier_binary(pooled) logits_multiclass = self.classifier_multiclass(pooled) loss = None if labels_binary is not None and labels_multiclass is not None: loss_binary = nn.CrossEntropyLoss()(logits_binary, labels_binary) loss_multiclass = nn.CrossEntropyLoss(ignore_index=-1)( logits_multiclass, labels_multiclass ) loss = loss_binary + 0.5 * loss_multiclass return SuaveMultitaskOutput( loss=loss, logits_binary=logits_binary, logits_multiclass=logits_multiclass, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )