VibeToken / vibetoken /quantizer /vector_quantizer.py
APGASU's picture
scripts
7bef20f verified
"""Vector Quantizer for VibeToken.
Simplified for inference-only use. Training-specific features removed.
Reference:
https://github.com/CompVis/taming-transformers
https://github.com/google-research/magvit
"""
from typing import Mapping, Text, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from torch.amp import autocast
class VectorQuantizer(nn.Module):
"""Vector Quantizer module for discrete tokenization.
Converts continuous latent representations to discrete tokens using
a learned codebook.
"""
def __init__(
self,
codebook_size: int = 1024,
token_size: int = 256,
commitment_cost: float = 0.25,
use_l2_norm: bool = False,
):
"""Initialize VectorQuantizer.
Args:
codebook_size: Number of entries in the codebook.
token_size: Dimension of each codebook entry.
commitment_cost: Weight for commitment loss (unused in inference).
use_l2_norm: Whether to L2-normalize embeddings.
"""
super().__init__()
self.codebook_size = codebook_size
self.token_size = token_size
self.commitment_cost = commitment_cost
self.use_l2_norm = use_l2_norm
self.embedding = nn.Embedding(codebook_size, token_size)
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
@autocast('cuda', enabled=False)
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Quantize input tensor.
Args:
z: Input tensor of shape (B, C, H, W).
Returns:
z_quantized: Quantized tensor of shape (B, C, H, W).
result_dict: Dictionary containing min_encoding_indices and losses.
"""
z = z.float()
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = rearrange(z, 'b h w c -> (b h w) c')
if self.use_l2_norm:
z_flattened = nn.functional.normalize(z_flattened, dim=-1)
embedding = nn.functional.normalize(self.embedding.weight, dim=-1)
else:
embedding = self.embedding.weight
# Compute distances to codebook entries
d = (torch.sum(z_flattened**2, dim=1, keepdim=True) +
torch.sum(embedding**2, dim=1) -
2 * torch.einsum('bd,dn->bn', z_flattened, embedding.T))
min_encoding_indices = torch.argmin(d, dim=1)
z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
if self.use_l2_norm:
z_norm = nn.functional.normalize(z, dim=-1)
else:
z_norm = z
# Compute losses (for compatibility, not used in inference)
commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z_norm) ** 2)
codebook_loss = torch.mean((z_quantized - z_norm.detach()) ** 2)
loss = commitment_loss + codebook_loss
# Straight-through estimator: preserve gradients
z_quantized = z_norm + (z_quantized - z_norm).detach()
# Reshape back to original format
z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
result_dict = dict(
quantizer_loss=loss,
commitment_loss=commitment_loss,
codebook_loss=codebook_loss,
min_encoding_indices=min_encoding_indices.view(
z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]
)
)
return z_quantized, result_dict
@autocast('cuda', enabled=False)
def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor:
"""Get codebook entries for given indices.
Args:
indices: Token indices, shape (N,) or (N, vocab_size) for soft indices.
Returns:
Codebook entries, shape (N, token_size).
"""
indices = indices.long()
if len(indices.shape) == 1:
z_quantized = self.embedding(indices)
elif len(indices.shape) == 2:
# Soft indices (weighted sum of embeddings)
z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)
else:
raise NotImplementedError(f"Unsupported indices shape: {indices.shape}")
if self.use_l2_norm:
z_quantized = nn.functional.normalize(z_quantized, dim=-1)
return z_quantized
class DiagonalGaussianDistribution:
"""Diagonal Gaussian distribution for VAE-style quantization.
Used when quantize_mode='vae' instead of discrete VQ.
"""
@autocast('cuda', enabled=False)
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
"""Initialize Gaussian distribution.
Args:
parameters: Tensor of shape (B, 2*C, H, W) containing mean and logvar.
deterministic: If True, sample() returns mean (no noise).
"""
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
@autocast('cuda', enabled=False)
def sample(self) -> torch.Tensor:
"""Sample from the distribution."""
x = self.mean.float() + self.std.float() * torch.randn(
self.mean.shape, device=self.parameters.device
)
return x
@autocast('cuda', enabled=False)
def mode(self) -> torch.Tensor:
"""Return the mode (mean) of the distribution."""
return self.mean
@autocast('cuda', enabled=False)
def kl(self) -> torch.Tensor:
"""Compute KL divergence from standard Gaussian."""
if self.deterministic:
return torch.Tensor([0.0])
return 0.5 * torch.sum(
torch.pow(self.mean.float(), 2) + self.var.float() - 1.0 - self.logvar.float(),
dim=[1, 2]
)