# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from fairseq.models.transformer import Linear class StackedEmbedding(nn.Embedding): """Embedding module that supports stacked units -> single embedding""" def __init__(self, num_embeddings, embed_dim, padding_idx, num_stacked=1): super().__init__(num_embeddings, embed_dim, padding_idx) # follow transformer.Embedding nn.init.normal_(self.weight, mean=0, std=embed_dim**-0.5) nn.init.constant_(self.weight[padding_idx], 0) self.offset = ( 4 # skip , , , , specific to fairseq dictionary ) self.vocab_size = num_embeddings - self.offset self.num_stacked = num_stacked if self.num_stacked > 1: self.project_in_dim = Linear(embed_dim * num_stacked, embed_dim, bias=False) def forward(self, input): if self.num_stacked == 1: return super().forward(input) # expand input indices mask = input >= self.offset stacked_input = [] cum_input = input.new_zeros(input.shape) for i in range(1, self.num_stacked + 1): div = pow(self.vocab_size, i) next_input = torch.remainder(input - self.offset - cum_input, div) cum_input += next_input next_input = torch.floor_divide(next_input, div // self.vocab_size) stacked_input.append((next_input + self.offset) * mask + input * ~mask) stacked_input = torch.stack(stacked_input[::-1], dim=2) embed = super().forward(stacked_input).view(input.size(0), input.size(1), -1) embed = self.project_in_dim(embed) return embed