Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from torch import nn | |
| from fairseq.models.transformer import Linear | |
| class StackedEmbedding(nn.Embedding): | |
| """Embedding module that supports stacked units -> single embedding""" | |
| def __init__(self, num_embeddings, embed_dim, padding_idx, num_stacked=1): | |
| super().__init__(num_embeddings, embed_dim, padding_idx) | |
| # follow transformer.Embedding | |
| nn.init.normal_(self.weight, mean=0, std=embed_dim**-0.5) | |
| nn.init.constant_(self.weight[padding_idx], 0) | |
| self.offset = ( | |
| 4 # skip <bos>, <pad>, <eos>, <unk>, specific to fairseq dictionary | |
| ) | |
| self.vocab_size = num_embeddings - self.offset | |
| self.num_stacked = num_stacked | |
| if self.num_stacked > 1: | |
| self.project_in_dim = Linear(embed_dim * num_stacked, embed_dim, bias=False) | |
| def forward(self, input): | |
| if self.num_stacked == 1: | |
| return super().forward(input) | |
| # expand input indices | |
| mask = input >= self.offset | |
| stacked_input = [] | |
| cum_input = input.new_zeros(input.shape) | |
| for i in range(1, self.num_stacked + 1): | |
| div = pow(self.vocab_size, i) | |
| next_input = torch.remainder(input - self.offset - cum_input, div) | |
| cum_input += next_input | |
| next_input = torch.floor_divide(next_input, div // self.vocab_size) | |
| stacked_input.append((next_input + self.offset) * mask + input * ~mask) | |
| stacked_input = torch.stack(stacked_input[::-1], dim=2) | |
| embed = super().forward(stacked_input).view(input.size(0), input.size(1), -1) | |
| embed = self.project_in_dim(embed) | |
| return embed | |