File size: 996 Bytes
c5bcbe7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import torch
import torch.nn as nn
from transformers import AutoModel, DebertaV2Tokenizer, AutoConfig
class SentiNetTransformer(nn.Module):
"""Sentiment classifier built on top of a pretrained Transformer backbone."""
def __init__(self, model_path: str, fc_dropout: float = 0.1):
super().__init__()
config = AutoConfig.from_pretrained(model_path)
self.transformer = AutoModel.from_config(config)
hidden_dim = self.transformer.config.hidden_size
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(fc_dropout)
)
self.output = nn.Linear(hidden_dim, 1)
def forward(self, encodings: dict):
transformer_outputs = self.transformer(**encodings) # last_hidden_state: (N, L, H)
cls_embedding = transformer_outputs.last_hidden_state[:, 0, :] # CLS token: (N, H)
x = self.fc(cls_embedding) # (N, H)
return self.output(x) # (N, 1) |