import torch import torch.nn.functional as F from transformers import PreTrainedModel, AutoModel, AutoConfig class EmoAxis(PreTrainedModel): config_class = AutoConfig def __init__(self, config): super().__init__(config) self.encoder = AutoModel.from_pretrained(config._name_or_path) def forward(self, input_ids=None, attention_mask=None): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True ) last_hidden_state = outputs.hidden_states[-1] mask = attention_mask.unsqueeze(-1).float() text_emb = (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) text_emb = F.normalize(text_emb, p=2, dim=1) return text_emb