| # Copyright (c) ByteDance, Inc. and its affiliates. | |
| # Copyright (c) Chutong Meng | |
| # | |
| # This source code is licensed under the CC BY-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Based on AudioDec (https://github.com/facebookresearch/AudioDec) | |
| import torch.nn as nn | |
| from repcodec.layers.vq_module import ResidualVQ | |
| class Quantizer(nn.Module): | |
| def __init__( | |
| self, | |
| code_dim: int, | |
| codebook_num: int, | |
| codebook_size: int, | |
| ): | |
| super().__init__() | |
| self.codebook = ResidualVQ( | |
| dim=code_dim, | |
| num_quantizers=codebook_num, | |
| codebook_size=codebook_size | |
| ) | |
| def initial(self): | |
| self.codebook.initial() | |
| def forward(self, z): | |
| zq, vqloss, perplexity = self.codebook(z.transpose(2, 1)) | |
| zq = zq.transpose(2, 1) | |
| return zq, vqloss, perplexity | |
| def inference(self, z): | |
| zq, indices = self.codebook.forward_index(z.transpose(2, 1)) | |
| zq = zq.transpose(2, 1) | |
| return zq, indices | |
| def encode(self, z): | |
| zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True) | |
| return zq, indices | |
| def decode(self, indices): | |
| z = self.codebook.lookup(indices) | |
| return z | |