import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from collections import OrderedDict from einops import rearrange, pack, unpack _PERSISTENT = True def patchify(x, patch_size): if patch_size == 1: return x if x.dim() == 4: x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) elif x.dim() == 5: x = rearrange( x, "b c f (h q) (w r) -> b (c r q) f h w", q=patch_size, r=patch_size, ) else: raise ValueError(f"Invalid input shape: {x.shape}") return x def unpatchify(x, patch_size): if patch_size == 1: return x if x.dim() == 4: x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) elif x.dim() == 5: x = rearrange( x, "b (c r q) f h w -> b c f (h q) (w r)", q=patch_size, r=patch_size, ) return x def exists(v): return v is not None def default(*args): for arg in args: if exists(arg): return arg return None def round_ste(z: torch.Tensor) -> torch.Tensor: """Round with straight through gradients.""" zhat = z.round() return z + (zhat - z).detach() def pack_one(t, pattern): return pack([t], pattern) def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] """ Quantizers """ class InvQuantizerJit(nn.Module): """Use for decoder_jit to trace quantizer in discrete tokenizer""" def __init__(self, quantizer): super().__init__() self.quantizer = quantizer def forward(self, indices: torch.Tensor): codes = self.quantizer.indices_to_codes(indices) return codes.to(self.quantizer.dtype) class ChannelSplitFSQ(nn.Module): """Quantizer that splits the input into K channels and quantizes each channel independently. From: https://research.nvidia.com/labs/dir/mamba-tokenizer/ Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 Code adapted from Jax version in Appendix A.1. Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ vector_quantize_pytorch/finite_scalar_quantization.py [Copyright (c) 2020 Phil Wang] """ def __init__( self, levels: list[int], dim: Optional[int] = None, K: int = 4, num_codebooks=1, keep_num_codebooks_dim: Optional[bool] = None, scale: Optional[float] = None, **ignore_kwargs, ): super().__init__() self.dtype = ignore_kwargs.get("dtype", torch.bfloat16) self.persistent = ignore_kwargs.get("persistent_quantizer", _PERSISTENT) _levels = torch.tensor(levels, dtype=torch.int32) self.register_buffer("_levels", _levels, persistent=self.persistent) _basis = torch.cumprod( torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32 ) self.register_buffer("_basis", _basis, persistent=self.persistent) self.scale = scale codebook_dim = len(levels) self.codebook_dim = codebook_dim self.num_codebooks = num_codebooks self.K = K keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks * K > 1) assert not (num_codebooks > 1 and not keep_num_codebooks_dim) self.keep_num_codebooks_dim = keep_num_codebooks_dim effective_codebook_dim = self.codebook_dim * num_codebooks * K self.effective_codebook_dim = effective_codebook_dim self.dim = default(dim, len(levels) * num_codebooks * K) has_projections = self.dim != effective_codebook_dim self.project_in = ( nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() ) self.project_out = ( nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() ) self.has_projections = has_projections self.codebook_size = self._levels.prod().item() def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: """Bound `z`, an array of shape (..., d).""" half_l = (self._levels - 1) * (1 + eps) / 2 offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) # shift = (offset / half_l).atanh() shift = offset / half_l shift = 0.5 * torch.log(1 + shift) - 0.5 * torch.log(1 - shift) return (z + shift).tanh() * half_l - offset def quantize(self, z: torch.Tensor) -> torch.Tensor: """Quantizes z, returns quantized zhat, same shape as z.""" quantized = round_ste(self.bound(z)) half_width = self._levels // 2 # Renormalize to [-1, 1]. return quantized / half_width def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: half_width = self._levels // 2 return (zhat_normalized * half_width) + half_width def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: half_width = self._levels // 2 return (zhat - half_width) / half_width def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: """Converts a `code` to an index in the codebook.""" assert zhat.shape[-1] == self.codebook_dim zhat = self._scale_and_shift(zhat).float() return (zhat * self._basis).sum(dim=-1).to(torch.int32) def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: """Inverse of `codes_to_indices`. indices: (b h w k) """ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) # Rearrange first: b, h, w, k = indices.shape indices = rearrange(indices, "b h w k -> (b k) (h w) 1") codes_non_centered = (indices // self._basis) % self._levels codes = self._scale_and_shift_inverse(codes_non_centered) # (b k) n d codes = rearrange(codes, "(b k) n d -> b n (d k)", k=self.K) if project_out: codes = self.project_out(codes) if is_img_or_video: codes = rearrange(codes, "b (h w) c -> b c h w", h=h, w=w) return codes.to(self.dtype) def forward(self, z: torch.Tensor) -> torch.Tensor: """ einstein notation b - batch n - sequence (or flattened spatial dimensions) d - feature dimension, which is also log2(codebook size) c - number of codebook dim k - number of channels to split into """ is_img_or_video = z.ndim >= 4 # standardize image or video into (batch, seq, dimension) if is_img_or_video: z = rearrange(z, "b d ... -> b ... d") z, ps = pack_one(z, "b * d") assert ( z.shape[-1] == self.dim ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" z = self.project_in(z) # b n (c d k) z = rearrange(z, "b n (c d k) -> b n c d k", c=self.num_codebooks, k=self.K) z = rearrange(z, "b n c d k -> (b k) n c d") codes = self.quantize(z) # (b k) n c d indices = self.codes_to_indices(codes) # (b k) n c indices = rearrange(indices, "(b k) n c -> b n (c k)", k=self.K) codes = rearrange(codes, "(b k) n c d -> b n (c d k)", k=self.K) out = self.project_out(codes) # b n c, with c = initial dimension # reconstitute image or video dimensions if is_img_or_video: out = unpack_one(out, ps, "b * d") out = rearrange(out, "b ... d -> b d ...") indices = unpack_one(indices, ps, "b * c") dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) else: dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze( 1 ) if not self.keep_num_codebooks_dim: indices = rearrange(indices, "... 1 -> ...") # indices - discrete codes for each position # out - continuous reconstruction # loss - zeros (unused) # return (indices, out.to(self.dtype), dummy_loss) return out.to(self.dtype), indices, dummy_loss """ VAE """ class RMS_norm(nn.Module): def __init__(self, dim, channel_first=True, images=True, bias=False): super().__init__() broadcastable_dims = (1, 1) if images else (1,) shape = (dim, *broadcastable_dims) if channel_first else (dim,) self.channel_first = channel_first self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(shape)) self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 # def forward(self, x): # return ( # F.normalize(x, dim=(1 if self.channel_first else -1)) # * self.scale # * self.gamma # + self.bias # ) def forward(self, x): dim = 1 if self.channel_first else -1 rms = x.pow(2).mean(dim=dim, keepdim=True).add(1e-6).rsqrt() return x * rms * self.gamma + self.bias class Upsample(nn.Upsample): def forward(self, x): # Fix bfloat16 support for nearest neighbor interpolation. # return super().forward(x.float()).type_as(x) return super().forward(x) class ResidualBlock2d(nn.Module): def __init__(self, in_dim, out_dim, dropout=0.0): super().__init__() self.in_dim = in_dim self.out_dim = out_dim self.residual = nn.Sequential( RMS_norm(in_dim), nn.SiLU(), nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), RMS_norm(out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), ) self.shortcut = ( nn.Conv2d(in_dim, out_dim, kernel_size=1) if in_dim != out_dim else nn.Identity() ) def forward(self, x): return self.residual(x) + self.shortcut(x) class AttentionBlock2d(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim self.norm = RMS_norm(dim) self.to_qkv = nn.Conv2d(dim, dim * 3, 1) self.proj = nn.Conv2d(dim, dim, 1) nn.init.zeros_(self.proj.weight) # def forward(self, x): # identity = x # b, c, h, w = x.size() # x = self.norm(x) # q, k, v = ( # self.to_qkv(x) # .reshape(b, 1, c * 3, -1) # .permute(0, 1, 3, 2) # .contiguous() # .chunk(3, dim=-1) # ) # x = F.scaled_dot_product_attention(q, k, v) # x = x.squeeze(1).permute(0, 2, 1).reshape(b, c, h, w) # x = self.proj(x) # return x + identity def forward(self, x): identity = x b, c, h, w = x.size() n_heads = 1 # or c // 64 head_dim = c // n_heads x = self.norm(x) qkv = self.to_qkv(x).reshape(b, 3, n_heads, head_dim, h * w) q, k, v = qkv.unbind(1) # Each: (b, n_heads, head_dim, h*w) q, k, v = q.transpose(-1, -2), k.transpose(-1, -2), v.transpose(-1, -2) x = F.scaled_dot_product_attention(q, k, v) # Flash attention x = x.transpose(-1, -2).reshape(b, c, h, w) return self.proj(x) + identity class FlashAttentionBlock2d(nn.Module): """Attention block using flash-attn's kernel directly.""" def __init__(self, dim, n_heads=8): super().__init__() assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}" self.dim = dim self.n_heads = n_heads self.head_dim = dim // n_heads self.norm = RMS_norm(dim) self.to_qkv = nn.Conv2d(dim, dim * 3, 1) self.proj = nn.Conv2d(dim, dim, 1) nn.init.zeros_(self.proj.weight) def forward(self, x): from flash_attn import flash_attn_func identity = x b, c, h, w = x.size() x = self.norm(x) qkv = self.to_qkv(x) # (b, 3*c, h, w) # flash_attn_func expects (b, seqlen, nheads, headdim) qkv = qkv.reshape(b, 3, self.n_heads, self.head_dim, h * w) qkv = qkv.permute(0, 4, 1, 2, 3) # (b, h*w, 3, n_heads, head_dim) q, k, v = qkv.unbind(2) # each (b, h*w, n_heads, head_dim) x = flash_attn_func(q, k, v) # (b, h*w, n_heads, head_dim) x = x.reshape(b, h * w, c).permute(0, 2, 1).reshape(b, c, h, w) return self.proj(x) + identity # Custom conv with asymmetric padding class AsymmetricConv2d(nn.Conv2d): def forward(self, x): x = F.pad(x, (0, 1, 0, 1)) # Fused with conv by torch.compile return super().forward(x) class Resample2d(nn.Module): def __init__(self, dim, mode): assert mode in ("none", "upsample2d", "downsample2d") super().__init__() self.mode = mode if mode == "upsample2d": self.resample = nn.Sequential( Upsample(scale_factor=2.0, mode="nearest"), nn.Conv2d(dim, dim, kernel_size=3, padding=1), ) elif mode == "downsample2d": self.resample = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size=3, stride=2), ) else: self.resample = nn.Identity() def forward(self, x): return self.resample(x) class Encoder2d(nn.Module): def __init__( self, dim=64, z_dim=4, dim_mult=[1, 2, 4], num_res_blocks=2, dropout=0.0, attn_scales=[], patch_size=1, in_channels=3, attn_class=AttentionBlock2d, ): super().__init__() self.dim = dim self.z_dim = z_dim self.dim_mult = dim_mult self.num_res_blocks = num_res_blocks self.attn_scales = attn_scales self.patch_size = patch_size self.in_channels = in_channels self.patcher = lambda x: patchify(x, patch_size=patch_size) # dimensions dims = [dim * u for u in [1] + dim_mult] scale = 1.0 initial_dim = self.in_channels * self.patch_size * self.patch_size # init block self.conv1 = nn.Conv2d(initial_dim, dims[0], kernel_size=3, padding=1) # downsample blocks downsamples = [] in_dim = dims[0] for i, out_dim in enumerate(dims[1:]): for _ in range(num_res_blocks): downsamples.append(ResidualBlock2d(in_dim, out_dim, dropout)) if scale in self.attn_scales: downsamples.append(attn_class(out_dim)) in_dim = out_dim if i != len(dim_mult) - 1: downsamples.append(Resample2d(out_dim, mode="downsample2d")) scale /= 2.0 self.downsamples = nn.Sequential(*downsamples) # middle and head self.middle = ResidualBlock2d(out_dim, out_dim, dropout) self.head = nn.Sequential( RMS_norm(out_dim), nn.SiLU(), nn.Conv2d(out_dim, z_dim * 2, kernel_size=3, padding=1), ) def forward(self, x): x = self.patcher(x) x = self.conv1(x) x = self.downsamples(x) x = self.middle(x) mu, log_var = self.head(x).chunk(2, dim=1) return mu, log_var class Decoder2d(nn.Module): def __init__( self, dim=64, z_dim=4, dim_mult=[1, 2, 4], num_res_blocks=2, dropout=0.0, attn_scales=[], out_channels=3, attn_class=AttentionBlock2d, patch_size=1, ): super().__init__() self.dim = dim self.z_dim = z_dim self.dim_mult = dim_mult self.num_res_blocks = num_res_blocks self.attn_scales = attn_scales self.out_channels = out_channels self.patch_size = patch_size self.unpatcher = lambda x: unpatchify(x, patch_size=patch_size) # dimensions (mirror of encoder) base = dim * dim_mult[-1] dims = [base] + [dim * u for u in dim_mult[::-1]] scale = 1.0 / (2 ** (len(dim_mult) - 2)) if len(dim_mult) >= 2 else 1.0 output_channels = self.out_channels * self.patch_size * self.patch_size # init block self.conv1 = nn.Conv2d(z_dim, dims[0], kernel_size=3, padding=1) # middle self.middle = ResidualBlock2d(dims[0], dims[0], dropout) # upsample blocks upsamples = [] in_dim = dims[0] for i, out_dim in enumerate(dims[1:]): for _ in range(num_res_blocks): upsamples.append(ResidualBlock2d(in_dim, out_dim, dropout)) if scale in self.attn_scales: upsamples.append(attn_class(out_dim)) in_dim = out_dim if i != len(dim_mult) - 1: upsamples.append(Resample2d(out_dim, mode="upsample2d")) scale *= 2.0 self.upsamples = nn.Sequential(*upsamples) # head self.head = nn.Sequential( RMS_norm(out_dim), nn.SiLU(), nn.Conv2d(out_dim, output_channels, kernel_size=3, padding=1), ) def forward(self, x): x = self.conv1(x) x = self.middle(x) x = self.upsamples(x) x = self.head(x) x = self.unpatcher(x) return x class DiscreteImageVAE(nn.Module): def __init__( self, dim=64, z_dim=4, dim_mult=[1, 2, 4], num_res_blocks=2, dropout=0.0, attn_scales=[], in_channels=3, out_channels=3, embedding_dim=128, scale=None, attn_class=AttentionBlock2d, patch_size=1, *args, **kwargs, ): """ Args: embedding_dim: embedding dimension scale: scale for the quantizer """ super().__init__() self.z_dim = z_dim self.encoder = Encoder2d( dim=dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, dropout=dropout, attn_scales=attn_scales, in_channels=in_channels, attn_class=attn_class, patch_size=patch_size, ) self.decoder = Decoder2d( dim=dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, dropout=dropout, attn_scales=attn_scales, out_channels=out_channels, attn_class=attn_class, patch_size=patch_size, ) self.embedding_dim = embedding_dim kwargs["dim"] = embedding_dim self.quantizer = ChannelSplitFSQ(**kwargs) if scale is None: mean = torch.zeros(self.z_dim, dtype=torch.float) std = torch.ones(self.z_dim, dtype=torch.float) self.scale = (mean, std) else: self.scale = scale def to(self, *args, **kwargs): super().to(*args, **kwargs) if isinstance(self.scale[0], torch.Tensor): self.scale = ( self.scale[0].to(*args, **kwargs), self.scale[1].to(*args, **kwargs), ) return self def encode(self, x): """ x: A batch of images each with shape [B, C, H, W] in [-1, 1]. Returns: - (quant_codes, indices, dummy loss) tuples, where: - quant_codes: continuous - indices: discrete - dummy loss: dummy loss for training - (h, log_var): continuous latent and log_var with shape [embedding_dim, H/scale, W/scale] """ h, log_var = self.encoder(x) # Normalize h mean-var if isinstance(self.scale[0], torch.Tensor): h_norm = (h - self.scale[0].view(1, self.z_dim, 1, 1)) / self.scale[1].view( 1, self.z_dim, 1, 1 ) else: h_norm = (h - self.scale[0]) / self.scale[1] quant_codes, indices, dummy_loss = self.quantizer(h_norm) return { "quant_codes": quant_codes, "indices": indices, "dummy_loss": dummy_loss, "h": h, "log_var_nouse": log_var, } def decode(self, z): if isinstance(self.scale[0], torch.Tensor): z = z * self.scale[1].view(1, self.z_dim, 1, 1) + self.scale[0].view( 1, self.z_dim, 1, 1 ) else: z = z * self.scale[1] + self.scale[0] return self.decoder(z) def decode_code(self, code_b): quant_b = self.quantizer.indices_to_codes(code_b) return self.decoder(quant_b) def encoder_jit(self): class EncoderJitModule(nn.Module): def __init__(self, encoder: nn.Module): super().__init__() self.encoder = encoder def forward(self, x): h, _ = self.encoder(x) return h return EncoderJitModule(self.encoder) def quantizer_jit(self): class QuantizerJitModule(nn.Module): def __init__(self, quantizer: nn.Module): super().__init__() self.quantizer = quantizer def forward(self, x): quant_codes, indices, dummy_loss = self.quantizer(x) return quant_codes, indices, dummy_loss return QuantizerJitModule(self.quantizer) def decoder_jit(self): return nn.Sequential( OrderedDict( [ ("inv_quant", InvQuantizerJit(self.quantizer)), ("decoder", self.decoder), ] ) ) if __name__ == "__main__": import argparse import os from PIL import Image import numpy as np def load_image(path, size=(1920, 1080)): if not os.path.exists(path): print( f"Image not found at {path}, generating random noise. Warning: The tokenizer might to work properly." ) return torch.randn(1, 3, size[1], size[0]).to( "cuda" if torch.cuda.is_available() else "cpu" ) img = Image.open(path).convert("RGB") img = np.array(img.resize(size, Image.BICUBIC))[None] device = "cuda" if torch.cuda.is_available() else "cpu" img = torch.from_numpy(img).to(device).to(torch.float32) img = (img / 127.5) - 1.0 img = rearrange(img, "b h w c -> b c h w") return img def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. Args: input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. Returns: A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. """ _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) if range_min == -1: input_tensor = (input_tensor.float() + 1.0) / 2.0 ndim = input_tensor.ndim output_image = input_tensor.clamp(0, 1).cpu().numpy() output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) parser = argparse.ArgumentParser(description="Run DiscreteImageVAE inference") parser.add_argument( "--checkpoint", type=str, default=None, help="Path to model checkpoint" ) parser.add_argument( "--image", type=str, default="assets/00128.png", help="Path to input image" ) parser.add_argument( "--output", type=str, default="decoded_image_test.png", help="Path to save output image", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on", ) args = parser.parse_args() cs_discrete8_wan_patch2 = { "dim": 64, "z_dim": 16, "dim_mult": [1, 2, 4], "patch_size": 2, "num_res_blocks": 3, "attn_scales": [], "dropout": 0.0, "cls": DiscreteImageVAE, "z_channels": 256, "z_factor": 1, "embedding_dim": 16, "levels": [8, 8, 8, 5, 5, 5], "dtype": torch.float, "model_type": "wan_2_1", "quantizer_cls": ChannelSplitFSQ, "num_codebooks": 1, "K": 2, } device = args.device print(f"Running on {device}") vae = DiscreteImageVAE(**cs_discrete8_wan_patch2).to(device) if args.checkpoint and os.path.exists(args.checkpoint): print(f"Loading checkpoint from {args.checkpoint}") state_dict = torch.load(args.checkpoint, map_location=device) vae.load_state_dict(state_dict) else: print("No checkpoint provided or found. Running with random initialization.") vae.eval() imgs = load_image(args.image) if imgs.device.type != device: imgs = imgs.to(device) # Example using JIT modules with torch.no_grad(): encoded_sample = vae.encoder_jit()(imgs) indices = vae.quantizer_jit()(encoded_sample)[1] decoded_sample = vae.decoder_jit()(indices) print(f"Encoded shape: {encoded_sample.shape}") print(f"Indices shape: {indices.shape}") print(f"Decoded shape: {decoded_sample.shape}") # Example using regular modules with torch.no_grad(): encoded_sample_regular = vae.encode(imgs) indices = encoded_sample_regular["indices"] decoded_sample_regular = vae.decode_code(indices) print(f"Quant codes shape: {encoded_sample_regular['quant_codes'].shape}") print(f"Indices shape: {indices.shape}") print(f"Decoded regular shape: {decoded_sample_regular.shape}") assert torch.allclose( decoded_sample, decoded_sample_regular, atol=1e-5 ), "JIT and regular outputs mismatch" # Save the decoded image decoded_img = tensor2numpy(decoded_sample) decoded_img = Image.fromarray(decoded_img[0]) decoded_img.save(args.output) print(f"Saved decoded image to {args.output}")