|
|
|
|
|
""" |
|
|
Simple VAE reconstruction test using InfinityStar's own code and video. |
|
|
This directly uses InfinityStar's encode_for_raw_features and decode methods. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import imageio |
|
|
from torchvision import transforms |
|
|
from torchvision.utils import make_grid, save_image |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
|
|
|
|
|
|
from infinity.models.videovae.models.wan_bsq_vae import AutoencoderKLCogVideoX |
|
|
from infinity.utils.video_decoder import EncodedVideoDecord |
|
|
import argparse |
|
|
|
|
|
|
|
|
def video_vae_model(vqgan_ckpt, schedule_mode, codebook_dim, global_args=None, test_mode=True): |
|
|
"""Load VAE model (copied from load_vae_bsq_wan_absorb_patchify.py to avoid import issues).""" |
|
|
|
|
|
if global_args is None: |
|
|
|
|
|
class MinimalArgs: |
|
|
semantic_scale_dim = 16 |
|
|
detail_scale_dim = 64 |
|
|
use_learnable_dim_proj = 0 |
|
|
detail_scale_min_tokens = 80 |
|
|
use_feat_proj = 2 |
|
|
semantic_scales = 8 |
|
|
global_args = MinimalArgs() |
|
|
else: |
|
|
|
|
|
if not hasattr(global_args, 'semantic_scale_dim'): |
|
|
global_args.semantic_scale_dim = getattr(global_args, 'semantic_scale_dim', 16) |
|
|
if not hasattr(global_args, 'detail_scale_dim'): |
|
|
global_args.detail_scale_dim = getattr(global_args, 'detail_scale_dim', 64) |
|
|
if not hasattr(global_args, 'use_learnable_dim_proj'): |
|
|
global_args.use_learnable_dim_proj = getattr(global_args, 'use_learnable_dim_proj', 0) |
|
|
if not hasattr(global_args, 'detail_scale_min_tokens'): |
|
|
global_args.detail_scale_min_tokens = getattr(global_args, 'detail_scale_min_tokens', 80) |
|
|
if not hasattr(global_args, 'use_feat_proj'): |
|
|
global_args.use_feat_proj = getattr(global_args, 'use_feat_proj', 2) |
|
|
if not hasattr(global_args, 'semantic_scales'): |
|
|
global_args.semantic_scales = getattr(global_args, 'semantic_scales', 8) |
|
|
|
|
|
args = argparse.Namespace( |
|
|
vqgan_ckpt=vqgan_ckpt, |
|
|
sd_ckpt=None, |
|
|
use_frames=None, |
|
|
inference_type='video', |
|
|
save_prediction=True, |
|
|
save_dir='results', |
|
|
intermediate_tensor=True, |
|
|
save_z=False, |
|
|
save_frames=False, |
|
|
image_recon4video=False, |
|
|
junke_old=False, |
|
|
cal_norm=False, |
|
|
save_samples=None, |
|
|
device='cuda', |
|
|
noise_scale=0.0, |
|
|
max_steps=1000000.0, |
|
|
log_every=1, |
|
|
ckpt_every=1000, |
|
|
default_root_dir='/tmp', |
|
|
compile='no', |
|
|
ema='no', |
|
|
mfu_logging='no', |
|
|
dataloader_init_epoch=-1, |
|
|
context_parallel_size=0, |
|
|
video_ranks_ratio=-1.0, |
|
|
lr=0.0001, |
|
|
beta1=0.9, |
|
|
beta2=0.95, |
|
|
optim_type='Adam', |
|
|
disc_optim_type=None, |
|
|
max_grad_norm=1.0, |
|
|
max_grad_norm_disc=1.0, |
|
|
disable_sch=False, |
|
|
scheduler='no', |
|
|
warmup_steps=0, |
|
|
lr_min=0.0, |
|
|
warmup_lr_init=0.0, |
|
|
patch_size=8, |
|
|
temporal_patch_size=4, |
|
|
embedding_dim=256, |
|
|
codebook_dim=codebook_dim, |
|
|
use_vae=True, |
|
|
eq_scale_prior=0.0, |
|
|
eq_angle_prior=0.0, |
|
|
use_stochastic_depth=False, |
|
|
drop_rate=0.0, |
|
|
schedule_mode=schedule_mode, |
|
|
lr_drop=None, |
|
|
lr_drop_rate=0.1, |
|
|
keep_first_quant=False, |
|
|
keep_last_quant=False, |
|
|
remove_residual_detach=False, |
|
|
use_out_phi=False, |
|
|
use_out_phi_res=False, |
|
|
use_lecam_reg=False, |
|
|
lecam_weight=0.05, |
|
|
perceptual_model='vgg16', |
|
|
base_ch_disc=64, |
|
|
random_flip=False, |
|
|
flip_prob=0.5, |
|
|
flip_mode='stochastic', |
|
|
max_flip_lvl=1, |
|
|
not_load_optimizer=False, |
|
|
use_lecam_reg_zero=False, |
|
|
freeze_encoder=False, |
|
|
rm_downsample=False, |
|
|
random_flip_1lvl=False, |
|
|
flip_lvl_idx=0, |
|
|
drop_when_test=False, |
|
|
drop_lvl_idx=None, |
|
|
drop_lvl_num=0, |
|
|
compute_all_commitment=False, |
|
|
disable_codebook_usage=False, |
|
|
freeze_enc_main=False, |
|
|
freeze_dec_main=False, |
|
|
random_short_schedule=False, |
|
|
short_schedule_prob=0.5, |
|
|
use_bernoulli=False, |
|
|
use_rot_trick=False, |
|
|
disable_flip_prob=0.0, |
|
|
dino_disc=False, |
|
|
quantizer_type='MultiScaleBSQTP', |
|
|
lfq_weight=0.0, |
|
|
entropy_loss_weight=0.1, |
|
|
visu_every=1000, |
|
|
commitment_loss_weight=0.25, |
|
|
bsq_version='v1', |
|
|
diversity_gamma=1, |
|
|
bs1_for1024=False, |
|
|
casual_multi_scale=False, |
|
|
double_compress_t=False, |
|
|
temporal_slicing=False, |
|
|
latent_adjust_type=None, |
|
|
compute_latent_loss=False, |
|
|
latent_loss_weight=0.0, |
|
|
use_raw_latentz=False, |
|
|
last_scale_repeat_n=0, |
|
|
num_lvl_fsq=5, |
|
|
use_midscale_sup=False, |
|
|
midscale_list=[0.5, 0.75, 1.0], |
|
|
use_eq=False, |
|
|
eq_prob=0.5, |
|
|
disc_version='v1', |
|
|
magvit_disc=False, |
|
|
disc_type='patchgan', |
|
|
sigmoid_in_disc=False, |
|
|
activation_in_disc='leaky_relu', |
|
|
apply_blur=False, |
|
|
apply_noise=False, |
|
|
dis_warmup_steps=0, |
|
|
dis_lr_multiplier=1.0, |
|
|
dis_minlr_multiplier=False, |
|
|
disc_channels=64, |
|
|
disc_layers=3, |
|
|
discriminator_iter_start=0, |
|
|
disc_pretrain_iter=0, |
|
|
disc_optim_steps=1, |
|
|
disc_warmup=0, |
|
|
disc_pool='no', |
|
|
disc_pool_size=100, |
|
|
disc_temporal_compress='yes', |
|
|
disc_use_blur='yes', |
|
|
disc_stylegan_downsample_base=2, |
|
|
fix_model=['no'], |
|
|
recon_loss_type='l1', |
|
|
image_gan_weight=1.0, |
|
|
video_gan_weight=1.0, |
|
|
image_disc_weight=0.0, |
|
|
video_disc_weight=0.0, |
|
|
vf_weight=0.0, |
|
|
vf_weight_approx=-1, |
|
|
vf_distmat_margin=0.25, |
|
|
vf_cos_margin=0.5, |
|
|
temporal_alignment=None, |
|
|
l1_weight=4.0, |
|
|
gan_feat_weight=0.0, |
|
|
lpips_model='vgg', |
|
|
perceptual_weight=0.0, |
|
|
video_perceptual_weight=None, |
|
|
video_perceptual_layers=[], |
|
|
kl_weight=0.0, |
|
|
norm_type='rms', |
|
|
disc_loss_type='hinge', |
|
|
gan_image4video='yes', |
|
|
use_checkpoint=False, |
|
|
precision='fp32', |
|
|
encoder_dtype='fp32', |
|
|
decoder_dtype='fp32', |
|
|
upcast_attention='', |
|
|
upcast_tf32=False, |
|
|
tokenizer='cogvideoxd', |
|
|
pretrained=None, |
|
|
pretrained_mode='full', |
|
|
pretrained_ema='no', |
|
|
inflation_pe=False, |
|
|
init_vgen='no', |
|
|
no_init_idis=False, |
|
|
init_idis='keep', |
|
|
init_vdis='no', |
|
|
enable_nan_detector=False, |
|
|
turn_on_profiler=False, |
|
|
profiler_scheduler_wait_steps=10, |
|
|
debug=False, |
|
|
video_logger=False, |
|
|
bytenas='sg', |
|
|
username='bin.yan', |
|
|
seed=1234, |
|
|
vq_to_vae=False, |
|
|
load_not_strict=False, |
|
|
zero=0, |
|
|
bucket_cap_mb=40, |
|
|
manual_gc_interval=10000, |
|
|
data_path=[''], |
|
|
data_type=[''], |
|
|
dataset_list=['wanxvideo-v1'], |
|
|
fps=[-1], |
|
|
dataaug='resizecrop', |
|
|
multi_resolution=False, |
|
|
random_bucket_ratio=0.0, |
|
|
sequence_length=81, |
|
|
resolution=[(480, 864)], |
|
|
resize_bucket=None, |
|
|
resize_bucket_use_self='yes', |
|
|
scaling_aug='no', |
|
|
batch_size=[1], |
|
|
num_workers=0, |
|
|
image_channels=3, |
|
|
in_channels=3, |
|
|
out_channels=3, |
|
|
down_block_types=['CogVideoXDownBlock3D', 'CogVideoXDownBlock3D', 'CogVideoXDownBlock3D', 'CogVideoXDownBlock3D'], |
|
|
down_block_mode='dc', |
|
|
up_block_types=['CogVideoXUpBlock3D', 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D'], |
|
|
up_block_mode='dc', |
|
|
block_out_channels=[96, 192, 384, 384, 384], |
|
|
layers_per_block=2, |
|
|
latent_channels=16, |
|
|
act_fn='silu', |
|
|
norm_eps=1e-06, |
|
|
norm_num_groups=32, |
|
|
spatial_compression_list=[2, 2, 2], |
|
|
temporal_compression_list=[2, 2], |
|
|
sample_height=480, |
|
|
sample_width=720, |
|
|
use_quant_conv=False, |
|
|
use_post_quant_conv=False, |
|
|
down_layer='3d-dc', |
|
|
down_norm=True, |
|
|
up_layer='3d-dc', |
|
|
up_norm=True, |
|
|
pad_mode='constant', |
|
|
dropout_z=0.0, |
|
|
flux_weight=0, |
|
|
cycle_weight=0, |
|
|
cycle_feat_weight=0, |
|
|
cycle_gan_weight=0, |
|
|
cycle_loop=0, |
|
|
cycle_norm='no', |
|
|
cycle_deterministic='no', |
|
|
cycle_kl_weight=0, |
|
|
z_drop=0.0, |
|
|
intermediate_tensor_dir='/tmp', |
|
|
codebook_dim_low=codebook_dim//4, |
|
|
freeze_decoder=False, |
|
|
semantic_scale_dim=global_args.semantic_scale_dim, |
|
|
detail_scale_dim=global_args.detail_scale_dim, |
|
|
use_learnable_dim_proj=global_args.use_learnable_dim_proj, |
|
|
detail_scale_min_tokens=global_args.detail_scale_min_tokens, |
|
|
use_feat_proj=global_args.use_feat_proj, |
|
|
semantic_scales=global_args.semantic_scales, |
|
|
use_multi_scale=0, |
|
|
quant_not_rely_256=0, |
|
|
semantic_num_lvl=2, |
|
|
detail_num_lvl=2, |
|
|
) |
|
|
|
|
|
vae = AutoencoderKLCogVideoX(args) |
|
|
state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True) |
|
|
if args.ema == "yes": |
|
|
print("testing ema weights") |
|
|
vae.load_state_dict(state_dict["ema"], strict=False) |
|
|
else: |
|
|
vae.load_state_dict(state_dict["vae"], strict=False) |
|
|
|
|
|
vae.enable_slicing() |
|
|
if test_mode: |
|
|
vae.eval() |
|
|
[p.requires_grad_(False) for p in vae.parameters()] |
|
|
return vae |
|
|
|
|
|
|
|
|
def transform(pil_img, tgt_h, tgt_w): |
|
|
"""Transform PIL image to tensor, resizing and center cropping (same as run_infinity.py). |
|
|
Returns tensor in [-1, 1] range. |
|
|
""" |
|
|
import PIL.Image as PImage |
|
|
from torchvision.transforms.functional import to_tensor |
|
|
width, height = pil_img.size |
|
|
if width / height <= tgt_w / tgt_h: |
|
|
resized_width = tgt_w |
|
|
resized_height = int(tgt_w / (width / height)) |
|
|
else: |
|
|
resized_height = tgt_h |
|
|
resized_width = int((width / height) * tgt_h) |
|
|
pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS) |
|
|
|
|
|
arr = np.array(pil_img) |
|
|
crop_y = (arr.shape[0] - tgt_h) // 2 |
|
|
crop_x = (arr.shape[1] - tgt_w) // 2 |
|
|
im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w]) |
|
|
|
|
|
return im.add(im).add_(-1) |
|
|
|
|
|
|
|
|
|
|
|
class SimpleArgs: |
|
|
def __init__(self): |
|
|
self.vae_path = "" |
|
|
self.vae_type = 18 |
|
|
self.videovae = 10 |
|
|
self.device = 'cuda' |
|
|
self.encoder_dtype = 'float32' |
|
|
self.decoder_dtype = 'float32' |
|
|
|
|
|
|
|
|
|
|
|
self.semantic_scale_dim = 16 |
|
|
self.detail_scale_dim = 64 |
|
|
self.use_learnable_dim_proj = 0 |
|
|
self.detail_scale_min_tokens = 80 |
|
|
self.use_feat_proj = 2 |
|
|
self.semantic_scales = 8 |
|
|
|
|
|
|
|
|
def add_text_to_image(image_tensor, text, position=(10, 30)): |
|
|
""" |
|
|
Add text label to an image tensor. |
|
|
|
|
|
Args: |
|
|
image_tensor: Image tensor [C, H, W] in [0, 1] |
|
|
text: Text to add |
|
|
position: (x, y) position for text |
|
|
Returns: |
|
|
Image tensor with text [C, H, W] |
|
|
""" |
|
|
|
|
|
image_np = image_tensor.permute(1, 2, 0).cpu().numpy() |
|
|
image_np = np.clip(image_np, 0, 1) |
|
|
image_np = (image_np * 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(image_np) |
|
|
|
|
|
|
|
|
from PIL import ImageDraw, ImageFont |
|
|
draw = ImageDraw.Draw(pil_image) |
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24) |
|
|
except: |
|
|
try: |
|
|
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
|
|
|
x, y = position |
|
|
|
|
|
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]: |
|
|
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0)) |
|
|
|
|
|
draw.text((x, y), text, font=font, fill=(255, 255, 255)) |
|
|
|
|
|
|
|
|
image_tensor = transforms.ToTensor()(pil_image) |
|
|
return image_tensor |
|
|
|
|
|
|
|
|
def create_comparison_grid(original, reconstructed, output_path, nrow=4): |
|
|
""" |
|
|
Create a grid image comparing original and reconstructed frames. |
|
|
|
|
|
Args: |
|
|
original: Original video tensor [C, F, H, W] |
|
|
reconstructed: Reconstructed video tensor [C, F, H, W] |
|
|
output_path: Path to save the grid image |
|
|
nrow: Number of frames per row |
|
|
""" |
|
|
|
|
|
F = min(original.shape[1], reconstructed.shape[1]) |
|
|
|
|
|
|
|
|
num_frames_to_show = min(8, F) |
|
|
frame_indices = np.linspace(0, F - 1, num_frames_to_show, dtype=int) |
|
|
|
|
|
frames_list = [] |
|
|
for idx in frame_indices: |
|
|
|
|
|
orig_frame = original[:, idx, :, :].clone() |
|
|
orig_frame = add_text_to_image(orig_frame, "Original", position=(10, 10)) |
|
|
frames_list.append(orig_frame) |
|
|
|
|
|
|
|
|
recon_frame = reconstructed[:, idx, :, :].clone() |
|
|
recon_frame = add_text_to_image(recon_frame, "Reconstructed", position=(10, 10)) |
|
|
frames_list.append(recon_frame) |
|
|
|
|
|
|
|
|
frames_tensor = torch.stack(frames_list, dim=0) |
|
|
grid = make_grid(frames_tensor, nrow=nrow * 2, padding=2, pad_value=1.0) |
|
|
|
|
|
save_image(grid, output_path) |
|
|
print(f"Saved comparison grid to: {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
video_path = "data/infinitystar_toy_data/videos/e06b8ca5dbc6.mp4" |
|
|
if not os.path.exists(video_path): |
|
|
print(f"Video not found: {video_path}") |
|
|
print("Please run from InfinityStar root directory") |
|
|
return |
|
|
|
|
|
|
|
|
vae_path = "/mnt/Meissonic/InfinityStar/infinitystar_videovae.pth" |
|
|
if not os.path.exists(vae_path): |
|
|
print(f"VAE not found: {vae_path}") |
|
|
return |
|
|
|
|
|
print("=" * 80) |
|
|
print("Loading VAE using InfinityStar's video_vae_model...") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
schedule_mode = "dynamic" |
|
|
codebook_dim = 18 |
|
|
|
|
|
print(f"Loading VAE from: {vae_path}") |
|
|
print(f" schedule_mode: {schedule_mode}") |
|
|
print(f" codebook_dim: {codebook_dim}") |
|
|
print(f" videovae: 10 (absorb patchify)") |
|
|
|
|
|
|
|
|
args = SimpleArgs() |
|
|
args.vae_path = vae_path |
|
|
args.vae_type = 18 |
|
|
args.videovae = 10 |
|
|
|
|
|
|
|
|
|
|
|
print(f" semantic_scale_dim: {args.semantic_scale_dim}") |
|
|
print(f" detail_scale_dim: {args.detail_scale_dim}") |
|
|
print(f" use_feat_proj: {args.use_feat_proj}") |
|
|
print(f" semantic_scales: {args.semantic_scales}") |
|
|
|
|
|
|
|
|
vae = video_vae_model(vae_path, schedule_mode, codebook_dim, global_args=args, test_mode=True) |
|
|
vae = vae.float().to('cuda') |
|
|
vae.eval() |
|
|
[p.requires_grad_(False) for p in vae.parameters()] |
|
|
|
|
|
print(f"VAE loaded: {type(vae)}") |
|
|
print(f" Device: {next(vae.parameters()).device}") |
|
|
print(f" Dtype: {next(vae.parameters()).dtype}") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Loading video using InfinityStar's EncodedVideoDecord...") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
video = EncodedVideoDecord(video_path, os.path.basename(video_path), num_threads=0) |
|
|
duration = video._duration |
|
|
print(f"Video duration: {duration:.2f} seconds") |
|
|
|
|
|
|
|
|
num_frames = 81 |
|
|
raw_video, _ = video.get_clip(0, 5, num_frames) |
|
|
print(f"Loaded {len(raw_video)} frames") |
|
|
|
|
|
|
|
|
|
|
|
tgt_h, tgt_w = 384,672 |
|
|
video_T3HW = [transform(Image.fromarray(frame).convert("RGB"), tgt_h, tgt_w) for frame in raw_video] |
|
|
video_T3HW = torch.stack(video_T3HW, 0) |
|
|
video_bcthw = video_T3HW.permute(1, 0, 2, 3).unsqueeze(0) |
|
|
|
|
|
print(f"Video tensor shape: {video_bcthw.shape}") |
|
|
print(f"Video tensor range: [{video_bcthw.min():.3f}, {video_bcthw.max():.3f}]") |
|
|
|
|
|
|
|
|
if video_bcthw.min() >= 0 and video_bcthw.max() <= 1.0: |
|
|
print("Video is in [0, 1], converting to [-1, 1] for VAE") |
|
|
video_bcthw = video_bcthw * 2.0 - 1.0 |
|
|
elif video_bcthw.min() < 0: |
|
|
print("Video is already in [-1, 1]") |
|
|
|
|
|
video_bcthw = video_bcthw.cuda() |
|
|
print(f"Video for VAE range: [{video_bcthw.min():.3f}, {video_bcthw.max():.3f}]") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Encoding using vae.encode_for_raw_features (InfinityStar's method)...") |
|
|
print("=" * 80) |
|
|
print("Note: This is a VQ-VAE (Vector Quantized VAE) with quantizer.") |
|
|
print(" encode_for_raw_features returns continuous latent (for transformer training).") |
|
|
print(" We will use quantizer to get discrete codes (indices).") |
|
|
print("=" * 80) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
raw_features, _, _ = vae.encode_for_raw_features( |
|
|
video_bcthw, |
|
|
scale_schedule=None, |
|
|
slice=True |
|
|
) |
|
|
print(f"Continuous latent shape: {raw_features.shape}") |
|
|
print(f"Continuous latent range: [{raw_features.min():.3f}, {raw_features.max():.3f}]") |
|
|
|
|
|
|
|
|
if hasattr(vae, 'quantizer') and vae.quantizer is not None: |
|
|
print(f"\nQuantizer detected: {type(vae.quantizer).__name__}") |
|
|
print(f"Raw features shape: {raw_features.shape}") |
|
|
print(f"Quantizer schedule_mode: {vae.quantizer.schedule_mode}") |
|
|
|
|
|
B, C, T, H, W = raw_features.shape |
|
|
print(f"Latent resolution: H={H}, W={W}") |
|
|
|
|
|
|
|
|
from infinity.models.videovae.modules.quantizer.multiscale_bsq_tp_absorb_patchify import get_latent2scale_schedule |
|
|
from infinity.models.videovae.utils.dynamic_resolution import predefined_HW_Scales_dynamic |
|
|
|
|
|
print(f"\nSupported resolutions for schedule_mode='{vae.quantizer.schedule_mode}':") |
|
|
if vae.quantizer.schedule_mode == "dynamic": |
|
|
supported_resolutions = sorted(list(predefined_HW_Scales_dynamic.keys())) |
|
|
print(f" {len(supported_resolutions)} resolutions:") |
|
|
for res in supported_resolutions: |
|
|
print(f" - {res}") |
|
|
elif vae.quantizer.schedule_mode == "original": |
|
|
|
|
|
supported_resolutions = [(16, 16), (36, 64), (18, 32), (30, 53), (32, 32), (64, 64)] |
|
|
print(f" {len(supported_resolutions)} resolutions:") |
|
|
for res in supported_resolutions: |
|
|
print(f" - {res}") |
|
|
else: |
|
|
print(f" (Please check quantizer code for mode '{vae.quantizer.schedule_mode}')") |
|
|
supported_resolutions = [] |
|
|
|
|
|
|
|
|
is_supported = False |
|
|
if vae.quantizer.schedule_mode == "dynamic": |
|
|
is_supported = (H, W) in predefined_HW_Scales_dynamic |
|
|
elif vae.quantizer.schedule_mode == "original": |
|
|
is_supported = (H, W) in [(16, 16), (36, 64), (18, 32), (30, 53), (32, 32), (64, 64)] |
|
|
|
|
|
if not is_supported: |
|
|
print(f"\n❌ ERROR: Resolution ({H}, {W}) is NOT supported for schedule_mode='{vae.quantizer.schedule_mode}'") |
|
|
print(f" Please use one of the supported resolutions listed above.") |
|
|
print(f" Or change the video resolution to match a supported one.") |
|
|
print(f"\n To fix this, you can:") |
|
|
print(f" 1. Change video resolution to one of: {supported_resolutions[:5]}...") |
|
|
print(f" 2. Or manually add ({H}, {W}) to predefined_HW_Scales_dynamic") |
|
|
raise ValueError(f"Resolution ({H}, {W}) not supported for schedule_mode='{vae.quantizer.schedule_mode}'. " |
|
|
f"Supported resolutions: {supported_resolutions}") |
|
|
|
|
|
print(f"\n✓ Resolution ({H}, {W}) is supported!") |
|
|
print("Quantizing to get discrete codes (indices)...") |
|
|
print(" Note: Fixed tower_split_index bug in quantizer for non-infinity_video_two_pyramid modes.") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
result = vae.quantizer(raw_features) |
|
|
|
|
|
if isinstance(result, (list, tuple)) and len(result) >= 2: |
|
|
quantized_out, all_indices, all_bit_indices, residual_norm_per_scale, all_losses, var_inputs = result[:6] |
|
|
else: |
|
|
raise ValueError(f"Unexpected return format from quantizer: {type(result)}, length: {len(result) if isinstance(result, (list, tuple)) else 'N/A'}") |
|
|
|
|
|
|
|
|
quantized_latent = quantized_out |
|
|
|
|
|
|
|
|
if isinstance(all_indices, (list, tuple)) and len(all_indices) > 0: |
|
|
discrete_indices = all_indices[0] |
|
|
else: |
|
|
discrete_indices = all_indices |
|
|
|
|
|
if discrete_indices is not None: |
|
|
print(f"✓ Quantization successful!") |
|
|
print(f" Discrete indices shape: {discrete_indices.shape}") |
|
|
print(f" Discrete indices dtype: {discrete_indices.dtype}") |
|
|
print(f" Discrete indices range: [{discrete_indices.min().item()}, {discrete_indices.max().item()}]") |
|
|
unique_count = torch.unique(discrete_indices).numel() |
|
|
print(f" Discrete indices unique values: {unique_count} (codebook size)") |
|
|
|
|
|
print(f" Quantized latent shape: {quantized_latent.shape}") |
|
|
print(f" Quantized latent range: [{quantized_latent.min():.3f}, {quantized_latent.max():.3f}]") |
|
|
|
|
|
latent_to_decode = quantized_latent |
|
|
use_quantized = True |
|
|
except Exception as e: |
|
|
import traceback |
|
|
print(f"\n❌ ERROR: Quantization failed!") |
|
|
print(f" Error: {e}") |
|
|
print(f" Error type: {type(e).__name__}") |
|
|
print(f"\n Full traceback:") |
|
|
print(traceback.format_exc()) |
|
|
raise RuntimeError(f"Quantization failed: {e}. This is required for testing quantizer performance.") from e |
|
|
else: |
|
|
print(" No quantizer found, using continuous latent (VAE mode, not VQ-VAE).") |
|
|
latent_to_decode = raw_features |
|
|
use_quantized = False |
|
|
discrete_indices = None |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Decoding using vae.decode (InfinityStar's method)...") |
|
|
if use_quantized: |
|
|
print(" Using quantized latent (VQ-VAE path with discrete codes)") |
|
|
else: |
|
|
print(" Using continuous latent (VAE path, no quantization)") |
|
|
print("=" * 80) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
reconstructed = vae.decode(latent_to_decode, slice=True) |
|
|
if isinstance(reconstructed, tuple): |
|
|
reconstructed = reconstructed[0] |
|
|
|
|
|
|
|
|
reconstructed = torch.clamp(reconstructed, min=-1, max=1) |
|
|
|
|
|
print(f"Reconstructed shape: {reconstructed.shape}") |
|
|
print(f"Reconstructed range: [{reconstructed.min():.3f}, {reconstructed.max():.3f}]") |
|
|
|
|
|
|
|
|
original_01 = (video_bcthw + 1.0) / 2.0 |
|
|
reconstructed_01 = (reconstructed + 1.0) / 2.0 |
|
|
original_01 = torch.clamp(original_01, 0, 1) |
|
|
reconstructed_01 = torch.clamp(reconstructed_01, 0, 1) |
|
|
|
|
|
|
|
|
original_01_video = original_01.squeeze(0) |
|
|
reconstructed_01_video = reconstructed_01.squeeze(0) |
|
|
|
|
|
|
|
|
output_dir = "vae_reconstruction_test" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Creating comparison grid (same format as test_cosmos_vqvae.py)...") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
grid_output_path = os.path.join(output_dir, "comparison_grid.png") |
|
|
create_comparison_grid(original_01_video, reconstructed_01_video, grid_output_path, nrow=4) |
|
|
|
|
|
|
|
|
print("\nSaving comparison video...") |
|
|
video_frames = [] |
|
|
for i in range(min(original_01.shape[2], reconstructed_01.shape[2])): |
|
|
orig_frame = original_01[0, :, i, :, :].permute(1, 2, 0).cpu().numpy() |
|
|
recon_frame = reconstructed_01[0, :, i, :, :].permute(1, 2, 0).cpu().numpy() |
|
|
|
|
|
orig_frame = (orig_frame * 255).astype(np.uint8) |
|
|
recon_frame = (recon_frame * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
side_by_side = np.hstack([orig_frame, recon_frame]).copy() |
|
|
side_by_side = np.ascontiguousarray(side_by_side) |
|
|
cv2.putText(side_by_side, "Original", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
|
|
cv2.putText(side_by_side, "Reconstructed", (tgt_w + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2) |
|
|
|
|
|
video_frames.append(cv2.cvtColor(side_by_side, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
video_path_out = os.path.join(output_dir, "comparison.mp4") |
|
|
imageio.mimsave(video_path_out, video_frames, fps=8) |
|
|
print(f"Saved video: {video_path_out}") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Test complete!") |
|
|
print(f"Results saved to: {output_dir}") |
|
|
print(f" - Comparison grid: {grid_output_path}") |
|
|
print(f" - Comparison video: {video_path_out}") |
|
|
print("=" * 80) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|