Zolisa commited on
Commit
d432f40
·
verified ·
1 Parent(s): 42afbda

Upload models/transformer_imdb.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/transformer_imdb.py +24 -0
models/transformer_imdb.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class TransformerClassifier(nn.Module):
6
+ def __init__(self, vocab_size, embed_dim=128, num_heads=8, num_layers=2, num_classes=2, max_len=256):
7
+ super(TransformerClassifier, self).__init__()
8
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
9
+ self.pos_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim))
10
+
11
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
12
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
13
+
14
+ self.fc = nn.Linear(embed_dim, num_classes)
15
+
16
+ def forward(self, x):
17
+ # x: (batch, seq_len)
18
+ seq_len = x.size(1)
19
+ x = self.embedding(x) + self.pos_encoding[:, :seq_len, :]
20
+ x = self.transformer(x)
21
+ # Global Average Pooling over the sequence
22
+ x = x.mean(dim=1)
23
+ x = self.fc(x)
24
+ return x