File size: 2,162 Bytes
e663138
 
738d2af
e663138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
from src.constants.tokens import PAD_ID

class TinyTransformer(nn.Module):
  def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=2, dim_feedforward=512, dropout=0.1):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD_ID)
    self.pos_encoder = PositionalEncoding(d_model, dropout)
    encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
    self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
    self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
    self.out = nn.Linear(d_model, vocab_size)

  # Keep tensors in batch-first format
  def forward(self, src, tgt):
    tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(src.device).bool()
    src_emb = self.pos_encoder(self.embedding(src))
    tgt_emb = self.pos_encoder(self.embedding(tgt))
    # Create padding masks
    src_padding_mask = (src == PAD_ID).bool()
    tgt_padding_mask = (tgt == PAD_ID).bool()
    memory = self.encoder(src_emb, src_key_padding_mask=src_padding_mask)
    output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask)
    return self.out(output)  # (batch, seq_len, vocab)

  def generate_src_mask(self, size):
    return torch.zeros((size, size), device='cpu').type(torch.bool)

class PositionalEncoding(nn.Module):
  def __init__(self, d_model, dropout=0.1, max_len=512):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(
      torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model)
    )
    pe = torch.zeros(max_len, d_model)
    pe[:, 0::2] = torch.sin(position * div_term)  # even indices
    pe[:, 1::2] = torch.cos(position * div_term)  # odd indices
    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self, x):
    x = x + self.pe[:, :x.size(1), :].to(x.device)
    return self.dropout(x)