File size: 966 Bytes
f3b11f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn as nn

from models.transformer.encode_decode.clones import clones
from models.transformer.encode_decode.sublayer_connection import SublayerConnection

class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."

        m = memory    

        x = self.sublayer[0](x, lambda x: self.self_attn(
            x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(
            x, m, m, src_mask))

        return self.sublayer[2](x, self.feed_forward)