Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class PositionalEmbedding(nn.Module): | |
| """ | |
| Positional Embedding | |
| shapes: | |
| N: batch size | |
| L: seq len (max seq len of batch) | |
| E: embedding dim | |
| max_seq_len: max seq len across all samples | |
| forward args: | |
| X: batch of semantic embeddings (N, L, E) | |
| """ | |
| def __init__(self, emb_dim, max_seq_len, dropout_p=0.1): | |
| super().__init__() | |
| # full embedding matrix with shape (maximum_sample_lenght, embedding_dim) | |
| self.pos_embedding = nn.Parameter(torch.randn(max_seq_len, emb_dim) * 0.01) | |
| self.dropout = nn.Dropout(dropout_p) | |
| def forward(self, X): | |
| # sliced for current batch max sequence lenght | |
| emb_matrix = self.pos_embedding[:X.size(1)].unsqueeze(0) # (1, L, E) | |
| return self.dropout(X + emb_matrix) # (N, L, E) | |
| class TransformerNMT(nn.Module): | |
| """ | |
| forward args: | |
| src_ids: (N, S) token ids | |
| tgt_ids: (N, L) token ids | |
| src_key_padding_mask: (N, S) bool, True=PAD (ignored) | |
| tgt_key_padding_mask: (N, L) bool, True=PAD (ignored) | |
| """ | |
| def __init__(self, vocab_size, max_seq_len, d_model=512, nhead=4, | |
| num_encoder_layers=2, num_decoder_layers=2, | |
| dim_feedforward=2048, dropout=0.1, padding_idx=0): | |
| super().__init__() | |
| self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx = padding_idx) | |
| self.positional_embedding = PositionalEmbedding(d_model, max_seq_len) | |
| self.transformer = nn.Transformer(d_model, nhead, | |
| num_encoder_layers, num_decoder_layers, | |
| dim_feedforward, dropout, | |
| activation="relu", batch_first=True, | |
| norm_first=False, bias=True) | |
| self.output = nn.Linear(d_model, vocab_size, bias=False) | |
| # weight tying | |
| self.output.weight = self.shared_embedding.weight | |
| def forward(self, src_ids, tgt_ids, src_key_padding_mask, tgt_key_padding_mask): | |
| src = self.positional_embedding(self.shared_embedding(src_ids)) # (N, S, E) | |
| tgt = self.positional_embedding(self.shared_embedding(tgt_ids)) # (N, L, E) | |
| # create target causal mask | |
| L = tgt.size(1) | |
| causal_mask = nn.Transformer.generate_square_subsequent_mask(L, dtype=torch.bool, device = tgt.device) | |
| out = self.transformer(src = src , tgt = tgt, | |
| src_key_padding_mask = src_key_padding_mask, | |
| tgt_key_padding_mask = tgt_key_padding_mask, | |
| memory_key_padding_mask = src_key_padding_mask, | |
| tgt_mask = causal_mask | |
| ) # (N, L, E) | |
| return self.output(out).transpose(-2,-1) # (N, vocab_size, L) | |