SuaveAI-Dectection-Multitask-Model-V1 / modeling_suave_multitask.py
DaJulster's picture
Upload folder using huggingface_hub
e90dc4c verified
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_suave_multitask import SuaveMultitaskConfig
@dataclass
class SuaveMultitaskOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits_binary: Optional[torch.FloatTensor] = None
logits_multiclass: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class SuaveMultitaskModel(PreTrainedModel):
config_class = SuaveMultitaskConfig
base_model_prefix = "encoder"
def __init__(self, config: SuaveMultitaskConfig):
super().__init__(config)
base_config = AutoConfig.from_pretrained(config.base_model_name)
self.encoder = AutoModel.from_config(base_config)
hidden_size = self.encoder.config.hidden_size
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier_binary = nn.Linear(hidden_size, 2)
self.classifier_multiclass = nn.Linear(hidden_size, config.num_ai_classes)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
labels_binary=None,
labels_multiclass=None,
**kwargs,
):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=kwargs.get("output_hidden_states", False),
output_attentions=kwargs.get("output_attentions", False),
)
pooled = outputs.last_hidden_state[:, 0]
pooled = self.dropout(pooled)
logits_binary = self.classifier_binary(pooled)
logits_multiclass = self.classifier_multiclass(pooled)
loss = None
if labels_binary is not None and labels_multiclass is not None:
loss_binary = nn.CrossEntropyLoss()(logits_binary, labels_binary)
loss_multiclass = nn.CrossEntropyLoss(ignore_index=-1)(
logits_multiclass, labels_multiclass
)
loss = loss_binary + 0.5 * loss_multiclass
return SuaveMultitaskOutput(
loss=loss,
logits_binary=logits_binary,
logits_multiclass=logits_multiclass,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)