Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from models.transformer.encode_decode.clones import clones | |
| from models.transformer.encode_decode.layer_norm import LayerNorm | |
| class Decoder(nn.Module): | |
| "Generic N layer decoder with masking." | |
| def __init__(self, layer, N): | |
| super(Decoder, self).__init__() | |
| self.layers = clones(layer, N) | |
| self.norm = LayerNorm(layer.size) | |
| def forward(self, x, memory, src_mask, tgt_mask): | |
| memory = memory | |
| for layer in self.layers: | |
| x = layer(x, memory, src_mask, tgt_mask) | |
| return self.norm(x) | |