File size: 3,261 Bytes
6044fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8fa43f
6044fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from torch import nn
import math
class RelativePositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, max_len: int = 5000):
        super(RelativePositionalEncoding, self).__init__()
        self.emb_size = emb_size
        self.max_len = max_len
        relative_positions = torch.arange(-max_len, max_len + 1, dtype=torch.long)
        scales = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        relative_positional_encodings = torch.zeros((2 * max_len + 1, emb_size))
        relative_positional_encodings[:, 0::2] = torch.sin(relative_positions.unsqueeze(-1) * scales)
        relative_positional_encodings[:, 1::2] = torch.cos(relative_positions.unsqueeze(-1) * scales)
        self.register_buffer('relative_positional_encodings', relative_positional_encodings)

    def forward(self, length: int):
        center_pos = self.max_len
        return self.relative_positional_encodings[center_pos - length + 1 : center_pos + 1]

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: torch.Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class TransformerModelRelative(nn.Module):
    def __init__(self, num_tokens_en, num_tokens_fr, embed_size, nhead, dim_feedforward, max_seq_length):
        super(TransformerModel, self).__init__()
        self.embed_size = embed_size
        self.src_tok_emb = TokenEmbedding(num_tokens_en, embed_size)
        self.tgt_tok_emb = TokenEmbedding(num_tokens_fr, embed_size)
        self.positional_encoding = RelativePositionalEncoding(embed_size, max_len=max_seq_length)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)

        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=3)
        self.generator = nn.Linear(embed_size, num_tokens_fr)

    def encode(self, src, src_mask):
        src_emb = self.src_tok_emb(src) + self.positional_encoding(src.size(1))
        return self.transformer_encoder(src_emb, src_key_padding_mask=src_mask)

    def decode(self, tgt, memory, tgt_mask, tgt_key_padding_mask):
        tgt_emb = self.tgt_tok_emb(tgt) + self.positional_encoding(tgt.size(1))
        return self.transformer_decoder(tgt_emb, memory, tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        memory = self.encode(src, src_padding_mask)
        output = self.decode(tgt, memory, tgt_mask, tgt_padding_mask)
        return self.generator(output)

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz)).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask