File size: 1,131 Bytes
8baa173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaPreTrainedModel, RobertaModel
from .configuration_emoaxis import EmoAxisConfig

class EmoAxis(RobertaPreTrainedModel):
    config_class = EmoAxisConfig
    def __init__(self, config):
        super().__init__(config)
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.25),
            nn.Linear(512, config.num_classes)
        )
        self.post_init()
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        mask = attention_mask.unsqueeze(-1).float()
        text_emb = (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
        text_emb = F.normalize(text_emb, p=2, dim=1)
        logits = self.mlp(text_emb)
        return text_emb, logits