Cosmos-Tokenizer-Surg / python /simple_sample_vae.py
javirk1's picture
Upload folder using huggingface_hub
86039d9 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from collections import OrderedDict
from einops import rearrange, pack, unpack
_PERSISTENT = True
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
"""
Quantizers
"""
class InvQuantizerJit(nn.Module):
"""Use for decoder_jit to trace quantizer in discrete tokenizer"""
def __init__(self, quantizer):
super().__init__()
self.quantizer = quantizer
def forward(self, indices: torch.Tensor):
codes = self.quantizer.indices_to_codes(indices)
return codes.to(self.quantizer.dtype)
class ChannelSplitFSQ(nn.Module):
"""Quantizer that splits the input into K channels and quantizes each channel independently.
From: https://research.nvidia.com/labs/dir/mamba-tokenizer/
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1.
Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/
vector_quantize_pytorch/finite_scalar_quantization.py
[Copyright (c) 2020 Phil Wang]
"""
def __init__(
self,
levels: list[int],
dim: Optional[int] = None,
K: int = 4,
num_codebooks=1,
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None,
**ignore_kwargs,
):
super().__init__()
self.dtype = ignore_kwargs.get("dtype", torch.bfloat16)
self.persistent = ignore_kwargs.get("persistent_quantizer", _PERSISTENT)
_levels = torch.tensor(levels, dtype=torch.int32)
self.register_buffer("_levels", _levels, persistent=self.persistent)
_basis = torch.cumprod(
torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32
)
self.register_buffer("_basis", _basis, persistent=self.persistent)
self.scale = scale
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
self.K = K
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks * K > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
effective_codebook_dim = self.codebook_dim * num_codebooks * K
self.effective_codebook_dim = effective_codebook_dim
self.dim = default(dim, len(levels) * num_codebooks * K)
has_projections = self.dim != effective_codebook_dim
self.project_in = (
nn.Linear(self.dim, effective_codebook_dim)
if has_projections
else nn.Identity()
)
self.project_out = (
nn.Linear(effective_codebook_dim, self.dim)
if has_projections
else nn.Identity()
)
self.has_projections = has_projections
self.codebook_size = self._levels.prod().item()
def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 + eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
# shift = (offset / half_l).atanh()
shift = offset / half_l
shift = 0.5 * torch.log(1 + shift) - 0.5 * torch.log(1 - shift)
return (z + shift).tanh() * half_l - offset
def quantize(self, z: torch.Tensor) -> torch.Tensor:
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width
def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
half_width = self._levels // 2
return (zhat - half_width) / half_width
def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat).float()
return (zhat * self._basis).sum(dim=-1).to(torch.int32)
def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor:
"""Inverse of `codes_to_indices`.
indices: (b h w k)
"""
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
# Rearrange first:
b, h, w, k = indices.shape
indices = rearrange(indices, "b h w k -> (b k) (h w) 1")
codes_non_centered = (indices // self._basis) % self._levels
codes = self._scale_and_shift_inverse(codes_non_centered) # (b k) n d
codes = rearrange(codes, "(b k) n d -> b n (d k)", k=self.K)
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, "b (h w) c -> b c h w", h=h, w=w)
return codes.to(self.dtype)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
k - number of channels to split into
"""
is_img_or_video = z.ndim >= 4
# standardize image or video into (batch, seq, dimension)
if is_img_or_video:
z = rearrange(z, "b d ... -> b ... d")
z, ps = pack_one(z, "b * d")
assert (
z.shape[-1] == self.dim
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
z = self.project_in(z) # b n (c d k)
z = rearrange(z, "b n (c d k) -> b n c d k", c=self.num_codebooks, k=self.K)
z = rearrange(z, "b n c d k -> (b k) n c d")
codes = self.quantize(z) # (b k) n c d
indices = self.codes_to_indices(codes) # (b k) n c
indices = rearrange(indices, "(b k) n c -> b n (c k)", k=self.K)
codes = rearrange(codes, "(b k) n c d -> b n (c d k)", k=self.K)
out = self.project_out(codes) # b n c, with c = initial dimension
# reconstitute image or video dimensions
if is_img_or_video:
out = unpack_one(out, ps, "b * d")
out = rearrange(out, "b ... d -> b d ...")
indices = unpack_one(indices, ps, "b * c")
dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True))
else:
dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(
1
)
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, "... 1 -> ...")
# indices - discrete codes for each position
# out - continuous reconstruction
# loss - zeros (unused)
# return (indices, out.to(self.dtype), dummy_loss)
return out.to(self.dtype), indices, dummy_loss
"""
VAE
"""
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1) if images else (1,)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
# def forward(self, x):
# return (
# F.normalize(x, dim=(1 if self.channel_first else -1))
# * self.scale
# * self.gamma
# + self.bias
# )
def forward(self, x):
dim = 1 if self.channel_first else -1
rms = x.pow(2).mean(dim=dim, keepdim=True).add(1e-6).rsqrt()
return x * rms * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
# Fix bfloat16 support for nearest neighbor interpolation.
# return super().forward(x.float()).type_as(x)
return super().forward(x)
class ResidualBlock2d(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.residual = nn.Sequential(
RMS_norm(in_dim),
nn.SiLU(),
nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
RMS_norm(out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1),
)
self.shortcut = (
nn.Conv2d(in_dim, out_dim, kernel_size=1)
if in_dim != out_dim
else nn.Identity()
)
def forward(self, x):
return self.residual(x) + self.shortcut(x)
class AttentionBlock2d(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
nn.init.zeros_(self.proj.weight)
# def forward(self, x):
# identity = x
# b, c, h, w = x.size()
# x = self.norm(x)
# q, k, v = (
# self.to_qkv(x)
# .reshape(b, 1, c * 3, -1)
# .permute(0, 1, 3, 2)
# .contiguous()
# .chunk(3, dim=-1)
# )
# x = F.scaled_dot_product_attention(q, k, v)
# x = x.squeeze(1).permute(0, 2, 1).reshape(b, c, h, w)
# x = self.proj(x)
# return x + identity
def forward(self, x):
identity = x
b, c, h, w = x.size()
n_heads = 1 # or c // 64
head_dim = c // n_heads
x = self.norm(x)
qkv = self.to_qkv(x).reshape(b, 3, n_heads, head_dim, h * w)
q, k, v = qkv.unbind(1) # Each: (b, n_heads, head_dim, h*w)
q, k, v = q.transpose(-1, -2), k.transpose(-1, -2), v.transpose(-1, -2)
x = F.scaled_dot_product_attention(q, k, v) # Flash attention
x = x.transpose(-1, -2).reshape(b, c, h, w)
return self.proj(x) + identity
class FlashAttentionBlock2d(nn.Module):
"""Attention block using flash-attn's kernel directly."""
def __init__(self, dim, n_heads=8):
super().__init__()
assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
nn.init.zeros_(self.proj.weight)
def forward(self, x):
from flash_attn import flash_attn_func
identity = x
b, c, h, w = x.size()
x = self.norm(x)
qkv = self.to_qkv(x) # (b, 3*c, h, w)
# flash_attn_func expects (b, seqlen, nheads, headdim)
qkv = qkv.reshape(b, 3, self.n_heads, self.head_dim, h * w)
qkv = qkv.permute(0, 4, 1, 2, 3) # (b, h*w, 3, n_heads, head_dim)
q, k, v = qkv.unbind(2) # each (b, h*w, n_heads, head_dim)
x = flash_attn_func(q, k, v) # (b, h*w, n_heads, head_dim)
x = x.reshape(b, h * w, c).permute(0, 2, 1).reshape(b, c, h, w)
return self.proj(x) + identity
# Custom conv with asymmetric padding
class AsymmetricConv2d(nn.Conv2d):
def forward(self, x):
x = F.pad(x, (0, 1, 0, 1)) # Fused with conv by torch.compile
return super().forward(x)
class Resample2d(nn.Module):
def __init__(self, dim, mode):
assert mode in ("none", "upsample2d", "downsample2d")
super().__init__()
self.mode = mode
if mode == "upsample2d":
self.resample = nn.Sequential(
Upsample(scale_factor=2.0, mode="nearest"),
nn.Conv2d(dim, dim, kernel_size=3, padding=1),
)
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, kernel_size=3, stride=2),
)
else:
self.resample = nn.Identity()
def forward(self, x):
return self.resample(x)
class Encoder2d(nn.Module):
def __init__(
self,
dim=64,
z_dim=4,
dim_mult=[1, 2, 4],
num_res_blocks=2,
dropout=0.0,
attn_scales=[],
patch_size=1,
in_channels=3,
attn_class=AttentionBlock2d,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.patch_size = patch_size
self.in_channels = in_channels
self.patcher = lambda x: patchify(x, patch_size=patch_size)
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
initial_dim = self.in_channels * self.patch_size * self.patch_size
# init block
self.conv1 = nn.Conv2d(initial_dim, dims[0], kernel_size=3, padding=1)
# downsample blocks
downsamples = []
in_dim = dims[0]
for i, out_dim in enumerate(dims[1:]):
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock2d(in_dim, out_dim, dropout))
if scale in self.attn_scales:
downsamples.append(attn_class(out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
downsamples.append(Resample2d(out_dim, mode="downsample2d"))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle and head
self.middle = ResidualBlock2d(out_dim, out_dim, dropout)
self.head = nn.Sequential(
RMS_norm(out_dim),
nn.SiLU(),
nn.Conv2d(out_dim, z_dim * 2, kernel_size=3, padding=1),
)
def forward(self, x):
x = self.patcher(x)
x = self.conv1(x)
x = self.downsamples(x)
x = self.middle(x)
mu, log_var = self.head(x).chunk(2, dim=1)
return mu, log_var
class Decoder2d(nn.Module):
def __init__(
self,
dim=64,
z_dim=4,
dim_mult=[1, 2, 4],
num_res_blocks=2,
dropout=0.0,
attn_scales=[],
out_channels=3,
attn_class=AttentionBlock2d,
patch_size=1,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.out_channels = out_channels
self.patch_size = patch_size
self.unpatcher = lambda x: unpatchify(x, patch_size=patch_size)
# dimensions (mirror of encoder)
base = dim * dim_mult[-1]
dims = [base] + [dim * u for u in dim_mult[::-1]]
scale = 1.0 / (2 ** (len(dim_mult) - 2)) if len(dim_mult) >= 2 else 1.0
output_channels = self.out_channels * self.patch_size * self.patch_size
# init block
self.conv1 = nn.Conv2d(z_dim, dims[0], kernel_size=3, padding=1)
# middle
self.middle = ResidualBlock2d(dims[0], dims[0], dropout)
# upsample blocks
upsamples = []
in_dim = dims[0]
for i, out_dim in enumerate(dims[1:]):
for _ in range(num_res_blocks):
upsamples.append(ResidualBlock2d(in_dim, out_dim, dropout))
if scale in self.attn_scales:
upsamples.append(attn_class(out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
upsamples.append(Resample2d(out_dim, mode="upsample2d"))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# head
self.head = nn.Sequential(
RMS_norm(out_dim),
nn.SiLU(),
nn.Conv2d(out_dim, output_channels, kernel_size=3, padding=1),
)
def forward(self, x):
x = self.conv1(x)
x = self.middle(x)
x = self.upsamples(x)
x = self.head(x)
x = self.unpatcher(x)
return x
class DiscreteImageVAE(nn.Module):
def __init__(
self,
dim=64,
z_dim=4,
dim_mult=[1, 2, 4],
num_res_blocks=2,
dropout=0.0,
attn_scales=[],
in_channels=3,
out_channels=3,
embedding_dim=128,
scale=None,
attn_class=AttentionBlock2d,
patch_size=1,
*args,
**kwargs,
):
"""
Args:
embedding_dim: embedding dimension
scale: scale for the quantizer
"""
super().__init__()
self.z_dim = z_dim
self.encoder = Encoder2d(
dim=dim,
z_dim=z_dim,
dim_mult=dim_mult,
num_res_blocks=num_res_blocks,
dropout=dropout,
attn_scales=attn_scales,
in_channels=in_channels,
attn_class=attn_class,
patch_size=patch_size,
)
self.decoder = Decoder2d(
dim=dim,
z_dim=z_dim,
dim_mult=dim_mult,
num_res_blocks=num_res_blocks,
dropout=dropout,
attn_scales=attn_scales,
out_channels=out_channels,
attn_class=attn_class,
patch_size=patch_size,
)
self.embedding_dim = embedding_dim
kwargs["dim"] = embedding_dim
self.quantizer = ChannelSplitFSQ(**kwargs)
if scale is None:
mean = torch.zeros(self.z_dim, dtype=torch.float)
std = torch.ones(self.z_dim, dtype=torch.float)
self.scale = (mean, std)
else:
self.scale = scale
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
if isinstance(self.scale[0], torch.Tensor):
self.scale = (
self.scale[0].to(*args, **kwargs),
self.scale[1].to(*args, **kwargs),
)
return self
def encode(self, x):
"""
x: A batch of images each with shape [B, C, H, W] in [-1, 1].
Returns:
- (quant_codes, indices, dummy loss) tuples, where:
- quant_codes: continuous
- indices: discrete
- dummy loss: dummy loss for training
- (h, log_var): continuous latent and log_var with shape [embedding_dim, H/scale, W/scale]
"""
h, log_var = self.encoder(x)
# Normalize h mean-var
if isinstance(self.scale[0], torch.Tensor):
h_norm = (h - self.scale[0].view(1, self.z_dim, 1, 1)) / self.scale[1].view(
1, self.z_dim, 1, 1
)
else:
h_norm = (h - self.scale[0]) / self.scale[1]
quant_codes, indices, dummy_loss = self.quantizer(h_norm)
return {
"quant_codes": quant_codes,
"indices": indices,
"dummy_loss": dummy_loss,
"h": h,
"log_var_nouse": log_var,
}
def decode(self, z):
if isinstance(self.scale[0], torch.Tensor):
z = z * self.scale[1].view(1, self.z_dim, 1, 1) + self.scale[0].view(
1, self.z_dim, 1, 1
)
else:
z = z * self.scale[1] + self.scale[0]
return self.decoder(z)
def decode_code(self, code_b):
quant_b = self.quantizer.indices_to_codes(code_b)
return self.decoder(quant_b)
def encoder_jit(self):
class EncoderJitModule(nn.Module):
def __init__(self, encoder: nn.Module):
super().__init__()
self.encoder = encoder
def forward(self, x):
h, _ = self.encoder(x)
return h
return EncoderJitModule(self.encoder)
def quantizer_jit(self):
class QuantizerJitModule(nn.Module):
def __init__(self, quantizer: nn.Module):
super().__init__()
self.quantizer = quantizer
def forward(self, x):
quant_codes, indices, dummy_loss = self.quantizer(x)
return quant_codes, indices, dummy_loss
return QuantizerJitModule(self.quantizer)
def decoder_jit(self):
return nn.Sequential(
OrderedDict(
[
("inv_quant", InvQuantizerJit(self.quantizer)),
("decoder", self.decoder),
]
)
)
if __name__ == "__main__":
import argparse
import os
from PIL import Image
import numpy as np
def load_image(path, size=(1920, 1080)):
if not os.path.exists(path):
print(
f"Image not found at {path}, generating random noise. Warning: The tokenizer might to work properly."
)
return torch.randn(1, 3, size[1], size[0]).to(
"cuda" if torch.cuda.is_available() else "cpu"
)
img = Image.open(path).convert("RGB")
img = np.array(img.resize(size, Image.BICUBIC))[None]
device = "cuda" if torch.cuda.is_available() else "cpu"
img = torch.from_numpy(img).to(device).to(torch.float32)
img = (img / 127.5) - 1.0
img = rearrange(img, "b h w c -> b c h w")
return img
def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray:
"""Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255].
Args:
input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1].
Returns:
A numpy image of layout BxHxWx3, range [0..255], uint8 dtype.
"""
_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max)
if range_min == -1:
input_tensor = (input_tensor.float() + 1.0) / 2.0
ndim = input_tensor.ndim
output_image = input_tensor.clamp(0, 1).cpu().numpy()
output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,))
return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8)
parser = argparse.ArgumentParser(description="Run DiscreteImageVAE inference")
parser.add_argument(
"--checkpoint", type=str, default=None, help="Path to model checkpoint"
)
parser.add_argument(
"--image", type=str, default="assets/00128.png", help="Path to input image"
)
parser.add_argument(
"--output",
type=str,
default="decoded_image_test.png",
help="Path to save output image",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run on",
)
args = parser.parse_args()
cs_discrete8_wan_patch2 = {
"dim": 64,
"z_dim": 16,
"dim_mult": [1, 2, 4],
"patch_size": 2,
"num_res_blocks": 3,
"attn_scales": [],
"dropout": 0.0,
"cls": DiscreteImageVAE,
"z_channels": 256,
"z_factor": 1,
"embedding_dim": 16,
"levels": [8, 8, 8, 5, 5, 5],
"dtype": torch.float,
"model_type": "wan_2_1",
"quantizer_cls": ChannelSplitFSQ,
"num_codebooks": 1,
"K": 2,
}
device = args.device
print(f"Running on {device}")
vae = DiscreteImageVAE(**cs_discrete8_wan_patch2).to(device)
if args.checkpoint and os.path.exists(args.checkpoint):
print(f"Loading checkpoint from {args.checkpoint}")
state_dict = torch.load(args.checkpoint, map_location=device)
vae.load_state_dict(state_dict)
else:
print("No checkpoint provided or found. Running with random initialization.")
vae.eval()
imgs = load_image(args.image)
if imgs.device.type != device:
imgs = imgs.to(device)
# Example using JIT modules
with torch.no_grad():
encoded_sample = vae.encoder_jit()(imgs)
indices = vae.quantizer_jit()(encoded_sample)[1]
decoded_sample = vae.decoder_jit()(indices)
print(f"Encoded shape: {encoded_sample.shape}")
print(f"Indices shape: {indices.shape}")
print(f"Decoded shape: {decoded_sample.shape}")
# Example using regular modules
with torch.no_grad():
encoded_sample_regular = vae.encode(imgs)
indices = encoded_sample_regular["indices"]
decoded_sample_regular = vae.decode_code(indices)
print(f"Quant codes shape: {encoded_sample_regular['quant_codes'].shape}")
print(f"Indices shape: {indices.shape}")
print(f"Decoded regular shape: {decoded_sample_regular.shape}")
assert torch.allclose(
decoded_sample, decoded_sample_regular, atol=1e-5
), "JIT and regular outputs mismatch"
# Save the decoded image
decoded_img = tensor2numpy(decoded_sample)
decoded_img = Image.fromarray(decoded_img[0])
decoded_img.save(args.output)
print(f"Saved decoded image to {args.output}")