File size: 996 Bytes
c5bcbe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
from transformers import AutoModel, DebertaV2Tokenizer, AutoConfig

class SentiNetTransformer(nn.Module):
    """Sentiment classifier built on top of a pretrained Transformer backbone."""
    
    def __init__(self, model_path: str, fc_dropout: float = 0.1):
        super().__init__()
        config = AutoConfig.from_pretrained(model_path)
        self.transformer = AutoModel.from_config(config)
        hidden_dim = self.transformer.config.hidden_size

        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(fc_dropout)
        )
        self.output = nn.Linear(hidden_dim, 1)

    def forward(self, encodings: dict):
        transformer_outputs = self.transformer(**encodings)  # last_hidden_state: (N, L, H)
        cls_embedding = transformer_outputs.last_hidden_state[:, 0, :]  # CLS token: (N, H)
        x = self.fc(cls_embedding)  # (N, H)
        return self.output(x)  # (N, 1)