multi-modal-pytorch-app / models /transformer_imdb.py
Zolisa's picture
Upload folder using huggingface_hub
52b1c1a verified
raw
history blame contribute delete
948 Bytes
import torch
import torch.nn as nn
import math
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim=128, num_heads=8, num_layers=2, num_classes=2, max_len=256):
super(TransformerClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, x):
# x: (batch, seq_len)
seq_len = x.size(1)
x = self.embedding(x) + self.pos_encoding[:, :seq_len, :]
x = self.transformer(x)
# Global Average Pooling over the sequence
x = x.mean(dim=1)
x = self.fc(x)
return x