ZibinDong's picture
Upload folder using huggingface_hub
cc2596c verified
from typing import List, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vector_quantize_pytorch import VectorQuantize as torchVQ
def sample_vectors(samples, num):
# samples: (N, D), num_samples: N, feature dim: D
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices].float() # (num, D), ensure fp32
def ema_inplace(moving_avg, new, decay):
# moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
"""Update exponential moving average in-place"""
moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32
def kmeans(samples, num_clusters, num_iters=10):
# samples: (N, D), N samples with D dimensions
dim, _ = samples.shape[-1], torch.float32 # Force fp32
means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32
for _ in range(num_iters):
dists = -(
samples.float().pow(2).sum(1, keepdim=True) # (N, 1), ensure fp32
- 2 * samples.float() @ means.t() # (N, num_clusters), ensure fp32
+ means.t().float().pow(2).sum(0, keepdim=True)
) # (1, num_clusters), ensure fp32
# dists: (N, num_clusters)
buckets = dists.max(dim=-1).indices # (N)
bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters)
zero_mask = bins == 0 # (num_clusters)
bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters)
new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32
new_means.scatter_add_(
0, buckets.unsqueeze(1).expand(-1, dim), samples.float()
) # (num_clusters, D), ensure fp32
new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D)
means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D)
# Final cluster assignments for returning cluster sizes
dists = -(
samples.float().pow(2).sum(1, keepdim=True)
- 2 * samples.float() @ means.t()
+ means.t().float().pow(2).sum(0, keepdim=True)
) # (N, num_clusters), ensure fp32
buckets = dists.max(dim=-1).indices # (N)
bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32
return means, bins # (num_clusters, D), (num_clusters)
class VectorQuantize(nn.Module):
def __init__(
self,
input_dim,
codebook_size,
codebook_dim,
commitment=1.0,
decay=0.99, # EMA decay
epsilon=1e-5, # Laplace smoothing epsilon
threshold_ema_dead=2, # Dead code threshold
kmeans_init=True, # Use kmeans initialization
kmeans_iters=10, # Kmeans iterations
rotation_trick=False, # Use rotation trick
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.decay = decay
self.epsilon = epsilon
self.threshold_ema_dead = threshold_ema_dead
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.rotation_trick = rotation_trick
if self.input_dim != self.codebook_dim:
self.in_project = nn.Linear(input_dim, codebook_dim)
self.out_project = nn.Linear(codebook_dim, input_dim)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
# Initialize codebook and EMA buffers
init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
self.register_buffer(
"codebook", init_fn(codebook_size, codebook_dim).float()
) # (codebook_size, D'), ensure fp32
self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1)
self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32
self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32
def ema_update(self, encodings, embed_onehot):
# encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
"""Update codebook using EMA"""
encodings = encodings.float() # Ensure fp32
embed_onehot = embed_onehot.float() # Ensure fp32
cluster_size_new = embed_onehot.sum(0) # (codebook_size)
embed_sum = encodings.t() @ embed_onehot # (D', codebook_size)
# Distributed reduction
if dist.is_initialized():
dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size)
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D')
# Laplace smoothing
cluster_size = (self.cluster_size + self.epsilon) / (
self.cluster_size.sum() + self.codebook_size * self.epsilon
) # (codebook_size)
cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size)
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D')
def replace_dead_codes(self, encodings):
# encodings: (B*T, D')
"""Replace dead codes with random samples from current batch"""
if self.threshold_ema_dead == 0:
return
dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size)
if dead_mask.any():
if dist.is_initialized() and dist.get_rank() == 0:
samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32
print(f"Replace {dead_mask.sum().item()} dead codes")
else:
samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32
# Broadcast samples
if dist.is_initialized():
dist.broadcast(samples, src=0)
self.codebook[dead_mask] = samples[: dead_mask.sum()].to(self.codebook.dtype) # Update dead codes
def init_codebook(self, encodings):
# encodings: (B*T, D')
"""Initialize codebook with k-means and update cluster_size"""
if self.inited.item():
return
if dist.is_initialized() and dist.get_rank() == 0:
embed, cluster_sizes = kmeans(
encodings.float(), self.codebook_size, self.kmeans_iters
) # (codebook_size, D'), (codebook_size), ensure fp32
else:
embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32
cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32
# Broadcast results
if dist.is_initialized():
dist.broadcast(embed, src=0)
dist.broadcast(cluster_sizes, src=0)
self.codebook.copy_(embed) # (codebook_size, D')
self.embed_avg.copy_(embed.clone()) # (codebook_size, D')
self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size)
self.inited.fill_(True)
def forward(self, z):
self = self.to(torch.float32)
z = z.float()
z_e = self.in_project(z).float()
# Rearrange for quantization
encodings = rearrange(z_e, "b t d -> (b t) d").float() # (B*T, D'), ensure fp32
# Initialize codebook if needed
if self.kmeans_init and not self.inited.item():
self.init_codebook(encodings)
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ self.codebook.float().t()
+ self.codebook.float().pow(2).sum(1, keepdim=True).t()
)
indices = (-dist).max(1)[1]
# cosine_similarity = F.cosine_similarity(encodings[None], self.codebook[:, None], dim=-1)
# indices = cosine_similarity.max(dim=0)[1]
indices = rearrange(indices, "(b t) -> b t", b=z.size(0))
z_q = self.decode_code(indices).float()
commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment
if self.training and torch.is_grad_enabled():
embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float()
self.ema_update(encodings, embed_onehot)
self.replace_dead_codes(encodings)
z_q = (z_q - z_e).detach() + z_e
z_q = self.out_project(z_q).float()
return (
z_q,
commit_loss,
torch.tensor(0.0, device=z.device, dtype=torch.float32),
indices,
z_e,
)
def decode_code(self, embed_id): # embed_id: (B, T)
return F.embedding(embed_id, self.codebook).float() # (B, D', T), ensure fp32
# class VectorQuantize(nn.Module):
# """
# Implementation of VQ similar to Karpathy's repo:
# https://github.com/karpathy/deep-vector-quantization
# Additionally uses following tricks from Improved VQGAN
# (https://arxiv.org/pdf/2110.04627.pdf):
# 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
# for improved codebook usage
# 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
# improves training stability
# """
# def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
# super().__init__()
# self.codebook_size = codebook_size
# self.codebook_dim = codebook_dim
# self.in_proj = nn.Linear(input_dim, codebook_dim)
# self.out_proj = nn.Linear(codebook_dim, input_dim)
# self.codebook = nn.Embedding(codebook_size, codebook_dim)
# def forward(self, z: torch.Tensor):
# """
# Args:
# z (torch.Tensor): shape (b, t, d)
# Returns:
# z_q (torch.Tensor): shape (b, t, d)
# commitment_loss (torch.Tensor): shape (1)
# codebook_loss (torch.Tensor): shape (1)
# indices (torch.Tensor): shape (b, t)
# z_e (torch.Tensor): shape (b, t, d)
# """
# # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
# z_e = self.in_proj(z)
# z_q, indices = self.decode_latents(z_e)
# commitment_loss = F.mse_loss(z_e, z_q.detach()) * 0.25
# codebook_loss = F.mse_loss(z_q, z_e.detach())
# z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
# z_q = self.out_proj(z_q)
# return z_q, commitment_loss, codebook_loss, indices, z_e
# def embed_code(self, embed_id):
# return F.embedding(embed_id, self.codebook.weight)
# def decode_code(self, embed_id):
# return self.embed_code(embed_id)
# def decode_latents(self, latents: torch.Tensor):
# codebook = self.codebook.weight
# encodings = rearrange(latents, "b t d -> (b t) d")
# cosine_similarity = F.cosine_similarity(encodings[None], codebook[:, None], dim=-1)
# indices = cosine_similarity.max(dim=0)[1]
# indices = rearrange(indices, "(b t) -> b t", b=latents.size(0))
# # encodings = F.normalize(encodings)
# # codebook = F.normalize(codebook)
# # dist = (
# # encodings.pow(2).sum(1, keepdim=True)
# # - 2 * encodings @ codebook.t()
# # + codebook.pow(2).sum(1, keepdim=True).t()
# # )
# # indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
# z_q = self.decode_code(indices)
# return z_q, indices
class ResidualVectorQuantize(nn.Module):
def __init__(
self,
dim: int = 256,
n_codebooks: int = 4,
codebook_size: int = 512,
codebook_dim: Union[int, list] = 8,
quantizer_dropout: float = 0.25,
commitment: float = 0.25,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead: int = 2,
kmeans_init: bool = True,
kmeans_iters: int = 10,
rotation_trick: bool = False,
):
super().__init__()
if isinstance(codebook_dim, int):
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
self.n_codebooks = n_codebooks
self.codebook_dim = codebook_dim
self.codebook_size = codebook_size
self.quantizers = nn.ModuleList(
[
VectorQuantize(
input_dim=dim,
codebook_size=codebook_size,
codebook_dim=codebook_dim[i],
commitment=commitment,
decay=decay,
epsilon=epsilon,
threshold_ema_dead=threshold_ema_dead,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
rotation_trick=rotation_trick,
)
for i in range(n_codebooks)
]
)
self.quantizer_dropout = quantizer_dropout
def forward(self, z, n_quantizers: int = None):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"""
z_q, residual = 0, z
commitment_loss, codebook_loss = 0, 0
codebook_indices, latents = [], []
if n_quantizers is None:
n_quantizers = self.n_codebooks
if self.training:
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
# Create mask to apply quantizer dropout
mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
z_q = z_q + z_q_i * mask[:, None, None]
residual = residual - z_q_i
# Sum losses
commitment_loss += (commitment_loss_i * mask).mean()
codebook_loss += (codebook_loss_i * mask).mean()
codebook_indices.append(indices_i)
latents.append(z_e_i)
codes = torch.stack(codebook_indices, dim=-1)
latents = torch.cat(latents, dim=1)
return z_q, codes, latents, commitment_loss, codebook_loss
def from_codes(self, codes: torch.Tensor):
"""Given the quantized codes, reconstruct the continuous representation
Parameters
----------
codes : Tensor[B x N x T]
Quantized discrete representation of input
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
"""
z_q = 0.0
z_p = []
n_codebooks = codes.shape[-1]
for i in range(n_codebooks):
z_p_i = self.quantizers[i].decode_code(codes[..., i])
z_p.append(z_p_i)
z_q_i = self.quantizers[i].out_project(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=-1), codes
def from_latents(self, latents: torch.Tensor):
"""Given the unquantized latents, reconstruct the
continuous representation after quantization.
Parameters
----------
latents : Tensor[B x N x T]
Continuous representation of input after projection
Returns
-------
Tensor[B x D x T]
Quantized representation of full-projected space
Tensor[B x D x T]
Quantized representation of latent space
"""
z_q = 0
z_p = []
codes = []
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
for i in range(n_codebooks):
j, k = dims[i], dims[i + 1]
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
z_p.append(z_p_i)
codes.append(codes_i)
z_q_i = self.quantizers[i].out_proj(z_p_i)
z_q = z_q + z_q_i
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
class IndependentVectorQuantize(nn.Module):
def __init__(self, num_codebooks: int = 1, **kwargs):
super().__init__()
self.vector_quantizers = nn.ModuleList([torchVQ(**kwargs) for _ in range(num_codebooks)])
self.num_codebooks = num_codebooks
self.codebook_size = self.vector_quantizers[0].codebook_size
@property
def ema_update(self):
return [vq.ema_update for vq in self.vector_quantizers]
@property
def codebook(self):
return torch.stack([vq.codebook for vq in self.vector_quantizers], dim=0)
@codebook.setter
def codebook(self, codes: List[torch.Tensor]):
assert len(codes) == self.num_codebooks, "Number of codebooks must match"
if not self.separate_codebook_per_head:
codes = rearrange(codes, "... -> 1 ...")
for i, code in enumerate(codes):
self.vector_quantizers[i].codebook.copy_(code)
def get_codes_from_indices(self, indices: torch.Tensor):
codes = list()
for i in range(self.num_codebooks):
codes.append(self.vector_quantizers[i].get_codes_from_indices(indices[..., i : i + 1]))
return torch.cat(codes, dim=-2)
def get_output_from_indices(self, indices: torch.Tensor):
outputs = list()
for i in range(self.num_codebooks):
outputs.append(self.vector_quantizers[i].get_output_from_indices(indices[..., i : i + 1]))
return torch.cat(outputs, dim=-2)
def update_in_place_optimizer(self):
for i in range(self.num_codebooks):
self.vector_quantizers[i].update_in_place_optimizer()
def forward(self, x: torch.Tensor, *args, **kwargs):
assert x.shape[1] == self.num_codebooks
quantized, indices, commit_losses = list(), list(), 0
for i in range(self.num_codebooks):
quantized_i, indices_i, commit_loss_i = self.vector_quantizers[i](x[:, i : i + 1])
quantized.append(quantized_i)
indices.append(indices_i)
commit_losses += commit_loss_i
quantized = torch.cat(quantized, dim=-2)
indices = torch.cat(indices, dim=-1)
return quantized, indices, commit_losses / self.num_codebooks
if __name__ == "__main__":
vq = IndependentVectorQuantize(
num_codebooks=16,
dim=256,
codebook_size=2048,
decay=0.8, # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight=1.0, # the weight on the commitment loss
)
x = torch.randn(1, 16, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)