| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
| from collections import OrderedDict |
| from einops import rearrange, pack, unpack |
|
|
| _PERSISTENT = True |
|
|
|
|
| def patchify(x, patch_size): |
| if patch_size == 1: |
| return x |
| if x.dim() == 4: |
| x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) |
| elif x.dim() == 5: |
| x = rearrange( |
| x, |
| "b c f (h q) (w r) -> b (c r q) f h w", |
| q=patch_size, |
| r=patch_size, |
| ) |
| else: |
| raise ValueError(f"Invalid input shape: {x.shape}") |
|
|
| return x |
|
|
|
|
| def unpatchify(x, patch_size): |
| if patch_size == 1: |
| return x |
|
|
| if x.dim() == 4: |
| x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) |
| elif x.dim() == 5: |
| x = rearrange( |
| x, |
| "b (c r q) f h w -> b c f (h q) (w r)", |
| q=patch_size, |
| r=patch_size, |
| ) |
| return x |
|
|
|
|
| def exists(v): |
| return v is not None |
|
|
|
|
| def default(*args): |
| for arg in args: |
| if exists(arg): |
| return arg |
| return None |
|
|
|
|
| def round_ste(z: torch.Tensor) -> torch.Tensor: |
| """Round with straight through gradients.""" |
| zhat = z.round() |
| return z + (zhat - z).detach() |
|
|
|
|
| def pack_one(t, pattern): |
| return pack([t], pattern) |
|
|
|
|
| def unpack_one(t, ps, pattern): |
| return unpack(t, ps, pattern)[0] |
|
|
|
|
| """ |
| Quantizers |
| """ |
|
|
|
|
| class InvQuantizerJit(nn.Module): |
| """Use for decoder_jit to trace quantizer in discrete tokenizer""" |
|
|
| def __init__(self, quantizer): |
| super().__init__() |
| self.quantizer = quantizer |
|
|
| def forward(self, indices: torch.Tensor): |
| codes = self.quantizer.indices_to_codes(indices) |
| return codes.to(self.quantizer.dtype) |
|
|
|
|
| class ChannelSplitFSQ(nn.Module): |
| """Quantizer that splits the input into K channels and quantizes each channel independently. |
| From: https://research.nvidia.com/labs/dir/mamba-tokenizer/ |
| |
| Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 |
| Code adapted from Jax version in Appendix A.1. |
| Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ |
| vector_quantize_pytorch/finite_scalar_quantization.py |
| [Copyright (c) 2020 Phil Wang] |
| """ |
|
|
| def __init__( |
| self, |
| levels: list[int], |
| dim: Optional[int] = None, |
| K: int = 4, |
| num_codebooks=1, |
| keep_num_codebooks_dim: Optional[bool] = None, |
| scale: Optional[float] = None, |
| **ignore_kwargs, |
| ): |
| super().__init__() |
| self.dtype = ignore_kwargs.get("dtype", torch.bfloat16) |
| self.persistent = ignore_kwargs.get("persistent_quantizer", _PERSISTENT) |
| _levels = torch.tensor(levels, dtype=torch.int32) |
| self.register_buffer("_levels", _levels, persistent=self.persistent) |
|
|
| _basis = torch.cumprod( |
| torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32 |
| ) |
| self.register_buffer("_basis", _basis, persistent=self.persistent) |
|
|
| self.scale = scale |
|
|
| codebook_dim = len(levels) |
| self.codebook_dim = codebook_dim |
| self.num_codebooks = num_codebooks |
|
|
| self.K = K |
|
|
| keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks * K > 1) |
| assert not (num_codebooks > 1 and not keep_num_codebooks_dim) |
| self.keep_num_codebooks_dim = keep_num_codebooks_dim |
|
|
| effective_codebook_dim = self.codebook_dim * num_codebooks * K |
| self.effective_codebook_dim = effective_codebook_dim |
|
|
| self.dim = default(dim, len(levels) * num_codebooks * K) |
|
|
| has_projections = self.dim != effective_codebook_dim |
| self.project_in = ( |
| nn.Linear(self.dim, effective_codebook_dim) |
| if has_projections |
| else nn.Identity() |
| ) |
| self.project_out = ( |
| nn.Linear(effective_codebook_dim, self.dim) |
| if has_projections |
| else nn.Identity() |
| ) |
| self.has_projections = has_projections |
|
|
| self.codebook_size = self._levels.prod().item() |
|
|
| def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: |
| """Bound `z`, an array of shape (..., d).""" |
| half_l = (self._levels - 1) * (1 + eps) / 2 |
| offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) |
| |
| shift = offset / half_l |
| shift = 0.5 * torch.log(1 + shift) - 0.5 * torch.log(1 - shift) |
| return (z + shift).tanh() * half_l - offset |
|
|
| def quantize(self, z: torch.Tensor) -> torch.Tensor: |
| """Quantizes z, returns quantized zhat, same shape as z.""" |
| quantized = round_ste(self.bound(z)) |
| half_width = self._levels // 2 |
| return quantized / half_width |
|
|
| def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: |
| half_width = self._levels // 2 |
| return (zhat_normalized * half_width) + half_width |
|
|
| def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: |
| half_width = self._levels // 2 |
| return (zhat - half_width) / half_width |
|
|
| def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: |
| """Converts a `code` to an index in the codebook.""" |
| assert zhat.shape[-1] == self.codebook_dim |
| zhat = self._scale_and_shift(zhat).float() |
| return (zhat * self._basis).sum(dim=-1).to(torch.int32) |
|
|
| def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: |
| """Inverse of `codes_to_indices`. |
| indices: (b h w k) |
| """ |
| is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) |
|
|
| |
| b, h, w, k = indices.shape |
| indices = rearrange(indices, "b h w k -> (b k) (h w) 1") |
| codes_non_centered = (indices // self._basis) % self._levels |
| codes = self._scale_and_shift_inverse(codes_non_centered) |
|
|
| codes = rearrange(codes, "(b k) n d -> b n (d k)", k=self.K) |
|
|
| if project_out: |
| codes = self.project_out(codes) |
|
|
| if is_img_or_video: |
| codes = rearrange(codes, "b (h w) c -> b c h w", h=h, w=w) |
|
|
| return codes.to(self.dtype) |
|
|
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| """ |
| einstein notation |
| b - batch |
| n - sequence (or flattened spatial dimensions) |
| d - feature dimension, which is also log2(codebook size) |
| c - number of codebook dim |
| k - number of channels to split into |
| """ |
| is_img_or_video = z.ndim >= 4 |
|
|
| |
|
|
| if is_img_or_video: |
| z = rearrange(z, "b d ... -> b ... d") |
| z, ps = pack_one(z, "b * d") |
|
|
| assert ( |
| z.shape[-1] == self.dim |
| ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" |
|
|
| z = self.project_in(z) |
|
|
| z = rearrange(z, "b n (c d k) -> b n c d k", c=self.num_codebooks, k=self.K) |
| z = rearrange(z, "b n c d k -> (b k) n c d") |
|
|
| codes = self.quantize(z) |
|
|
| indices = self.codes_to_indices(codes) |
| indices = rearrange(indices, "(b k) n c -> b n (c k)", k=self.K) |
|
|
| codes = rearrange(codes, "(b k) n c d -> b n (c d k)", k=self.K) |
|
|
| out = self.project_out(codes) |
|
|
| |
|
|
| if is_img_or_video: |
| out = unpack_one(out, ps, "b * d") |
| out = rearrange(out, "b ... d -> b d ...") |
| indices = unpack_one(indices, ps, "b * c") |
| dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) |
| else: |
| dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze( |
| 1 |
| ) |
|
|
| if not self.keep_num_codebooks_dim: |
| indices = rearrange(indices, "... 1 -> ...") |
|
|
| |
| |
| |
| |
| return out.to(self.dtype), indices, dummy_loss |
|
|
|
|
| """ |
| VAE |
| """ |
|
|
|
|
| class RMS_norm(nn.Module): |
|
|
| def __init__(self, dim, channel_first=True, images=True, bias=False): |
| super().__init__() |
| broadcastable_dims = (1, 1) if images else (1,) |
| shape = (dim, *broadcastable_dims) if channel_first else (dim,) |
|
|
| self.channel_first = channel_first |
| self.scale = dim**0.5 |
| self.gamma = nn.Parameter(torch.ones(shape)) |
| self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def forward(self, x): |
| dim = 1 if self.channel_first else -1 |
| rms = x.pow(2).mean(dim=dim, keepdim=True).add(1e-6).rsqrt() |
| return x * rms * self.gamma + self.bias |
|
|
|
|
| class Upsample(nn.Upsample): |
|
|
| def forward(self, x): |
| |
| |
| return super().forward(x) |
|
|
|
|
| class ResidualBlock2d(nn.Module): |
|
|
| def __init__(self, in_dim, out_dim, dropout=0.0): |
| super().__init__() |
| self.in_dim = in_dim |
| self.out_dim = out_dim |
|
|
| self.residual = nn.Sequential( |
| RMS_norm(in_dim), |
| nn.SiLU(), |
| nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), |
| RMS_norm(out_dim), |
| nn.SiLU(), |
| nn.Dropout(dropout), |
| nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), |
| ) |
| self.shortcut = ( |
| nn.Conv2d(in_dim, out_dim, kernel_size=1) |
| if in_dim != out_dim |
| else nn.Identity() |
| ) |
|
|
| def forward(self, x): |
| return self.residual(x) + self.shortcut(x) |
|
|
|
|
| class AttentionBlock2d(nn.Module): |
|
|
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
| self.norm = RMS_norm(dim) |
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) |
| self.proj = nn.Conv2d(dim, dim, 1) |
| nn.init.zeros_(self.proj.weight) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def forward(self, x): |
| identity = x |
| b, c, h, w = x.size() |
| n_heads = 1 |
| head_dim = c // n_heads |
|
|
| x = self.norm(x) |
| qkv = self.to_qkv(x).reshape(b, 3, n_heads, head_dim, h * w) |
| q, k, v = qkv.unbind(1) |
| q, k, v = q.transpose(-1, -2), k.transpose(-1, -2), v.transpose(-1, -2) |
|
|
| x = F.scaled_dot_product_attention(q, k, v) |
| x = x.transpose(-1, -2).reshape(b, c, h, w) |
| return self.proj(x) + identity |
|
|
|
|
| class FlashAttentionBlock2d(nn.Module): |
| """Attention block using flash-attn's kernel directly.""" |
|
|
| def __init__(self, dim, n_heads=8): |
| super().__init__() |
| assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}" |
| self.dim = dim |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
| self.norm = RMS_norm(dim) |
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) |
| self.proj = nn.Conv2d(dim, dim, 1) |
| nn.init.zeros_(self.proj.weight) |
|
|
| def forward(self, x): |
| from flash_attn import flash_attn_func |
|
|
| identity = x |
| b, c, h, w = x.size() |
|
|
| x = self.norm(x) |
| qkv = self.to_qkv(x) |
|
|
| |
| qkv = qkv.reshape(b, 3, self.n_heads, self.head_dim, h * w) |
| qkv = qkv.permute(0, 4, 1, 2, 3) |
| q, k, v = qkv.unbind(2) |
|
|
| x = flash_attn_func(q, k, v) |
| x = x.reshape(b, h * w, c).permute(0, 2, 1).reshape(b, c, h, w) |
|
|
| return self.proj(x) + identity |
|
|
|
|
| |
| class AsymmetricConv2d(nn.Conv2d): |
| def forward(self, x): |
| x = F.pad(x, (0, 1, 0, 1)) |
| return super().forward(x) |
|
|
|
|
| class Resample2d(nn.Module): |
|
|
| def __init__(self, dim, mode): |
| assert mode in ("none", "upsample2d", "downsample2d") |
| super().__init__() |
| self.mode = mode |
|
|
| if mode == "upsample2d": |
| self.resample = nn.Sequential( |
| Upsample(scale_factor=2.0, mode="nearest"), |
| nn.Conv2d(dim, dim, kernel_size=3, padding=1), |
| ) |
| elif mode == "downsample2d": |
| self.resample = nn.Sequential( |
| nn.ZeroPad2d((0, 1, 0, 1)), |
| nn.Conv2d(dim, dim, kernel_size=3, stride=2), |
| ) |
| else: |
| self.resample = nn.Identity() |
|
|
| def forward(self, x): |
| return self.resample(x) |
|
|
|
|
| class Encoder2d(nn.Module): |
|
|
| def __init__( |
| self, |
| dim=64, |
| z_dim=4, |
| dim_mult=[1, 2, 4], |
| num_res_blocks=2, |
| dropout=0.0, |
| attn_scales=[], |
| patch_size=1, |
| in_channels=3, |
| attn_class=AttentionBlock2d, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.z_dim = z_dim |
| self.dim_mult = dim_mult |
| self.num_res_blocks = num_res_blocks |
| self.attn_scales = attn_scales |
| self.patch_size = patch_size |
| self.in_channels = in_channels |
|
|
| self.patcher = lambda x: patchify(x, patch_size=patch_size) |
|
|
| |
| dims = [dim * u for u in [1] + dim_mult] |
| scale = 1.0 |
|
|
| initial_dim = self.in_channels * self.patch_size * self.patch_size |
|
|
| |
| self.conv1 = nn.Conv2d(initial_dim, dims[0], kernel_size=3, padding=1) |
|
|
| |
| downsamples = [] |
| in_dim = dims[0] |
| for i, out_dim in enumerate(dims[1:]): |
| for _ in range(num_res_blocks): |
| downsamples.append(ResidualBlock2d(in_dim, out_dim, dropout)) |
| if scale in self.attn_scales: |
| downsamples.append(attn_class(out_dim)) |
| in_dim = out_dim |
| if i != len(dim_mult) - 1: |
| downsamples.append(Resample2d(out_dim, mode="downsample2d")) |
| scale /= 2.0 |
| self.downsamples = nn.Sequential(*downsamples) |
|
|
| |
| self.middle = ResidualBlock2d(out_dim, out_dim, dropout) |
| self.head = nn.Sequential( |
| RMS_norm(out_dim), |
| nn.SiLU(), |
| nn.Conv2d(out_dim, z_dim * 2, kernel_size=3, padding=1), |
| ) |
|
|
| def forward(self, x): |
| x = self.patcher(x) |
| x = self.conv1(x) |
| x = self.downsamples(x) |
| x = self.middle(x) |
| mu, log_var = self.head(x).chunk(2, dim=1) |
| return mu, log_var |
|
|
|
|
| class Decoder2d(nn.Module): |
|
|
| def __init__( |
| self, |
| dim=64, |
| z_dim=4, |
| dim_mult=[1, 2, 4], |
| num_res_blocks=2, |
| dropout=0.0, |
| attn_scales=[], |
| out_channels=3, |
| attn_class=AttentionBlock2d, |
| patch_size=1, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.z_dim = z_dim |
| self.dim_mult = dim_mult |
| self.num_res_blocks = num_res_blocks |
| self.attn_scales = attn_scales |
| self.out_channels = out_channels |
| self.patch_size = patch_size |
|
|
| self.unpatcher = lambda x: unpatchify(x, patch_size=patch_size) |
|
|
| |
| base = dim * dim_mult[-1] |
| dims = [base] + [dim * u for u in dim_mult[::-1]] |
| scale = 1.0 / (2 ** (len(dim_mult) - 2)) if len(dim_mult) >= 2 else 1.0 |
| output_channels = self.out_channels * self.patch_size * self.patch_size |
|
|
| |
| self.conv1 = nn.Conv2d(z_dim, dims[0], kernel_size=3, padding=1) |
|
|
| |
| self.middle = ResidualBlock2d(dims[0], dims[0], dropout) |
|
|
| |
| upsamples = [] |
| in_dim = dims[0] |
| for i, out_dim in enumerate(dims[1:]): |
| for _ in range(num_res_blocks): |
| upsamples.append(ResidualBlock2d(in_dim, out_dim, dropout)) |
| if scale in self.attn_scales: |
| upsamples.append(attn_class(out_dim)) |
| in_dim = out_dim |
| if i != len(dim_mult) - 1: |
| upsamples.append(Resample2d(out_dim, mode="upsample2d")) |
| scale *= 2.0 |
| self.upsamples = nn.Sequential(*upsamples) |
|
|
| |
| self.head = nn.Sequential( |
| RMS_norm(out_dim), |
| nn.SiLU(), |
| nn.Conv2d(out_dim, output_channels, kernel_size=3, padding=1), |
| ) |
|
|
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.middle(x) |
| x = self.upsamples(x) |
| x = self.head(x) |
| x = self.unpatcher(x) |
| return x |
|
|
|
|
| class DiscreteImageVAE(nn.Module): |
|
|
| def __init__( |
| self, |
| dim=64, |
| z_dim=4, |
| dim_mult=[1, 2, 4], |
| num_res_blocks=2, |
| dropout=0.0, |
| attn_scales=[], |
| in_channels=3, |
| out_channels=3, |
| embedding_dim=128, |
| scale=None, |
| attn_class=AttentionBlock2d, |
| patch_size=1, |
| *args, |
| **kwargs, |
| ): |
| """ |
| Args: |
| embedding_dim: embedding dimension |
| scale: scale for the quantizer |
| """ |
| super().__init__() |
| self.z_dim = z_dim |
| self.encoder = Encoder2d( |
| dim=dim, |
| z_dim=z_dim, |
| dim_mult=dim_mult, |
| num_res_blocks=num_res_blocks, |
| dropout=dropout, |
| attn_scales=attn_scales, |
| in_channels=in_channels, |
| attn_class=attn_class, |
| patch_size=patch_size, |
| ) |
| self.decoder = Decoder2d( |
| dim=dim, |
| z_dim=z_dim, |
| dim_mult=dim_mult, |
| num_res_blocks=num_res_blocks, |
| dropout=dropout, |
| attn_scales=attn_scales, |
| out_channels=out_channels, |
| attn_class=attn_class, |
| patch_size=patch_size, |
| ) |
| self.embedding_dim = embedding_dim |
|
|
| kwargs["dim"] = embedding_dim |
| self.quantizer = ChannelSplitFSQ(**kwargs) |
| if scale is None: |
| mean = torch.zeros(self.z_dim, dtype=torch.float) |
| std = torch.ones(self.z_dim, dtype=torch.float) |
| self.scale = (mean, std) |
| else: |
| self.scale = scale |
|
|
| def to(self, *args, **kwargs): |
| super().to(*args, **kwargs) |
| if isinstance(self.scale[0], torch.Tensor): |
| self.scale = ( |
| self.scale[0].to(*args, **kwargs), |
| self.scale[1].to(*args, **kwargs), |
| ) |
| return self |
|
|
| def encode(self, x): |
| """ |
| x: A batch of images each with shape [B, C, H, W] in [-1, 1]. |
| Returns: |
| - (quant_codes, indices, dummy loss) tuples, where: |
| - quant_codes: continuous |
| - indices: discrete |
| - dummy loss: dummy loss for training |
| - (h, log_var): continuous latent and log_var with shape [embedding_dim, H/scale, W/scale] |
| """ |
| h, log_var = self.encoder(x) |
| |
| if isinstance(self.scale[0], torch.Tensor): |
| h_norm = (h - self.scale[0].view(1, self.z_dim, 1, 1)) / self.scale[1].view( |
| 1, self.z_dim, 1, 1 |
| ) |
| else: |
| h_norm = (h - self.scale[0]) / self.scale[1] |
|
|
| quant_codes, indices, dummy_loss = self.quantizer(h_norm) |
| return { |
| "quant_codes": quant_codes, |
| "indices": indices, |
| "dummy_loss": dummy_loss, |
| "h": h, |
| "log_var_nouse": log_var, |
| } |
|
|
| def decode(self, z): |
| if isinstance(self.scale[0], torch.Tensor): |
| z = z * self.scale[1].view(1, self.z_dim, 1, 1) + self.scale[0].view( |
| 1, self.z_dim, 1, 1 |
| ) |
| else: |
| z = z * self.scale[1] + self.scale[0] |
| return self.decoder(z) |
|
|
| def decode_code(self, code_b): |
| quant_b = self.quantizer.indices_to_codes(code_b) |
| return self.decoder(quant_b) |
|
|
| def encoder_jit(self): |
| class EncoderJitModule(nn.Module): |
| def __init__(self, encoder: nn.Module): |
| super().__init__() |
| self.encoder = encoder |
|
|
| def forward(self, x): |
| h, _ = self.encoder(x) |
| return h |
|
|
| return EncoderJitModule(self.encoder) |
|
|
| def quantizer_jit(self): |
| class QuantizerJitModule(nn.Module): |
| def __init__(self, quantizer: nn.Module): |
| super().__init__() |
| self.quantizer = quantizer |
|
|
| def forward(self, x): |
| quant_codes, indices, dummy_loss = self.quantizer(x) |
| return quant_codes, indices, dummy_loss |
|
|
| return QuantizerJitModule(self.quantizer) |
|
|
| def decoder_jit(self): |
| return nn.Sequential( |
| OrderedDict( |
| [ |
| ("inv_quant", InvQuantizerJit(self.quantizer)), |
| ("decoder", self.decoder), |
| ] |
| ) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| import os |
| from PIL import Image |
| import numpy as np |
|
|
| def load_image(path, size=(1920, 1080)): |
| if not os.path.exists(path): |
| print( |
| f"Image not found at {path}, generating random noise. Warning: The tokenizer might to work properly." |
| ) |
| return torch.randn(1, 3, size[1], size[0]).to( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
|
|
| img = Image.open(path).convert("RGB") |
| img = np.array(img.resize(size, Image.BICUBIC))[None] |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| img = torch.from_numpy(img).to(device).to(torch.float32) |
| img = (img / 127.5) - 1.0 |
| img = rearrange(img, "b h w c -> b c h w") |
| return img |
|
|
| def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: |
| """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. |
| |
| Args: |
| input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. |
| Returns: |
| A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. |
| """ |
| _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) |
| if range_min == -1: |
| input_tensor = (input_tensor.float() + 1.0) / 2.0 |
| ndim = input_tensor.ndim |
| output_image = input_tensor.clamp(0, 1).cpu().numpy() |
| output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) |
| return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) |
|
|
| parser = argparse.ArgumentParser(description="Run DiscreteImageVAE inference") |
| parser.add_argument( |
| "--checkpoint", type=str, default=None, help="Path to model checkpoint" |
| ) |
| parser.add_argument( |
| "--image", type=str, default="assets/00128.png", help="Path to input image" |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| default="decoded_image_test.png", |
| help="Path to save output image", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device to run on", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| cs_discrete8_wan_patch2 = { |
| "dim": 64, |
| "z_dim": 16, |
| "dim_mult": [1, 2, 4], |
| "patch_size": 2, |
| "num_res_blocks": 3, |
| "attn_scales": [], |
| "dropout": 0.0, |
| "cls": DiscreteImageVAE, |
| "z_channels": 256, |
| "z_factor": 1, |
| "embedding_dim": 16, |
| "levels": [8, 8, 8, 5, 5, 5], |
| "dtype": torch.float, |
| "model_type": "wan_2_1", |
| "quantizer_cls": ChannelSplitFSQ, |
| "num_codebooks": 1, |
| "K": 2, |
| } |
|
|
| device = args.device |
| print(f"Running on {device}") |
|
|
| vae = DiscreteImageVAE(**cs_discrete8_wan_patch2).to(device) |
|
|
| if args.checkpoint and os.path.exists(args.checkpoint): |
| print(f"Loading checkpoint from {args.checkpoint}") |
| state_dict = torch.load(args.checkpoint, map_location=device) |
| vae.load_state_dict(state_dict) |
| else: |
| print("No checkpoint provided or found. Running with random initialization.") |
|
|
| vae.eval() |
|
|
| imgs = load_image(args.image) |
| if imgs.device.type != device: |
| imgs = imgs.to(device) |
|
|
| |
| with torch.no_grad(): |
| encoded_sample = vae.encoder_jit()(imgs) |
| indices = vae.quantizer_jit()(encoded_sample)[1] |
| decoded_sample = vae.decoder_jit()(indices) |
|
|
| print(f"Encoded shape: {encoded_sample.shape}") |
| print(f"Indices shape: {indices.shape}") |
| print(f"Decoded shape: {decoded_sample.shape}") |
|
|
| |
| with torch.no_grad(): |
| encoded_sample_regular = vae.encode(imgs) |
| indices = encoded_sample_regular["indices"] |
| decoded_sample_regular = vae.decode_code(indices) |
|
|
| print(f"Quant codes shape: {encoded_sample_regular['quant_codes'].shape}") |
| print(f"Indices shape: {indices.shape}") |
| print(f"Decoded regular shape: {decoded_sample_regular.shape}") |
|
|
| assert torch.allclose( |
| decoded_sample, decoded_sample_regular, atol=1e-5 |
| ), "JIT and regular outputs mismatch" |
|
|
| |
| decoded_img = tensor2numpy(decoded_sample) |
| decoded_img = Image.fromarray(decoded_img[0]) |
| decoded_img.save(args.output) |
| print(f"Saved decoded image to {args.output}") |
|
|