# Credit to: lucidrains vector_quantizaiton from einops import rearrange, repeat import torch from torch import nn, einsum import torch.nn.functional as F import typing as tp def exists(val: tp.Optional[tp.Any]) -> bool: return val is not None def default(val: tp.Any, d: tp.Any) -> tp.Any: return val if exists(val) else d def l2norm(t): return F.normalize(t, p=2, dim=-1) def ema_inplace(moving_avg, new, decay: float): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): return (x + epsilon) / (x.sum() + n_categories * epsilon) def uniform_init(*shape: int): t = torch.empty(shape) nn.init.kaiming_uniform_(t) return t def sample_vectors(samples, num: int): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device=device)[:num] else: indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices] def kmeans(samples, num_clusters: int, num_iters: int = 10): dim, dtype = samples.shape[-1], samples.dtype # print(samples.shape) means = sample_vectors(samples, num_clusters) # print(means.shape) for _ in range(num_iters): diffs = rearrange(samples, "n d -> n () d") - rearrange( means, "c d -> () c d" ) dists = -(diffs ** 2).sum(dim=-1) # print(dists.shape) buckets = dists.max(dim=-1).indices # print(buckets.shape) bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) new_means = new_means / bins_min_clamped[..., None] means = torch.where(zero_mask[..., None], means, new_means) # print("doing kmeans:", means, means.shape, bins, bins.shape) return means, bins def orthogonal_loss_fn(t): # eq (2) from https://arxiv.org/abs/2112.00384 n = t.shape[0] normed_codes = l2norm(t) identity = torch.eye(n, device=t.device) cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) return ((cosine_sim - identity) ** 2).sum() / (n ** 2) class EuclideanCodebook(nn.Module): def __init__( self, dim: int, codebook_size: int, kmeans_init: int = True, kmeans_iters: int = 50, decay: float = 0.8, epsilon: float = 1e-5, threshold_ema_dead_code: int = 2, ): super().__init__() self.decay = decay init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros embed = init_fn(codebook_size, dim) self.codebook_size = codebook_size self.kmeans_iters = kmeans_iters self.epsilon = epsilon self.threshold_ema_dead_code = threshold_ema_dead_code self.inited = nn.Parameter(torch.Tensor([not kmeans_init]), requires_grad=False) self.cluster_size = nn.Parameter(torch.zeros(codebook_size), requires_grad=False) # Change `embed` to be an nn.Parameter self.embed = nn.Parameter(embed) self.embed_avg = nn.Parameter(embed.clone(), requires_grad=False) @torch.jit.ignore def init_embed_(self, data): if self.inited: return embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) # print("1", self.embed.data.shape) # print("2", embed.shape) self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed.clone()) self.cluster_size.data.copy_(cluster_size) self.inited.data.copy_(torch.Tensor([True])) def replace_(self, samples, mask): modified_codebook = torch.where( mask[..., None], sample_vectors(samples, self.codebook_size), self.embed ) self.embed.data.copy_(modified_codebook) def expire_codes_(self, batch_samples): if self.threshold_ema_dead_code == 0: return expired_codes = self.cluster_size < self.threshold_ema_dead_code if not torch.any(expired_codes): return batch_samples = rearrange(batch_samples, "... d -> (...) d") self.replace_(batch_samples, mask=expired_codes) def preprocess(self, x): x = rearrange(x, "... d -> (...) d") return x def quantize(self, x): embed = self.embed.t() dist = -( x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True) ) embed_ind = dist.max(dim=-1).indices return embed_ind def postprocess_emb(self, embed_ind, shape): return embed_ind.view(*shape[:-1]) def dequantize(self, embed_ind): quantize = F.embedding(embed_ind, self.embed) return quantize def encode(self, x): shape = x.shape # pre-process x = self.preprocess(x) # quantize embed_ind = self.quantize(x) # post-process embed_ind = self.postprocess_emb(embed_ind, shape) return embed_ind def decode(self, embed_ind): quantize = self.dequantize(embed_ind) return quantize def forward(self, x): shape, dtype = x.shape, x.dtype x = self.preprocess(x) self.init_embed_(x) embed_ind = self.quantize(x) embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = self.postprocess_emb(embed_ind, shape) quantize = self.dequantize(embed_ind) if self.training: # print("doing rvq ema") # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. self.expire_codes_(x) ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = x.t() @ embed_onehot ema_inplace(self.embed_avg, embed_sum.t(), self.decay) cluster_size = ( laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() ) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) return quantize, embed_ind class VectorQuantization(nn.Module): """Vector quantization implementation. Currently supports only euclidean distance. Args: dim (int): Dimension codebook_size (int): Codebook size codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. threshold_ema_dead_code (int): channels_last (bool): Channels are the last dimension in the input tensors. commitment_weight (float): Weight for commitment loss. orthogonal_reg_weight (float): Orthogonal regularization weights. orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. orthogonal_reg_max_codes (optional int): Maximum number of codes to consider for orthogonal regularization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, dim: int, codebook_size: int, codebook_dim: tp.Optional[int] = None, decay: float = 0.8, epsilon: float = 1e-5, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, channels_last: bool = False, commitment_weight: float = 1., orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: tp.Optional[int] = None, ): super().__init__() _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) self.epsilon = epsilon self.commitment_weight = commitment_weight self.orthogonal_reg_weight = orthogonal_reg_weight self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, decay=decay, epsilon=epsilon, threshold_ema_dead_code=threshold_ema_dead_code) self.codebook_size = codebook_size self.channels_last = channels_last @property def codebook(self): return self._codebook.embed @property def inited(self): return self._codebook.inited def _preprocess(self, x): if not self.channels_last: x = rearrange(x, "b d n -> b n d") return x def _postprocess(self, quantize): if not self.channels_last: quantize = rearrange(quantize, "b n d -> b d n") return quantize def encode(self, x): x = self._preprocess(x) x = self.project_in(x) embed_in = self._codebook.encode(x) return embed_in def decode(self, embed_ind): quantize = self._codebook.decode(embed_ind) quantize = self.project_out(quantize) quantize = self._postprocess(quantize) return quantize def forward(self, x): device = x.device # print("before preprocess: ", x.shape) x = self._preprocess(x) # print("after preprocess: ", x.shape) # print("before project:", x.shape) x = self.project_in(x) # print("after project: ", x.shape) quantize, embed_ind = self._codebook(x) if self.training: quantize = x + (quantize - x).detach() loss = torch.tensor([0.0], device=device, requires_grad=self.training) # if self.training: if self.commitment_weight > 0: commit_loss = F.mse_loss(quantize.detach(), x) loss = loss + commit_loss * self.commitment_weight if self.orthogonal_reg_weight > 0: codebook = self.codebook if self.orthogonal_reg_active_codes_only: # only calculate orthogonal loss for the activated codes for this batch unique_code_ids = torch.unique(embed_ind) codebook = codebook[unique_code_ids] num_codes = codebook.shape[0] if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] codebook = codebook[rand_ids] orthogonal_reg_loss = orthogonal_loss_fn(codebook) loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight quantize = self.project_out(quantize) quantize = self._postprocess(quantize) return quantize, embed_ind, loss class ResidualVectorQuantization(nn.Module): def __init__(self, num_quantizers, **kwargs): super().__init__() # Explicitly define each layer as an attribute and register codebook parameters for i in range(num_quantizers): vq_layer = VectorQuantization(**kwargs) setattr(self, f"vq_layer_{i}", vq_layer) # Register each codebook embed parameter explicitly self.register_parameter(f"embed_{i}", vq_layer._codebook.embed) self.num_quantizers = num_quantizers def forward(self, x, n_q: tp.Optional[int] = None): quantized_out = 0.0 residual = x all_losses = [] all_indices = [] n_q = n_q or self.num_quantizers # Loop through the layers explicitly using getattr for i in range(n_q): layer = getattr(self, f"vq_layer_{i}") quantized, indices, loss = layer(residual) quantized = quantized.detach() residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) if self.training: quantized_out = x + (quantized_out - x).detach() out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: residual = x all_indices = [] n_q = n_q or self.num_quantizers for i in range(n_q): layer = getattr(self, f"vq_layer_{i}") indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized all_indices.append(indices) out_indices = torch.stack(all_indices) return out_indices def decode(self, q_indices: torch.Tensor) -> torch.Tensor: quantized_out = torch.tensor(0.0, device=q_indices.device) for i, indices in enumerate(q_indices): layer = getattr(self, f"vq_layer_{i}") quantized = layer.decode(indices) quantized_out = quantized_out + quantized return quantized_out