| | from typing import List, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from vector_quantize_pytorch import VectorQuantize as torchVQ |
| |
|
| |
|
| | def sample_vectors(samples, num): |
| | |
| | num_samples, device = samples.shape[0], samples.device |
| | if num_samples >= num: |
| | indices = torch.randperm(num_samples, device=device)[:num] |
| | else: |
| | indices = torch.randint(0, num_samples, (num,), device=device) |
| | return samples[indices].float() |
| |
|
| |
|
| | def ema_inplace(moving_avg, new, decay): |
| | |
| | """Update exponential moving average in-place""" |
| | moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) |
| |
|
| |
|
| | def kmeans(samples, num_clusters, num_iters=10): |
| | |
| | dim, _ = samples.shape[-1], torch.float32 |
| | means = sample_vectors(samples, num_clusters).float() |
| |
|
| | for _ in range(num_iters): |
| | dists = -( |
| | samples.float().pow(2).sum(1, keepdim=True) |
| | - 2 * samples.float() @ means.t() |
| | + means.t().float().pow(2).sum(0, keepdim=True) |
| | ) |
| | |
| | buckets = dists.max(dim=-1).indices |
| | bins = torch.bincount(buckets, minlength=num_clusters) |
| | zero_mask = bins == 0 |
| | bins_min_clamped = bins.masked_fill(zero_mask, 1) |
| |
|
| | new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) |
| | new_means.scatter_add_( |
| | 0, buckets.unsqueeze(1).expand(-1, dim), samples.float() |
| | ) |
| | new_means = new_means / bins_min_clamped[..., None] |
| | means = torch.where(zero_mask[..., None], means, new_means) |
| |
|
| | |
| | dists = -( |
| | samples.float().pow(2).sum(1, keepdim=True) |
| | - 2 * samples.float() @ means.t() |
| | + means.t().float().pow(2).sum(0, keepdim=True) |
| | ) |
| | buckets = dists.max(dim=-1).indices |
| | bins = torch.bincount(buckets, minlength=num_clusters).float() |
| |
|
| | return means, bins |
| |
|
| |
|
| | class VectorQuantize(nn.Module): |
| | def __init__( |
| | self, |
| | input_dim, |
| | codebook_size, |
| | codebook_dim, |
| | commitment=1.0, |
| | decay=0.99, |
| | epsilon=1e-5, |
| | threshold_ema_dead=2, |
| | kmeans_init=True, |
| | kmeans_iters=10, |
| | rotation_trick=False, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.codebook_size = codebook_size |
| | self.codebook_dim = codebook_dim |
| | self.commitment = commitment |
| | self.decay = decay |
| | self.epsilon = epsilon |
| | self.threshold_ema_dead = threshold_ema_dead |
| | self.kmeans_init = kmeans_init |
| | self.kmeans_iters = kmeans_iters |
| | self.rotation_trick = rotation_trick |
| |
|
| | if self.input_dim != self.codebook_dim: |
| | self.in_project = nn.Linear(input_dim, codebook_dim) |
| | self.out_project = nn.Linear(codebook_dim, input_dim) |
| | else: |
| | self.in_project = nn.Identity() |
| | self.out_project = nn.Identity() |
| |
|
| | |
| | init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y) |
| | self.register_buffer( |
| | "codebook", init_fn(codebook_size, codebook_dim).float() |
| | ) |
| | self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) |
| | self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) |
| | self.register_buffer("embed_avg", self.codebook.clone().float()) |
| |
|
| | def ema_update(self, encodings, embed_onehot): |
| | |
| | """Update codebook using EMA""" |
| | encodings = encodings.float() |
| | embed_onehot = embed_onehot.float() |
| | cluster_size_new = embed_onehot.sum(0) |
| | embed_sum = encodings.t() @ embed_onehot |
| |
|
| | |
| | if dist.is_initialized(): |
| | dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM) |
| | dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM) |
| |
|
| | ema_inplace(self.cluster_size, cluster_size_new, self.decay) |
| | ema_inplace(self.embed_avg, embed_sum.t(), self.decay) |
| |
|
| | |
| | cluster_size = (self.cluster_size + self.epsilon) / ( |
| | self.cluster_size.sum() + self.codebook_size * self.epsilon |
| | ) |
| | cluster_size = cluster_size * self.cluster_size.sum() |
| | self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) |
| |
|
| | def replace_dead_codes(self, encodings): |
| | |
| | """Replace dead codes with random samples from current batch""" |
| | if self.threshold_ema_dead == 0: |
| | return |
| |
|
| | dead_mask = self.cluster_size < self.threshold_ema_dead |
| | if dead_mask.any(): |
| | if dist.is_initialized() and dist.get_rank() == 0: |
| | samples = sample_vectors(encodings.float(), self.codebook_size) |
| | print(f"Replace {dead_mask.sum().item()} dead codes") |
| | else: |
| | samples = torch.zeros_like(self.codebook).float() |
| |
|
| | |
| | if dist.is_initialized(): |
| | dist.broadcast(samples, src=0) |
| |
|
| | self.codebook[dead_mask] = samples[: dead_mask.sum()].to(self.codebook.dtype) |
| |
|
| | def init_codebook(self, encodings): |
| | |
| | """Initialize codebook with k-means and update cluster_size""" |
| | if self.inited.item(): |
| | return |
| |
|
| | if dist.is_initialized() and dist.get_rank() == 0: |
| | embed, cluster_sizes = kmeans( |
| | encodings.float(), self.codebook_size, self.kmeans_iters |
| | ) |
| | else: |
| | embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() |
| | cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) |
| |
|
| | |
| | if dist.is_initialized(): |
| | dist.broadcast(embed, src=0) |
| | dist.broadcast(cluster_sizes, src=0) |
| |
|
| | self.codebook.copy_(embed) |
| | self.embed_avg.copy_(embed.clone()) |
| | self.cluster_size.copy_(cluster_sizes.float()) |
| | self.inited.fill_(True) |
| |
|
| | def forward(self, z): |
| | self = self.to(torch.float32) |
| | z = z.float() |
| | z_e = self.in_project(z).float() |
| |
|
| | |
| | encodings = rearrange(z_e, "b t d -> (b t) d").float() |
| |
|
| | |
| | if self.kmeans_init and not self.inited.item(): |
| | self.init_codebook(encodings) |
| |
|
| | dist = ( |
| | encodings.pow(2).sum(1, keepdim=True) |
| | - 2 * encodings @ self.codebook.float().t() |
| | + self.codebook.float().pow(2).sum(1, keepdim=True).t() |
| | ) |
| | indices = (-dist).max(1)[1] |
| |
|
| | |
| | |
| |
|
| | indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) |
| | z_q = self.decode_code(indices).float() |
| | commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment |
| |
|
| | if self.training and torch.is_grad_enabled(): |
| | embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() |
| | self.ema_update(encodings, embed_onehot) |
| | self.replace_dead_codes(encodings) |
| |
|
| | z_q = (z_q - z_e).detach() + z_e |
| | z_q = self.out_project(z_q).float() |
| |
|
| | return ( |
| | z_q, |
| | commit_loss, |
| | torch.tensor(0.0, device=z.device, dtype=torch.float32), |
| | indices, |
| | z_e, |
| | ) |
| |
|
| | def decode_code(self, embed_id): |
| | return F.embedding(embed_id, self.codebook).float() |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | class ResidualVectorQuantize(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int = 256, |
| | n_codebooks: int = 4, |
| | codebook_size: int = 512, |
| | codebook_dim: Union[int, list] = 8, |
| | quantizer_dropout: float = 0.25, |
| | commitment: float = 0.25, |
| | decay: float = 0.99, |
| | epsilon: float = 1e-5, |
| | threshold_ema_dead: int = 2, |
| | kmeans_init: bool = True, |
| | kmeans_iters: int = 10, |
| | rotation_trick: bool = False, |
| | ): |
| | super().__init__() |
| | if isinstance(codebook_dim, int): |
| | codebook_dim = [codebook_dim for _ in range(n_codebooks)] |
| |
|
| | self.n_codebooks = n_codebooks |
| | self.codebook_dim = codebook_dim |
| | self.codebook_size = codebook_size |
| |
|
| | self.quantizers = nn.ModuleList( |
| | [ |
| | VectorQuantize( |
| | input_dim=dim, |
| | codebook_size=codebook_size, |
| | codebook_dim=codebook_dim[i], |
| | commitment=commitment, |
| | decay=decay, |
| | epsilon=epsilon, |
| | threshold_ema_dead=threshold_ema_dead, |
| | kmeans_init=kmeans_init, |
| | kmeans_iters=kmeans_iters, |
| | rotation_trick=rotation_trick, |
| | ) |
| | for i in range(n_codebooks) |
| | ] |
| | ) |
| | self.quantizer_dropout = quantizer_dropout |
| |
|
| | def forward(self, z, n_quantizers: int = None): |
| | """Quantized the input tensor using a fixed set of `n` codebooks and returns |
| | the corresponding codebook vectors |
| | Parameters |
| | ---------- |
| | z : Tensor[B x D x T] |
| | n_quantizers : int, optional |
| | No. of quantizers to use |
| | (n_quantizers < self.n_codebooks ex: for quantizer dropout) |
| | Note: if `self.quantizer_dropout` is True, this argument is ignored |
| | when in training mode, and a random number of quantizers is used. |
| | Returns |
| | ------- |
| | dict |
| | A dictionary with the following keys: |
| | |
| | "z" : Tensor[B x D x T] |
| | Quantized continuous representation of input |
| | "codes" : Tensor[B x N x T] |
| | Codebook indices for each codebook |
| | (quantized discrete representation of input) |
| | "latents" : Tensor[B x N*D x T] |
| | Projected latents (continuous representation of input before quantization) |
| | "vq/commitment_loss" : Tensor[1] |
| | Commitment loss to train encoder to predict vectors closer to codebook |
| | entries |
| | "vq/codebook_loss" : Tensor[1] |
| | Codebook loss to update the codebook |
| | """ |
| | z_q, residual = 0, z |
| | commitment_loss, codebook_loss = 0, 0 |
| |
|
| | codebook_indices, latents = [], [] |
| |
|
| | if n_quantizers is None: |
| | n_quantizers = self.n_codebooks |
| | if self.training: |
| | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 |
| | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) |
| | n_dropout = int(z.shape[0] * self.quantizer_dropout) |
| | n_quantizers[:n_dropout] = dropout[:n_dropout] |
| | n_quantizers = n_quantizers.to(z.device) |
| |
|
| | for i, quantizer in enumerate(self.quantizers): |
| | if self.training is False and i >= n_quantizers: |
| | break |
| |
|
| | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual) |
| |
|
| | |
| | mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers |
| | z_q = z_q + z_q_i * mask[:, None, None] |
| | residual = residual - z_q_i |
| |
|
| | |
| | commitment_loss += (commitment_loss_i * mask).mean() |
| | codebook_loss += (codebook_loss_i * mask).mean() |
| |
|
| | codebook_indices.append(indices_i) |
| | latents.append(z_e_i) |
| |
|
| | codes = torch.stack(codebook_indices, dim=-1) |
| | latents = torch.cat(latents, dim=1) |
| |
|
| | return z_q, codes, latents, commitment_loss, codebook_loss |
| |
|
| | def from_codes(self, codes: torch.Tensor): |
| | """Given the quantized codes, reconstruct the continuous representation |
| | Parameters |
| | ---------- |
| | codes : Tensor[B x N x T] |
| | Quantized discrete representation of input |
| | Returns |
| | ------- |
| | Tensor[B x D x T] |
| | Quantized continuous representation of input |
| | """ |
| | z_q = 0.0 |
| | z_p = [] |
| | n_codebooks = codes.shape[-1] |
| | for i in range(n_codebooks): |
| | z_p_i = self.quantizers[i].decode_code(codes[..., i]) |
| | z_p.append(z_p_i) |
| |
|
| | z_q_i = self.quantizers[i].out_project(z_p_i) |
| | z_q = z_q + z_q_i |
| | return z_q, torch.cat(z_p, dim=-1), codes |
| |
|
| | def from_latents(self, latents: torch.Tensor): |
| | """Given the unquantized latents, reconstruct the |
| | continuous representation after quantization. |
| | |
| | Parameters |
| | ---------- |
| | latents : Tensor[B x N x T] |
| | Continuous representation of input after projection |
| | |
| | Returns |
| | ------- |
| | Tensor[B x D x T] |
| | Quantized representation of full-projected space |
| | Tensor[B x D x T] |
| | Quantized representation of latent space |
| | """ |
| | z_q = 0 |
| | z_p = [] |
| | codes = [] |
| | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) |
| |
|
| | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0] |
| | for i in range(n_codebooks): |
| | j, k = dims[i], dims[i + 1] |
| | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) |
| | z_p.append(z_p_i) |
| | codes.append(codes_i) |
| |
|
| | z_q_i = self.quantizers[i].out_proj(z_p_i) |
| | z_q = z_q + z_q_i |
| |
|
| | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) |
| |
|
| |
|
| | class IndependentVectorQuantize(nn.Module): |
| | def __init__(self, num_codebooks: int = 1, **kwargs): |
| | super().__init__() |
| | self.vector_quantizers = nn.ModuleList([torchVQ(**kwargs) for _ in range(num_codebooks)]) |
| | self.num_codebooks = num_codebooks |
| | self.codebook_size = self.vector_quantizers[0].codebook_size |
| |
|
| | @property |
| | def ema_update(self): |
| | return [vq.ema_update for vq in self.vector_quantizers] |
| |
|
| | @property |
| | def codebook(self): |
| | return torch.stack([vq.codebook for vq in self.vector_quantizers], dim=0) |
| |
|
| | @codebook.setter |
| | def codebook(self, codes: List[torch.Tensor]): |
| | assert len(codes) == self.num_codebooks, "Number of codebooks must match" |
| | if not self.separate_codebook_per_head: |
| | codes = rearrange(codes, "... -> 1 ...") |
| |
|
| | for i, code in enumerate(codes): |
| | self.vector_quantizers[i].codebook.copy_(code) |
| |
|
| | def get_codes_from_indices(self, indices: torch.Tensor): |
| | codes = list() |
| | for i in range(self.num_codebooks): |
| | codes.append(self.vector_quantizers[i].get_codes_from_indices(indices[..., i : i + 1])) |
| | return torch.cat(codes, dim=-2) |
| |
|
| | def get_output_from_indices(self, indices: torch.Tensor): |
| | outputs = list() |
| | for i in range(self.num_codebooks): |
| | outputs.append(self.vector_quantizers[i].get_output_from_indices(indices[..., i : i + 1])) |
| | return torch.cat(outputs, dim=-2) |
| |
|
| | def update_in_place_optimizer(self): |
| | for i in range(self.num_codebooks): |
| | self.vector_quantizers[i].update_in_place_optimizer() |
| |
|
| | def forward(self, x: torch.Tensor, *args, **kwargs): |
| | assert x.shape[1] == self.num_codebooks |
| | quantized, indices, commit_losses = list(), list(), 0 |
| | for i in range(self.num_codebooks): |
| | quantized_i, indices_i, commit_loss_i = self.vector_quantizers[i](x[:, i : i + 1]) |
| | quantized.append(quantized_i) |
| | indices.append(indices_i) |
| | commit_losses += commit_loss_i |
| | quantized = torch.cat(quantized, dim=-2) |
| | indices = torch.cat(indices, dim=-1) |
| | return quantized, indices, commit_losses / self.num_codebooks |
| |
|
| |
|
| | if __name__ == "__main__": |
| | vq = IndependentVectorQuantize( |
| | num_codebooks=16, |
| | dim=256, |
| | codebook_size=2048, |
| | decay=0.8, |
| | commitment_weight=1.0, |
| | ) |
| |
|
| | x = torch.randn(1, 16, 256) |
| | quantized, indices, commit_loss = vq(x) |
| |
|