import torch import torch.nn as nn import torch.nn.functional as F from transformers import RobertaPreTrainedModel, RobertaModel from .configuration_emoaxis import EmoAxisConfig class EmoAxis(RobertaPreTrainedModel): config_class = EmoAxisConfig def __init__(self, config): super().__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.mlp = nn.Sequential( nn.Linear(config.hidden_size, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.25), nn.Linear(512, config.num_classes) ) self.post_init() def forward(self, input_ids=None, attention_mask=None, **kwargs): outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) last_hidden_state = outputs.hidden_states[-1] mask = attention_mask.unsqueeze(-1).float() text_emb = (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) text_emb = F.normalize(text_emb, p=2, dim=1) logits = self.mlp(text_emb) return text_emb, logits