|
|
from typing import List, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from vector_quantize_pytorch import VectorQuantize as torchVQ |
|
|
|
|
|
|
|
|
def sample_vectors(samples, num): |
|
|
|
|
|
num_samples, device = samples.shape[0], samples.device |
|
|
if num_samples >= num: |
|
|
indices = torch.randperm(num_samples, device=device)[:num] |
|
|
else: |
|
|
indices = torch.randint(0, num_samples, (num,), device=device) |
|
|
return samples[indices].float() |
|
|
|
|
|
|
|
|
def ema_inplace(moving_avg, new, decay): |
|
|
|
|
|
"""Update exponential moving average in-place""" |
|
|
moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) |
|
|
|
|
|
|
|
|
def kmeans(samples, num_clusters, num_iters=10): |
|
|
|
|
|
dim, _ = samples.shape[-1], torch.float32 |
|
|
means = sample_vectors(samples, num_clusters).float() |
|
|
|
|
|
for _ in range(num_iters): |
|
|
dists = -( |
|
|
samples.float().pow(2).sum(1, keepdim=True) |
|
|
- 2 * samples.float() @ means.t() |
|
|
+ means.t().float().pow(2).sum(0, keepdim=True) |
|
|
) |
|
|
|
|
|
buckets = dists.max(dim=-1).indices |
|
|
bins = torch.bincount(buckets, minlength=num_clusters) |
|
|
zero_mask = bins == 0 |
|
|
bins_min_clamped = bins.masked_fill(zero_mask, 1) |
|
|
|
|
|
new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) |
|
|
new_means.scatter_add_( |
|
|
0, buckets.unsqueeze(1).expand(-1, dim), samples.float() |
|
|
) |
|
|
new_means = new_means / bins_min_clamped[..., None] |
|
|
means = torch.where(zero_mask[..., None], means, new_means) |
|
|
|
|
|
|
|
|
dists = -( |
|
|
samples.float().pow(2).sum(1, keepdim=True) |
|
|
- 2 * samples.float() @ means.t() |
|
|
+ means.t().float().pow(2).sum(0, keepdim=True) |
|
|
) |
|
|
buckets = dists.max(dim=-1).indices |
|
|
bins = torch.bincount(buckets, minlength=num_clusters).float() |
|
|
|
|
|
return means, bins |
|
|
|
|
|
|
|
|
class VectorQuantize(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_dim, |
|
|
codebook_size, |
|
|
codebook_dim, |
|
|
commitment=1.0, |
|
|
decay=0.99, |
|
|
epsilon=1e-5, |
|
|
threshold_ema_dead=2, |
|
|
kmeans_init=True, |
|
|
kmeans_iters=10, |
|
|
rotation_trick=False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.codebook_size = codebook_size |
|
|
self.codebook_dim = codebook_dim |
|
|
self.commitment = commitment |
|
|
self.decay = decay |
|
|
self.epsilon = epsilon |
|
|
self.threshold_ema_dead = threshold_ema_dead |
|
|
self.kmeans_init = kmeans_init |
|
|
self.kmeans_iters = kmeans_iters |
|
|
self.rotation_trick = rotation_trick |
|
|
|
|
|
if self.input_dim != self.codebook_dim: |
|
|
self.in_project = nn.Linear(input_dim, codebook_dim) |
|
|
self.out_project = nn.Linear(codebook_dim, input_dim) |
|
|
else: |
|
|
self.in_project = nn.Identity() |
|
|
self.out_project = nn.Identity() |
|
|
|
|
|
|
|
|
init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y) |
|
|
self.register_buffer( |
|
|
"codebook", init_fn(codebook_size, codebook_dim).float() |
|
|
) |
|
|
self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) |
|
|
self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) |
|
|
self.register_buffer("embed_avg", self.codebook.clone().float()) |
|
|
|
|
|
def ema_update(self, encodings, embed_onehot): |
|
|
|
|
|
"""Update codebook using EMA""" |
|
|
encodings = encodings.float() |
|
|
embed_onehot = embed_onehot.float() |
|
|
cluster_size_new = embed_onehot.sum(0) |
|
|
embed_sum = encodings.t() @ embed_onehot |
|
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM) |
|
|
dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM) |
|
|
|
|
|
ema_inplace(self.cluster_size, cluster_size_new, self.decay) |
|
|
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) |
|
|
|
|
|
|
|
|
cluster_size = (self.cluster_size + self.epsilon) / ( |
|
|
self.cluster_size.sum() + self.codebook_size * self.epsilon |
|
|
) |
|
|
cluster_size = cluster_size * self.cluster_size.sum() |
|
|
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) |
|
|
|
|
|
def replace_dead_codes(self, encodings): |
|
|
|
|
|
"""Replace dead codes with random samples from current batch""" |
|
|
if self.threshold_ema_dead == 0: |
|
|
return |
|
|
|
|
|
dead_mask = self.cluster_size < self.threshold_ema_dead |
|
|
if dead_mask.any(): |
|
|
if dist.is_initialized() and dist.get_rank() == 0: |
|
|
samples = sample_vectors(encodings.float(), self.codebook_size) |
|
|
print(f"Replace {dead_mask.sum().item()} dead codes") |
|
|
else: |
|
|
samples = torch.zeros_like(self.codebook).float() |
|
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.broadcast(samples, src=0) |
|
|
|
|
|
self.codebook[dead_mask] = samples[: dead_mask.sum()].to(self.codebook.dtype) |
|
|
|
|
|
def init_codebook(self, encodings): |
|
|
|
|
|
"""Initialize codebook with k-means and update cluster_size""" |
|
|
if self.inited.item(): |
|
|
return |
|
|
|
|
|
if dist.is_initialized() and dist.get_rank() == 0: |
|
|
embed, cluster_sizes = kmeans( |
|
|
encodings.float(), self.codebook_size, self.kmeans_iters |
|
|
) |
|
|
else: |
|
|
embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() |
|
|
cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.broadcast(embed, src=0) |
|
|
dist.broadcast(cluster_sizes, src=0) |
|
|
|
|
|
self.codebook.copy_(embed) |
|
|
self.embed_avg.copy_(embed.clone()) |
|
|
self.cluster_size.copy_(cluster_sizes.float()) |
|
|
self.inited.fill_(True) |
|
|
|
|
|
def forward(self, z): |
|
|
self = self.to(torch.float32) |
|
|
z = z.float() |
|
|
z_e = self.in_project(z).float() |
|
|
|
|
|
|
|
|
encodings = rearrange(z_e, "b t d -> (b t) d").float() |
|
|
|
|
|
|
|
|
if self.kmeans_init and not self.inited.item(): |
|
|
self.init_codebook(encodings) |
|
|
|
|
|
dist = ( |
|
|
encodings.pow(2).sum(1, keepdim=True) |
|
|
- 2 * encodings @ self.codebook.float().t() |
|
|
+ self.codebook.float().pow(2).sum(1, keepdim=True).t() |
|
|
) |
|
|
indices = (-dist).max(1)[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) |
|
|
z_q = self.decode_code(indices).float() |
|
|
commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment |
|
|
|
|
|
if self.training and torch.is_grad_enabled(): |
|
|
embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() |
|
|
self.ema_update(encodings, embed_onehot) |
|
|
self.replace_dead_codes(encodings) |
|
|
|
|
|
z_q = (z_q - z_e).detach() + z_e |
|
|
z_q = self.out_project(z_q).float() |
|
|
|
|
|
return ( |
|
|
z_q, |
|
|
commit_loss, |
|
|
torch.tensor(0.0, device=z.device, dtype=torch.float32), |
|
|
indices, |
|
|
z_e, |
|
|
) |
|
|
|
|
|
def decode_code(self, embed_id): |
|
|
return F.embedding(embed_id, self.codebook).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualVectorQuantize(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int = 256, |
|
|
n_codebooks: int = 4, |
|
|
codebook_size: int = 512, |
|
|
codebook_dim: Union[int, list] = 8, |
|
|
quantizer_dropout: float = 0.25, |
|
|
commitment: float = 0.25, |
|
|
decay: float = 0.99, |
|
|
epsilon: float = 1e-5, |
|
|
threshold_ema_dead: int = 2, |
|
|
kmeans_init: bool = True, |
|
|
kmeans_iters: int = 10, |
|
|
rotation_trick: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
if isinstance(codebook_dim, int): |
|
|
codebook_dim = [codebook_dim for _ in range(n_codebooks)] |
|
|
|
|
|
self.n_codebooks = n_codebooks |
|
|
self.codebook_dim = codebook_dim |
|
|
self.codebook_size = codebook_size |
|
|
|
|
|
self.quantizers = nn.ModuleList( |
|
|
[ |
|
|
VectorQuantize( |
|
|
input_dim=dim, |
|
|
codebook_size=codebook_size, |
|
|
codebook_dim=codebook_dim[i], |
|
|
commitment=commitment, |
|
|
decay=decay, |
|
|
epsilon=epsilon, |
|
|
threshold_ema_dead=threshold_ema_dead, |
|
|
kmeans_init=kmeans_init, |
|
|
kmeans_iters=kmeans_iters, |
|
|
rotation_trick=rotation_trick, |
|
|
) |
|
|
for i in range(n_codebooks) |
|
|
] |
|
|
) |
|
|
self.quantizer_dropout = quantizer_dropout |
|
|
|
|
|
def forward(self, z, n_quantizers: int = None): |
|
|
"""Quantized the input tensor using a fixed set of `n` codebooks and returns |
|
|
the corresponding codebook vectors |
|
|
Parameters |
|
|
---------- |
|
|
z : Tensor[B x D x T] |
|
|
n_quantizers : int, optional |
|
|
No. of quantizers to use |
|
|
(n_quantizers < self.n_codebooks ex: for quantizer dropout) |
|
|
Note: if `self.quantizer_dropout` is True, this argument is ignored |
|
|
when in training mode, and a random number of quantizers is used. |
|
|
Returns |
|
|
------- |
|
|
dict |
|
|
A dictionary with the following keys: |
|
|
|
|
|
"z" : Tensor[B x D x T] |
|
|
Quantized continuous representation of input |
|
|
"codes" : Tensor[B x N x T] |
|
|
Codebook indices for each codebook |
|
|
(quantized discrete representation of input) |
|
|
"latents" : Tensor[B x N*D x T] |
|
|
Projected latents (continuous representation of input before quantization) |
|
|
"vq/commitment_loss" : Tensor[1] |
|
|
Commitment loss to train encoder to predict vectors closer to codebook |
|
|
entries |
|
|
"vq/codebook_loss" : Tensor[1] |
|
|
Codebook loss to update the codebook |
|
|
""" |
|
|
z_q, residual = 0, z |
|
|
commitment_loss, codebook_loss = 0, 0 |
|
|
|
|
|
codebook_indices, latents = [], [] |
|
|
|
|
|
if n_quantizers is None: |
|
|
n_quantizers = self.n_codebooks |
|
|
if self.training: |
|
|
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 |
|
|
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) |
|
|
n_dropout = int(z.shape[0] * self.quantizer_dropout) |
|
|
n_quantizers[:n_dropout] = dropout[:n_dropout] |
|
|
n_quantizers = n_quantizers.to(z.device) |
|
|
|
|
|
for i, quantizer in enumerate(self.quantizers): |
|
|
if self.training is False and i >= n_quantizers: |
|
|
break |
|
|
|
|
|
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual) |
|
|
|
|
|
|
|
|
mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers |
|
|
z_q = z_q + z_q_i * mask[:, None, None] |
|
|
residual = residual - z_q_i |
|
|
|
|
|
|
|
|
commitment_loss += (commitment_loss_i * mask).mean() |
|
|
codebook_loss += (codebook_loss_i * mask).mean() |
|
|
|
|
|
codebook_indices.append(indices_i) |
|
|
latents.append(z_e_i) |
|
|
|
|
|
codes = torch.stack(codebook_indices, dim=-1) |
|
|
latents = torch.cat(latents, dim=1) |
|
|
|
|
|
return z_q, codes, latents, commitment_loss, codebook_loss |
|
|
|
|
|
def from_codes(self, codes: torch.Tensor): |
|
|
"""Given the quantized codes, reconstruct the continuous representation |
|
|
Parameters |
|
|
---------- |
|
|
codes : Tensor[B x N x T] |
|
|
Quantized discrete representation of input |
|
|
Returns |
|
|
------- |
|
|
Tensor[B x D x T] |
|
|
Quantized continuous representation of input |
|
|
""" |
|
|
z_q = 0.0 |
|
|
z_p = [] |
|
|
n_codebooks = codes.shape[-1] |
|
|
for i in range(n_codebooks): |
|
|
z_p_i = self.quantizers[i].decode_code(codes[..., i]) |
|
|
z_p.append(z_p_i) |
|
|
|
|
|
z_q_i = self.quantizers[i].out_project(z_p_i) |
|
|
z_q = z_q + z_q_i |
|
|
return z_q, torch.cat(z_p, dim=-1), codes |
|
|
|
|
|
def from_latents(self, latents: torch.Tensor): |
|
|
"""Given the unquantized latents, reconstruct the |
|
|
continuous representation after quantization. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
latents : Tensor[B x N x T] |
|
|
Continuous representation of input after projection |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Tensor[B x D x T] |
|
|
Quantized representation of full-projected space |
|
|
Tensor[B x D x T] |
|
|
Quantized representation of latent space |
|
|
""" |
|
|
z_q = 0 |
|
|
z_p = [] |
|
|
codes = [] |
|
|
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) |
|
|
|
|
|
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0] |
|
|
for i in range(n_codebooks): |
|
|
j, k = dims[i], dims[i + 1] |
|
|
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) |
|
|
z_p.append(z_p_i) |
|
|
codes.append(codes_i) |
|
|
|
|
|
z_q_i = self.quantizers[i].out_proj(z_p_i) |
|
|
z_q = z_q + z_q_i |
|
|
|
|
|
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) |
|
|
|
|
|
|
|
|
class IndependentVectorQuantize(nn.Module): |
|
|
def __init__(self, num_codebooks: int = 1, **kwargs): |
|
|
super().__init__() |
|
|
self.vector_quantizers = nn.ModuleList([torchVQ(**kwargs) for _ in range(num_codebooks)]) |
|
|
self.num_codebooks = num_codebooks |
|
|
self.codebook_size = self.vector_quantizers[0].codebook_size |
|
|
|
|
|
@property |
|
|
def ema_update(self): |
|
|
return [vq.ema_update for vq in self.vector_quantizers] |
|
|
|
|
|
@property |
|
|
def codebook(self): |
|
|
return torch.stack([vq.codebook for vq in self.vector_quantizers], dim=0) |
|
|
|
|
|
@codebook.setter |
|
|
def codebook(self, codes: List[torch.Tensor]): |
|
|
assert len(codes) == self.num_codebooks, "Number of codebooks must match" |
|
|
if not self.separate_codebook_per_head: |
|
|
codes = rearrange(codes, "... -> 1 ...") |
|
|
|
|
|
for i, code in enumerate(codes): |
|
|
self.vector_quantizers[i].codebook.copy_(code) |
|
|
|
|
|
def get_codes_from_indices(self, indices: torch.Tensor): |
|
|
codes = list() |
|
|
for i in range(self.num_codebooks): |
|
|
codes.append(self.vector_quantizers[i].get_codes_from_indices(indices[..., i : i + 1])) |
|
|
return torch.cat(codes, dim=-2) |
|
|
|
|
|
def get_output_from_indices(self, indices: torch.Tensor): |
|
|
outputs = list() |
|
|
for i in range(self.num_codebooks): |
|
|
outputs.append(self.vector_quantizers[i].get_output_from_indices(indices[..., i : i + 1])) |
|
|
return torch.cat(outputs, dim=-2) |
|
|
|
|
|
def update_in_place_optimizer(self): |
|
|
for i in range(self.num_codebooks): |
|
|
self.vector_quantizers[i].update_in_place_optimizer() |
|
|
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs): |
|
|
assert x.shape[1] == self.num_codebooks |
|
|
quantized, indices, commit_losses = list(), list(), 0 |
|
|
for i in range(self.num_codebooks): |
|
|
quantized_i, indices_i, commit_loss_i = self.vector_quantizers[i](x[:, i : i + 1]) |
|
|
quantized.append(quantized_i) |
|
|
indices.append(indices_i) |
|
|
commit_losses += commit_loss_i |
|
|
quantized = torch.cat(quantized, dim=-2) |
|
|
indices = torch.cat(indices, dim=-1) |
|
|
return quantized, indices, commit_losses / self.num_codebooks |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
vq = IndependentVectorQuantize( |
|
|
num_codebooks=16, |
|
|
dim=256, |
|
|
codebook_size=2048, |
|
|
decay=0.8, |
|
|
commitment_weight=1.0, |
|
|
) |
|
|
|
|
|
x = torch.randn(1, 16, 256) |
|
|
quantized, indices, commit_loss = vq(x) |
|
|
|