import torch import torch.nn as nn from .vq import VectorQuantizer class ResidualVectorQuantizer(nn.Module): """ References: SoundStream: An End-to-End Neural Audio Codec https://arxiv.org/pdf/2107.03312.pdf """ def __init__(self, n_e_list, e_dim, sk_epsilons, kmeans_init = False, kmeans_iters = 100, sk_iters=100,): super().__init__() self.n_e_list = n_e_list self.e_dim = e_dim self.num_quantizers = len(n_e_list) self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters self.sk_epsilons = sk_epsilons self.sk_iters = sk_iters self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim, kmeans_init = self.kmeans_init, kmeans_iters = self.kmeans_iters, sk_epsilon=sk_epsilon, sk_iters=sk_iters) for n_e, sk_epsilon in zip(n_e_list,sk_epsilons) ]) def get_codebook(self): all_codebook = [] for quantizer in self.vq_layers: codebook = quantizer.get_codebook() all_codebook.append(codebook) return torch.stack(all_codebook) def forward(self, x, use_sk=True): all_losses = [] all_indices = [] x_q = 0 residual = x for quantizer in self.vq_layers: x_res, loss, indices = quantizer(residual, use_sk=use_sk) residual = residual - x_res x_q = x_q + x_res all_losses.append(loss) all_indices.append(indices) mean_losses = torch.stack(all_losses).mean() all_indices = torch.stack(all_indices, dim=-1) return x_q, mean_losses, all_indices