| import torch | |
| import torch.nn as nn | |
| from .vq import VectorQuantizer | |
| class ResidualVectorQuantizer(nn.Module): | |
| """ References: | |
| SoundStream: An End-to-End Neural Audio Codec | |
| https://arxiv.org/pdf/2107.03312.pdf | |
| """ | |
| def __init__(self, n_e_list, e_dim, sk_epsilons, | |
| kmeans_init = False, kmeans_iters = 100, sk_iters=100,): | |
| super().__init__() | |
| self.n_e_list = n_e_list | |
| self.e_dim = e_dim | |
| self.num_quantizers = len(n_e_list) | |
| self.kmeans_init = kmeans_init | |
| self.kmeans_iters = kmeans_iters | |
| self.sk_epsilons = sk_epsilons | |
| self.sk_iters = sk_iters | |
| self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim, | |
| kmeans_init = self.kmeans_init, | |
| kmeans_iters = self.kmeans_iters, | |
| sk_epsilon=sk_epsilon, | |
| sk_iters=sk_iters) | |
| for n_e, sk_epsilon in zip(n_e_list,sk_epsilons) ]) | |
| def get_codebook(self): | |
| all_codebook = [] | |
| for quantizer in self.vq_layers: | |
| codebook = quantizer.get_codebook() | |
| all_codebook.append(codebook) | |
| return torch.stack(all_codebook) | |
| def forward(self, x, use_sk=True): | |
| all_losses = [] | |
| all_indices = [] | |
| x_q = 0 | |
| residual = x | |
| for quantizer in self.vq_layers: | |
| x_res, loss, indices = quantizer(residual, use_sk=use_sk) | |
| residual = residual - x_res | |
| x_q = x_q + x_res | |
| all_losses.append(loss) | |
| all_indices.append(indices) | |
| mean_losses = torch.stack(all_losses).mean() | |
| all_indices = torch.stack(all_indices, dim=-1) | |
| return x_q, mean_losses, all_indices |