File size: 820 Bytes
f2c9a1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModel, AutoConfig

class EmoAxis(PreTrainedModel):
    config_class = AutoConfig

    def __init__(self, config):
        super().__init__(config)
        self.encoder = AutoModel.from_pretrained(config._name_or_path)

    def forward(self, input_ids=None, attention_mask=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        last_hidden_state = outputs.hidden_states[-1]
        mask = attention_mask.unsqueeze(-1).float()
        text_emb = (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
        text_emb = F.normalize(text_emb, p=2, dim=1)
        return text_emb