| import torch | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel, AutoModel, AutoConfig | |
| class EmoAxis(PreTrainedModel): | |
| config_class = AutoConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.encoder = AutoModel.from_pretrained(config._name_or_path) | |
| def forward(self, input_ids=None, attention_mask=None): | |
| outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True | |
| ) | |
| last_hidden_state = outputs.hidden_states[-1] | |
| mask = attention_mask.unsqueeze(-1).float() | |
| text_emb = (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) | |
| text_emb = F.normalize(text_emb, p=2, dim=1) | |
| return text_emb |