Spaces:
Sleeping
Sleeping
| # src/model.py | |
| import torch | |
| import torch.nn as nn | |
| class TransformerModel(nn.Module): | |
| def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, dropout=0.1): | |
| super(TransformerModel, self).__init__() | |
| self.embed_size = embed_size | |
| self.token_embedding = nn.Embedding(vocab_size, embed_size) | |
| self.position_embedding = nn.Embedding(5000, embed_size) # Max sequence length | |
| encoder_layers = nn.TransformerEncoderLayer( | |
| d_model=embed_size, | |
| nhead=num_heads, | |
| dim_feedforward=hidden_dim, | |
| dropout=dropout | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers) | |
| self.fc_out = nn.Linear(embed_size, vocab_size) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, src, src_mask): | |
| batch_size, seq_length = src.size() | |
| positions = torch.arange(0, seq_length).unsqueeze(0).repeat(batch_size, 1).to(src.device) | |
| x = self.token_embedding(src) + self.position_embedding(positions) | |
| x = self.dropout(x) | |
| x = x.permute(1, 0, 2) # Transformer expects [seq_length, batch_size, embed_size] | |
| transformer_out = self.transformer_encoder(x, src_mask) | |
| transformer_out = transformer_out.permute(1, 0, 2) | |
| logits = self.fc_out(transformer_out) | |
| return logits | |
| def generate_square_subsequent_mask(self, sz): | |
| mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
| mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
| return mask | |