File size: 2,658 Bytes
383bfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
from utils.utils import sequence_mask
import math

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
            x: [B, max_len, d_model]
            pe: [1, max_len, d_model]
        """
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

class LearnedPositionEncoding(nn.Module):

    def __init__(self, d_model, max_len = 20):
        super(LearnedPositionEncoding, self).__init__()
        self.embedding = nn.Embedding(max_len, d_model)

    def forward(self, x, var_pos):
        """
            x: [B, max_len, d_model]
            var_pos: [B, var_len]
        """
        loc_mat = torch.zeros(x.size(0), x.size(1), dtype=torch.int64).cuda()
        pos_id = torch.arange(1, var_pos.size(1)+1).repeat(var_pos.size(0), 1).cuda()
        pos_id[var_pos==var_pos.min()] = 0
        loc_mat.scatter_(1, var_pos, pos_id)

        x = x + self.embedding(loc_mat)

        return x

class TransformerEncoder(nn.Module):

    def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.2):
        super(TransformerEncoder,self).__init__()

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        encoder_norm = nn.LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
        self.position = PositionalEncoding(d_model=d_model)
        
        self._reset_parameters()
        self.d_model = d_model
        self.nhead = nhead
    
    def _reset_parameters(self):
        """
            Initiate parameters in the transformer model.
        """
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, len_src, emb_src):
        # mask
        src_key_padding_mask = ~sequence_mask(len_src)
        # position encoding
        emb_src = self.position(emb_src) 
        # encoder   
        memory = self.encoder(emb_src.permute(1,0,2), src_key_padding_mask=src_key_padding_mask)

        return memory.permute(1,0,2)