Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| class TransformerClassifier(nn.Module): | |
| def __init__(self, vocab_size, embed_dim=128, num_heads=8, num_layers=2, num_classes=2, max_len=256): | |
| super(TransformerClassifier, self).__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.pos_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim)) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| self.fc = nn.Linear(embed_dim, num_classes) | |
| def forward(self, x): | |
| # x: (batch, seq_len) | |
| seq_len = x.size(1) | |
| x = self.embedding(x) + self.pos_encoding[:, :seq_len, :] | |
| x = self.transformer(x) | |
| # Global Average Pooling over the sequence | |
| x = x.mean(dim=1) | |
| x = self.fc(x) | |
| return x | |