Emo-Axis02 / modeling.py
Subi003's picture
Initial upload of EmoAxis model
f2c9a1a verified
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModel, AutoConfig
class EmoAxis(PreTrainedModel):
config_class = AutoConfig
def __init__(self, config):
super().__init__(config)
self.encoder = AutoModel.from_pretrained(config._name_or_path)
def forward(self, input_ids=None, attention_mask=None):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
last_hidden_state = outputs.hidden_states[-1]
mask = attention_mask.unsqueeze(-1).float()
text_emb = (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
text_emb = F.normalize(text_emb, p=2, dim=1)
return text_emb