| | from typing import Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from torch.nn.utils import weight_norm |
| |
|
| | from .layers import WNConv1d |
| |
|
| |
|
| | class VectorQuantize(nn.Module): |
| | """ |
| | Implementation of VQ similar to Karpathy's repo: |
| | https://github.com/karpathy/deep-vector-quantization |
| | Additionally uses following tricks from Improved VQGAN |
| | (https://arxiv.org/pdf/2110.04627.pdf): |
| | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space |
| | for improved codebook usage |
| | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which |
| | improves training stability |
| | """ |
| |
|
| | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): |
| | super().__init__() |
| | self.codebook_size = codebook_size |
| | self.codebook_dim = codebook_dim |
| |
|
| | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) |
| | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) |
| | self.codebook = nn.Embedding(codebook_size, codebook_dim) |
| |
|
| | def forward(self, z): |
| | """Quantized the input tensor using a fixed codebook and returns |
| | the corresponding codebook vectors |
| | |
| | Parameters |
| | ---------- |
| | z : Tensor[B x D x T] |
| | |
| | Returns |
| | ------- |
| | Tensor[B x D x T] |
| | Quantized continuous representation of input |
| | Tensor[1] |
| | Commitment loss to train encoder to predict vectors closer to codebook |
| | entries |
| | Tensor[1] |
| | Codebook loss to update the codebook |
| | Tensor[B x T] |
| | Codebook indices (quantized discrete representation of input) |
| | Tensor[B x D x T] |
| | Projected latents (continuous representation of input before quantization) |
| | """ |
| |
|
| | |
| | z_e = self.in_proj(z) |
| | z_q, indices = self.decode_latents(z_e) |
| |
|
| | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) |
| | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) |
| |
|
| | z_q = ( |
| | z_e + (z_q - z_e).detach() |
| | ) |
| |
|
| | z_q = self.out_proj(z_q) |
| |
|
| | return z_q, commitment_loss, codebook_loss, indices, z_e |
| |
|
| | def embed_code(self, embed_id): |
| | return F.embedding(embed_id, self.codebook.weight) |
| |
|
| | def decode_code(self, embed_id): |
| | return self.embed_code(embed_id).transpose(1, 2) |
| |
|
| | def decode_latents(self, latents): |
| | encodings = rearrange(latents, "b d t -> (b t) d") |
| | codebook = self.codebook.weight |
| |
|
| | |
| | encodings = F.normalize(encodings) |
| | codebook = F.normalize(codebook) |
| |
|
| | |
| | dist = ( |
| | encodings.pow(2).sum(1, keepdim=True) |
| | - 2 * encodings @ codebook.t() |
| | + codebook.pow(2).sum(1, keepdim=True).t() |
| | ) |
| | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) |
| | z_q = self.decode_code(indices) |
| | return z_q, indices |
| |
|
| |
|
| | class ResidualVectorQuantize(nn.Module): |
| | """ |
| | Introduced in SoundStream: An end2end neural audio codec |
| | https://arxiv.org/abs/2107.03312 |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int = 512, |
| | n_codebooks: int = 9, |
| | codebook_size: int = 1024, |
| | codebook_dim: Union[int, list] = 8, |
| | quantizer_dropout: float = 0.0, |
| | ): |
| | super().__init__() |
| | if isinstance(codebook_dim, int): |
| | codebook_dim = [codebook_dim for _ in range(n_codebooks)] |
| |
|
| | self.n_codebooks = n_codebooks |
| | self.codebook_dim = codebook_dim |
| | self.codebook_size = codebook_size |
| |
|
| | self.quantizers = nn.ModuleList( |
| | [ |
| | VectorQuantize(input_dim, codebook_size, codebook_dim[i]) |
| | for i in range(n_codebooks) |
| | ] |
| | ) |
| | self.quantizer_dropout = quantizer_dropout |
| |
|
| | def forward(self, z, n_quantizers: int = None): |
| | """Quantized the input tensor using a fixed set of `n` codebooks and returns |
| | the corresponding codebook vectors |
| | Parameters |
| | ---------- |
| | z : Tensor[B x D x T] |
| | n_quantizers : int, optional |
| | No. of quantizers to use |
| | (n_quantizers < self.n_codebooks ex: for quantizer dropout) |
| | Note: if `self.quantizer_dropout` is True, this argument is ignored |
| | when in training mode, and a random number of quantizers is used. |
| | Returns |
| | ------- |
| | dict |
| | A dictionary with the following keys: |
| | |
| | "z" : Tensor[B x D x T] |
| | Quantized continuous representation of input |
| | "codes" : Tensor[B x N x T] |
| | Codebook indices for each codebook |
| | (quantized discrete representation of input) |
| | "latents" : Tensor[B x N*D x T] |
| | Projected latents (continuous representation of input before quantization) |
| | "vq/commitment_loss" : Tensor[1] |
| | Commitment loss to train encoder to predict vectors closer to codebook |
| | entries |
| | "vq/codebook_loss" : Tensor[1] |
| | Codebook loss to update the codebook |
| | """ |
| | z_q = 0 |
| | residual = z |
| | commitment_loss = 0 |
| | codebook_loss = 0 |
| |
|
| | codebook_indices = [] |
| | latents = [] |
| |
|
| | if n_quantizers is None: |
| | n_quantizers = self.n_codebooks |
| | if self.training: |
| | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 |
| | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) |
| | n_dropout = int(z.shape[0] * self.quantizer_dropout) |
| | n_quantizers[:n_dropout] = dropout[:n_dropout] |
| | n_quantizers = n_quantizers.to(z.device) |
| |
|
| | for i, quantizer in enumerate(self.quantizers): |
| | if self.training is False and i >= n_quantizers: |
| | break |
| |
|
| | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( |
| | residual |
| | ) |
| |
|
| | |
| | mask = ( |
| | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers |
| | ) |
| | z_q = z_q + z_q_i * mask[:, None, None] |
| | residual = residual - z_q_i |
| |
|
| | |
| | commitment_loss += (commitment_loss_i * mask).mean() |
| | codebook_loss += (codebook_loss_i * mask).mean() |
| |
|
| | codebook_indices.append(indices_i) |
| | latents.append(z_e_i) |
| |
|
| | codes = torch.stack(codebook_indices, dim=1) |
| | latents = torch.cat(latents, dim=1) |
| |
|
| | return z_q, codes, latents, commitment_loss, codebook_loss |
| |
|
| | def from_codes(self, codes: torch.Tensor): |
| | """Given the quantized codes, reconstruct the continuous representation |
| | Parameters |
| | ---------- |
| | codes : Tensor[B x N x T] |
| | Quantized discrete representation of input |
| | Returns |
| | ------- |
| | Tensor[B x D x T] |
| | Quantized continuous representation of input |
| | """ |
| | z_q = 0.0 |
| | z_p = [] |
| | n_codebooks = codes.shape[1] |
| | for i in range(n_codebooks): |
| | z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) |
| | z_p.append(z_p_i) |
| |
|
| | z_q_i = self.quantizers[i].out_proj(z_p_i) |
| | z_q = z_q + z_q_i |
| | return z_q, torch.cat(z_p, dim=1), codes |
| |
|
| | def from_latents(self, latents: torch.Tensor): |
| | """Given the unquantized latents, reconstruct the |
| | continuous representation after quantization. |
| | |
| | Parameters |
| | ---------- |
| | latents : Tensor[B x N x T] |
| | Continuous representation of input after projection |
| | |
| | Returns |
| | ------- |
| | Tensor[B x D x T] |
| | Quantized representation of full-projected space |
| | Tensor[B x D x T] |
| | Quantized representation of latent space |
| | """ |
| | z_q = 0 |
| | z_p = [] |
| | codes = [] |
| | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) |
| |
|
| | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ |
| | 0 |
| | ] |
| | for i in range(n_codebooks): |
| | j, k = dims[i], dims[i + 1] |
| | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) |
| | z_p.append(z_p_i) |
| | codes.append(codes_i) |
| |
|
| | z_q_i = self.quantizers[i].out_proj(z_p_i) |
| | z_q = z_q + z_q_i |
| |
|
| | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | rvq = ResidualVectorQuantize(quantizer_dropout=True) |
| | x = torch.randn(16, 512, 80) |
| | y = rvq(x) |
| | print(y["latents"].shape) |
| |
|