Yixuan Li
add fairseq folder
85ba398
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from fairseq.models.transformer import Linear
class StackedEmbedding(nn.Embedding):
"""Embedding module that supports stacked units -> single embedding"""
def __init__(self, num_embeddings, embed_dim, padding_idx, num_stacked=1):
super().__init__(num_embeddings, embed_dim, padding_idx)
# follow transformer.Embedding
nn.init.normal_(self.weight, mean=0, std=embed_dim**-0.5)
nn.init.constant_(self.weight[padding_idx], 0)
self.offset = (
4 # skip <bos>, <pad>, <eos>, <unk>, specific to fairseq dictionary
)
self.vocab_size = num_embeddings - self.offset
self.num_stacked = num_stacked
if self.num_stacked > 1:
self.project_in_dim = Linear(embed_dim * num_stacked, embed_dim, bias=False)
def forward(self, input):
if self.num_stacked == 1:
return super().forward(input)
# expand input indices
mask = input >= self.offset
stacked_input = []
cum_input = input.new_zeros(input.shape)
for i in range(1, self.num_stacked + 1):
div = pow(self.vocab_size, i)
next_input = torch.remainder(input - self.offset - cum_input, div)
cum_input += next_input
next_input = torch.floor_divide(next_input, div // self.vocab_size)
stacked_input.append((next_input + self.offset) * mask + input * ~mask)
stacked_input = torch.stack(stacked_input[::-1], dim=2)
embed = super().forward(stacked_input).view(input.size(0), input.size(1), -1)
embed = self.project_in_dim(embed)
return embed