Spaces:
Sleeping
Sleeping
File size: 966 Bytes
f3b11f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import torch.nn as nn
from models.transformer.encode_decode.clones import clones
from models.transformer.encode_decode.sublayer_connection import SublayerConnection
class DecoderLayer(nn.Module):
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3)
def forward(self, x, memory, src_mask, tgt_mask):
"Follow Figure 1 (right) for connections."
m = memory
x = self.sublayer[0](x, lambda x: self.self_attn(
x, x, x, tgt_mask))
x = self.sublayer[1](x, lambda x: self.src_attn(
x, m, m, src_mask))
return self.sublayer[2](x, self.feed_forward)
|