Spaces:
Sleeping
Sleeping
File size: 598 Bytes
f3b11f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import torch.nn as nn
from models.transformer.encode_decode.clones import clones
from models.transformer.encode_decode.layer_norm import LayerNorm
class Decoder(nn.Module):
"Generic N layer decoder with masking."
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
memory = memory
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
|