| from typing import Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from torch.nn.utils import weight_norm |
|
|
| from dac.nn.layers import WNConv1d |
|
|
| class VectorQuantizeLegacy(nn.Module): |
| """ |
| Implementation of VQ similar to Karpathy's repo: |
| https://github.com/karpathy/deep-vector-quantization |
| removed in-out projection |
| """ |
|
|
| def __init__(self, input_dim: int, codebook_size: int): |
| super().__init__() |
| self.codebook_size = codebook_size |
| self.codebook = nn.Embedding(codebook_size, input_dim) |
|
|
| def forward(self, z, z_mask=None): |
| """Quantized the input tensor using a fixed codebook and returns |
| the corresponding codebook vectors |
| |
| Parameters |
| ---------- |
| z : Tensor[B x D x T] |
| |
| Returns |
| ------- |
| Tensor[B x D x T] |
| Quantized continuous representation of input |
| Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| Tensor[1] |
| Codebook loss to update the codebook |
| Tensor[B x T] |
| Codebook indices (quantized discrete representation of input) |
| Tensor[B x D x T] |
| Projected latents (continuous representation of input before quantization) |
| """ |
|
|
| z_e = z |
| z_q, indices = self.decode_latents(z) |
|
|
| if z_mask is not None: |
| commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum() |
| codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum() |
| else: |
| commitment_loss = F.mse_loss(z_e, z_q.detach()) |
| codebook_loss = F.mse_loss(z_q, z_e.detach()) |
| z_q = ( |
| z_e + (z_q - z_e).detach() |
| ) |
|
|
| return z_q, indices, z_e, commitment_loss, codebook_loss |
|
|
| def embed_code(self, embed_id): |
| return F.embedding(embed_id, self.codebook.weight) |
|
|
| def decode_code(self, embed_id): |
| return self.embed_code(embed_id).transpose(1, 2) |
|
|
| def decode_latents(self, latents): |
| encodings = rearrange(latents, "b d t -> (b t) d") |
| codebook = self.codebook.weight |
|
|
| |
| encodings = F.normalize(encodings) |
| codebook = F.normalize(codebook) |
|
|
| |
| dist = ( |
| encodings.pow(2).sum(1, keepdim=True) |
| - 2 * encodings @ codebook.t() |
| + codebook.pow(2).sum(1, keepdim=True).t() |
| ) |
| indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) |
| z_q = self.decode_code(indices) |
| return z_q, indices |
|
|
| class VectorQuantize(nn.Module): |
| """ |
| Implementation of VQ similar to Karpathy's repo: |
| https://github.com/karpathy/deep-vector-quantization |
| Additionally uses following tricks from Improved VQGAN |
| (https://arxiv.org/pdf/2110.04627.pdf): |
| 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space |
| for improved codebook usage |
| 2. l2-normalized codes: Converts euclidean distance to cosine similarity which |
| improves training stability |
| """ |
|
|
| def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): |
| super().__init__() |
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
|
|
| self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) |
| self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) |
| self.codebook = nn.Embedding(codebook_size, codebook_dim) |
|
|
| def forward(self, z, z_mask=None): |
| """Quantized the input tensor using a fixed codebook and returns |
| the corresponding codebook vectors |
| |
| Parameters |
| ---------- |
| z : Tensor[B x D x T] |
| |
| Returns |
| ------- |
| Tensor[B x D x T] |
| Quantized continuous representation of input |
| Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| Tensor[1] |
| Codebook loss to update the codebook |
| Tensor[B x T] |
| Codebook indices (quantized discrete representation of input) |
| Tensor[B x D x T] |
| Projected latents (continuous representation of input before quantization) |
| """ |
|
|
| |
| z_e = self.in_proj(z) |
| z_q, indices = self.decode_latents(z_e) |
|
|
| if z_mask is not None: |
| commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum() |
| codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum() |
| else: |
| commitment_loss = F.mse_loss(z_e, z_q.detach()) |
| codebook_loss = F.mse_loss(z_q, z_e.detach()) |
|
|
| z_q = ( |
| z_e + (z_q - z_e).detach() |
| ) |
|
|
| z_q = self.out_proj(z_q) |
|
|
| return z_q, commitment_loss, codebook_loss, indices, z_e |
|
|
| def embed_code(self, embed_id): |
| return F.embedding(embed_id, self.codebook.weight) |
|
|
| def decode_code(self, embed_id): |
| return self.embed_code(embed_id).transpose(1, 2) |
|
|
| def decode_latents(self, latents): |
| encodings = rearrange(latents, "b d t -> (b t) d") |
| codebook = self.codebook.weight |
|
|
| |
| encodings = F.normalize(encodings) |
| codebook = F.normalize(codebook) |
|
|
| |
| dist = ( |
| encodings.pow(2).sum(1, keepdim=True) |
| - 2 * encodings @ codebook.t() |
| + codebook.pow(2).sum(1, keepdim=True).t() |
| ) |
| indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) |
| z_q = self.decode_code(indices) |
| return z_q, indices |
|
|
|
|
| class ResidualVectorQuantize(nn.Module): |
| """ |
| Introduced in SoundStream: An end2end neural audio codec |
| https://arxiv.org/abs/2107.03312 |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int = 512, |
| n_codebooks: int = 9, |
| codebook_size: int = 1024, |
| codebook_dim: Union[int, list] = 8, |
| quantizer_dropout: float = 0.0, |
| ): |
| super().__init__() |
| if isinstance(codebook_dim, int): |
| codebook_dim = [codebook_dim for _ in range(n_codebooks)] |
|
|
| self.n_codebooks = n_codebooks |
| self.codebook_dim = codebook_dim |
| self.codebook_size = codebook_size |
|
|
| self.quantizers = nn.ModuleList( |
| [ |
| VectorQuantize(input_dim, codebook_size, codebook_dim[i]) |
| for i in range(n_codebooks) |
| ] |
| ) |
| self.quantizer_dropout = quantizer_dropout |
|
|
| def forward(self, z, n_quantizers: int = None): |
| """Quantized the input tensor using a fixed set of `n` codebooks and returns |
| the corresponding codebook vectors |
| Parameters |
| ---------- |
| z : Tensor[B x D x T] |
| n_quantizers : int, optional |
| No. of quantizers to use |
| (n_quantizers < self.n_codebooks ex: for quantizer dropout) |
| Note: if `self.quantizer_dropout` is True, this argument is ignored |
| when in training mode, and a random number of quantizers is used. |
| Returns |
| ------- |
| dict |
| A dictionary with the following keys: |
| |
| "z" : Tensor[B x D x T] |
| Quantized continuous representation of input |
| "codes" : Tensor[B x N x T] |
| Codebook indices for each codebook |
| (quantized discrete representation of input) |
| "latents" : Tensor[B x N*D x T] |
| Projected latents (continuous representation of input before quantization) |
| "vq/commitment_loss" : Tensor[1] |
| Commitment loss to train encoder to predict vectors closer to codebook |
| entries |
| "vq/codebook_loss" : Tensor[1] |
| Codebook loss to update the codebook |
| """ |
| z_q = 0 |
| residual = z |
| commitment_loss = 0 |
| codebook_loss = 0 |
|
|
| codebook_indices = [] |
| latents = [] |
|
|
| if n_quantizers is None: |
| n_quantizers = self.n_codebooks |
| if self.training: |
| n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 |
| dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) |
| n_dropout = int(z.shape[0] * self.quantizer_dropout) |
| n_quantizers[:n_dropout] = dropout[:n_dropout] |
| n_quantizers = n_quantizers.to(z.device) |
|
|
| for i, quantizer in enumerate(self.quantizers): |
| if self.training is False and i >= n_quantizers: |
| break |
|
|
| z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( |
| residual |
| ) |
|
|
| |
| mask = ( |
| torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers |
| ) |
| z_q = z_q + z_q_i * mask[:, None, None] |
| residual = residual - z_q_i |
|
|
| |
| commitment_loss += (commitment_loss_i * mask).mean() |
| codebook_loss += (codebook_loss_i * mask).mean() |
|
|
| codebook_indices.append(indices_i) |
| latents.append(z_e_i) |
|
|
| codes = torch.stack(codebook_indices, dim=1) |
| latents = torch.cat(latents, dim=1) |
|
|
| return z_q, codes, latents, commitment_loss, codebook_loss |
|
|
| def from_codes(self, codes: torch.Tensor): |
| """Given the quantized codes, reconstruct the continuous representation |
| Parameters |
| ---------- |
| codes : Tensor[B x N x T] |
| Quantized discrete representation of input |
| Returns |
| ------- |
| Tensor[B x D x T] |
| Quantized continuous representation of input |
| """ |
| z_q = 0.0 |
| z_p = [] |
| n_codebooks = codes.shape[1] |
| for i in range(n_codebooks): |
| z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) |
| z_p.append(z_p_i) |
|
|
| z_q_i = self.quantizers[i].out_proj(z_p_i) |
| z_q = z_q + z_q_i |
| return z_q, torch.cat(z_p, dim=1), codes |
|
|
| def from_latents(self, latents: torch.Tensor): |
| """Given the unquantized latents, reconstruct the |
| continuous representation after quantization. |
| |
| Parameters |
| ---------- |
| latents : Tensor[B x N x T] |
| Continuous representation of input after projection |
| |
| Returns |
| ------- |
| Tensor[B x D x T] |
| Quantized representation of full-projected space |
| Tensor[B x D x T] |
| Quantized representation of latent space |
| """ |
| z_q = 0 |
| z_p = [] |
| codes = [] |
| dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) |
|
|
| n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ |
| 0 |
| ] |
| for i in range(n_codebooks): |
| j, k = dims[i], dims[i + 1] |
| z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) |
| z_p.append(z_p_i) |
| codes.append(codes_i) |
|
|
| z_q_i = self.quantizers[i].out_proj(z_p_i) |
| z_q = z_q + z_q_i |
|
|
| return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) |
|
|
|
|
| if __name__ == "__main__": |
| rvq = ResidualVectorQuantize(quantizer_dropout=True) |
| x = torch.randn(16, 512, 80) |
| y = rvq(x) |
| print(y["latents"].shape) |
|
|