#!/usr/bin/env python3 # Copyright (c) 2025 FoundationVision # SPDX-License-Identifier: MIT """ Test script for InfinityStar VQ-VAE performance. This script: 1. Loads a video from the training dataset (same as test_cosmos_vqvae.py) 2. Encodes it using InfinityStar VAE 3. Decodes it back 4. Computes metrics (PSNR, SSIM, MSE) - same as test_cosmos_vqvae.py 5. Creates a side-by-side comparison video 6. Saves the results """ import os import sys import torch import numpy as np from PIL import Image import cv2 from torchvision import transforms from torchvision.utils import make_grid, save_image # Add Meissonic to path FIRST to avoid importing InfinityStar's train.py meissonic_path = "/mnt/Meissonic" #os.path.join(os.path.dirname(os.path.dirname(__file__)), "Meissonic") if os.path.exists(meissonic_path): sys.path.insert(0, meissonic_path) # Also add Meissonic's train directory to path meissonic_train_path = os.path.join(meissonic_path, "train") if os.path.exists(meissonic_train_path): sys.path.insert(0, meissonic_train_path) # Add InfinityStar to path (but after Meissonic) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # Avoid importing arg_util which depends on 'tap' package (has Python 2 syntax issues) # Create a simple Args class instead class SimpleArgs: """Simple replacement for Args class to avoid tap dependency.""" def __init__(self): # Quantizer-related fields: MUST match the checkpoint config self.semantic_scale_dim = 16 self.detail_scale_dim = 64 self.use_learnable_dim_proj = 0 self.detail_scale_min_tokens = 80 # IMPORTANT: for infinitystar_videovae.pth this must be 2, # otherwise the quantizer takes a different feature projection path # and reconstructions become very blurry. self.use_feat_proj = 2 self.semantic_scales = 8 # VAE-specific attributes self.vae_path = "" self.vae_type = 18 self.videovae = 10 # Import load_visual_tokenizer directly, avoiding arg_util import import sys import importlib.util # Load load_visual_tokenizer function without importing arg_util def load_visual_tokenizer_safe(args, device=None): """Load visual tokenizer without importing arg_util.""" if not device: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if args.vae_type in [8,12,14,16,18,20,24,32,48,64,128]: schedule_mode = "dynamic" codebook_dim = args.vae_type print(f'Load VAE from {args.vae_path}') if args.videovae == 10: # absorb patchify from infinity.models.videovae.models.load_vae_bsq_wan_absorb_patchify import video_vae_model vae_local = video_vae_model(args.vae_path, schedule_mode, codebook_dim, global_args=args, test_mode=True).to(device) else: raise ValueError(f"vae_type {args.vae_type} not supported") else: raise ValueError(f"vae_type {args.vae_type} not supported") return vae_local # Import dataset utilities from Meissonic using direct file import to avoid conflicts try: # Import directly from Meissonic's train directory to avoid InfinityStar's train.py import importlib.util dataset_utils_path = os.path.join(meissonic_path, "train", "dataset_utils.py") if os.path.exists(dataset_utils_path): spec = importlib.util.spec_from_file_location("meissonic_dataset_utils", dataset_utils_path) dataset_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(dataset_utils) OpenVid1MDataset = dataset_utils.OpenVid1MDataset from transformers import T5Tokenizer DATASET_AVAILABLE = True print(f"Loaded dataset utilities from Meissonic: {dataset_utils_path}") else: raise ImportError(f"Could not find dataset_utils.py at {dataset_utils_path}") except Exception as e: DATASET_AVAILABLE = False print(f"Warning: Could not import dataset utilities: {e}") print("Will use direct video loading.") def calculate_psnr(img1, img2, max_val=1.0): """Calculate PSNR between two images.""" # Ensure both tensors are on CPU img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1) img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2) mse = torch.mean((img1 - img2) ** 2) if mse == 0: return float('inf') psnr = 20 * torch.log10(max_val / torch.sqrt(mse)) return psnr.item() def calculate_mse(img1, img2): """Calculate MSE between two images.""" # Ensure both tensors are on CPU img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1) img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2) return torch.mean((img1 - img2) ** 2).item() def calculate_ssim(img1, img2, window_size=11): """Calculate SSIM between two images (simplified version).""" # Ensure both tensors are on CPU img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1) img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2) # Simple SSIM approximation C1 = 0.01 ** 2 C2 = 0.03 ** 2 mu1 = img1.mean() mu2 = img2.mean() sigma1_sq = img1.var() sigma2_sq = img2.var() sigma12 = ((img1 - mu1) * (img2 - mu2)).mean() ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim.item() def video_to_numpy(video_tensor): """ Convert video tensor [C, F, H, W] in [0, 1] to numpy array [F, H, W, C] in [0, 255] (RGB). """ if isinstance(video_tensor, torch.Tensor): # [C, F, H, W] -> [F, C, H, W] -> [F, H, W, C] video_np = video_tensor.permute(1, 0, 2, 3).cpu().numpy() # [F, C, H, W] video_np = np.transpose(video_np, (0, 2, 3, 1)) # [F, H, W, C] # Clamp to [0, 1] and convert to [0, 255] video_np = np.clip(video_np, 0, 1) video_np = (video_np * 255).astype(np.uint8) else: video_np = np.array(video_tensor) return video_np def create_side_by_side_video(original, reconstructed, output_path, fps=8): """ Create a side-by-side comparison video. Args: original: Original video tensor [C, F, H, W] or numpy array reconstructed: Reconstructed video tensor [C, F, H, W] or numpy array output_path: Path to save the output video fps: Frames per second """ # Convert to numpy (RGB format: [F, H, W, C]) orig_np = video_to_numpy(original) recon_np = video_to_numpy(reconstructed) # Get dimensions F, H, W, C = orig_np.shape F_recon, H_recon, W_recon, C_recon = recon_np.shape # Ensure same number of frames F_min = min(F, F_recon) orig_np = orig_np[:F_min] recon_np = recon_np[:F_min] # Resize if needed if (H, W) != (H_recon, W_recon): recon_np = np.array([cv2.resize(frame, (W, H)) for frame in recon_np]) # Create side-by-side frames comparison_frames = [] for t in range(F_min): orig = orig_np[t] recon = recon_np[t] # Add text labels orig_labeled = orig.copy() recon_labeled = recon.copy() cv2.putText(orig_labeled, "Original", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) cv2.putText(recon_labeled, "Reconstructed", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2) # Concatenate horizontally side_by_side = np.concatenate([orig_labeled, recon_labeled], axis=1) comparison_frames.append(side_by_side) # Save video if len(comparison_frames) == 0: raise ValueError("No frames to save") height, width = comparison_frames[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame in comparison_frames: # Convert RGB to BGR for OpenCV frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() print(f"Saved side-by-side video to: {output_path}") def add_text_to_image(image_tensor, text, position=(10, 30)): """ Add text label to an image tensor. Args: image_tensor: Image tensor [C, H, W] in [0, 1] text: Text to add position: (x, y) position for text Returns: Image tensor with text [C, H, W] """ # Convert to PIL Image image_np = image_tensor.permute(1, 2, 0).cpu().numpy() # [H, W, C] image_np = np.clip(image_np, 0, 1) image_np = (image_np * 255).astype(np.uint8) pil_image = Image.fromarray(image_np) # Add text from PIL import ImageDraw, ImageFont draw = ImageDraw.Draw(pil_image) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24) except: try: font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24) except: font = ImageFont.load_default() # Draw white text with black outline x, y = position # Draw outline for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]: draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0)) # Draw main text draw.text((x, y), text, font=font, fill=(255, 255, 255)) # Convert back to tensor image_tensor = transforms.ToTensor()(pil_image) return image_tensor def create_comparison_grid(original, reconstructed, output_path, nrow=4): """ Create a grid image comparing original and reconstructed frames. Args: original: Original video tensor [C, F, H, W] reconstructed: Reconstructed video tensor [C, F, H, W] output_path: Path to save the grid image nrow: Number of frames per row """ # Get number of frames F = min(original.shape[1], reconstructed.shape[1]) # Select frames to display (same as test_cosmos_vqvae.py) num_frames_to_show = min(8, F) frame_indices = np.linspace(0, F - 1, num_frames_to_show, dtype=int) frames_list = [] for idx in frame_indices: # Original frame with label orig_frame = original[:, idx, :, :].clone() # [C, H, W] orig_frame = add_text_to_image(orig_frame, "Original", position=(10, 10)) frames_list.append(orig_frame) # Reconstructed frame with label recon_frame = reconstructed[:, idx, :, :].clone() # [C, H, W] recon_frame = add_text_to_image(recon_frame, "Reconstructed", position=(10, 10)) frames_list.append(recon_frame) # Create grid (nrow * 2 because each frame has original and reconstructed) frames_tensor = torch.stack(frames_list, dim=0) grid = make_grid(frames_tensor, nrow=nrow * 2, padding=2, pad_value=1.0) save_image(grid, output_path) print(f"Saved comparison grid to: {output_path}") def main(): # Direct paths (like test_cosmos_vqvae.py) # Modify these paths according to your setup VAE_PATH = "/mnt/Meissonic/InfinityStar/infinitystar_videovae.pth" # Update this path VAE_TYPE = 18 # codebook_dim VIDEOVAE = 10 # absorb patchify # Dataset paths (same as test_cosmos_vqvae.py) CSV_PATH = "/mnt/VideoGen/dataset/OpenVid1M/video_reorg/OpenVid1M_reorganized.csv" # Update this path VIDEO_ROOT_DIR = None # Auto-detect if None VIDEO_INDEX = 3 # Index of video to test # Video parameters (same as test_cosmos_vqvae.py) NUM_FRAMES = 16 HEIGHT = 480 WIDTH = 848 # Output OUTPUT_DIR = "./infinity_vqvae_test_output" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = "float32" # Create output directory os.makedirs(OUTPUT_DIR, exist_ok=True) # Set device and dtype device = torch.device(DEVICE) if DTYPE == "float16": dtype = torch.float16 elif DTYPE == "bfloat16": dtype = torch.bfloat16 else: dtype = torch.float32 print(f"Using device: {device}, dtype: {dtype}") # Load VAE print("=" * 80) print("Loading VQ-VAE model...") print(f" VAE path: {VAE_PATH}") print(f" VAE type: {VAE_TYPE}") print(f" Video VAE: {VIDEOVAE}") print("=" * 80) vae_args = SimpleArgs() vae_args.vae_path = VAE_PATH vae_args.vae_type = VAE_TYPE vae_args.videovae = VIDEOVAE vae = load_visual_tokenizer_safe(vae_args, device=device) vae = vae.to(device) vae.eval() # Disable gradient computation for all parameters (same as official code) [p.requires_grad_(False) for p in vae.parameters()] print("VAE loaded successfully!") print(f" Device: {device}") print(f" Model dtype: {next(vae.parameters()).dtype}") print(f" Model in eval mode: {not vae.training}") # Load dataset (same as test_cosmos_vqvae.py) if DATASET_AVAILABLE: print(f"\nLoading dataset from: {CSV_PATH}") # Auto-detect video_root_dir if not provided video_root_dir = VIDEO_ROOT_DIR if video_root_dir is None: csv_dir = os.path.dirname(CSV_PATH) if os.path.exists(os.path.join(csv_dir, 'video_reorg')): video_root_dir = os.path.join(csv_dir, 'video_reorg') elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')): video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg') else: video_root_dir = csv_dir print(f"Warning: Video directory not found, using CSV directory: {video_root_dir}") # Initialize tokenizer for dataset tokenizer = T5Tokenizer.from_pretrained("google/umt5-base") # Create dataset dataset = OpenVid1MDataset( csv_path=CSV_PATH, video_root_dir=video_root_dir, tokenizer=tokenizer, num_frames=NUM_FRAMES, height=HEIGHT, width=WIDTH, text_encoder_architecture="umt5-base", ) print(f"Dataset size: {len(dataset)}") # Load video if VIDEO_INDEX >= len(dataset): print(f"Error: video_index {VIDEO_INDEX} >= dataset size {len(dataset)}") return print(f"Loading video at index {VIDEO_INDEX}...") sample = dataset[VIDEO_INDEX] original_video = sample["video"] # Ensure video is [C, T, H, W] format (VAE expects this) if original_video.dim() == 4: # Check if it's [T, C, H, W] format if original_video.shape[0] == NUM_FRAMES and original_video.shape[1] == 3: print(f"Detected [T, C, H, W] format, converting to [C, T, H, W]") original_video = original_video.permute(1, 0, 2, 3) # Check if it's [T, H, W, C] format elif original_video.shape[-1] == 3: print(f"Detected [T, H, W, C] format, converting to [C, T, H, W]") original_video = original_video.permute(3, 0, 1, 2) # Get video info from dataset row = dataset.data[VIDEO_INDEX] video_path = row.get('video', 'unknown') caption = row.get('caption', 'no caption') print(f"Video path: {video_path}") print(f"Caption: {caption}") else: print("Warning: Dataset utilities not available. Using dummy video.") original_video = torch.rand(3, NUM_FRAMES, HEIGHT, WIDTH) video_path = "dummy" caption = "dummy video" print(f"Original video shape (C, T, H, W): {original_video.shape}") print(f"Original video range (from dataset): [{original_video.min():.3f}, {original_video.max():.3f}]") # Move to device video_for_vae = original_video.to(device=device, dtype=dtype) # OpenVid1MDataset.process_video normalizes to [0, 1]. # VAE expects [-1, 1]. video_for_vae = video_for_vae.clamp(0.0, 1.0) print("Dataset returns [0, 1], converting to [-1, 1] for VAE") video_for_vae = video_for_vae * 2.0 - 1.0 print(f"Video for VAE range: [{video_for_vae.min():.3f}, {video_for_vae.max():.3f}]") # Convert to [B, C, T, H, W] format video_for_vae = video_for_vae.unsqueeze(0) # [1, C, T, H, W] # Encode: Use VAE's official interface (same as test_vae_reconstruction_simple.py) print("\n" + "=" * 80) print("Encoding using vae.encode_for_raw_features (InfinityStar's method)...") print("=" * 80) with torch.no_grad(): # Use InfinityStar's encode_for_raw_features (same as working script) raw_features, _, _ = vae.encode_for_raw_features( video_for_vae, scale_schedule=None, slice=True ) print(f"Encoded latent shape: {raw_features.shape}") print(f"Encoded latent range: [{raw_features.min().item():.4f}, {raw_features.max().item():.4f}]") # Decode: Use VAE's official interface (same as test_vae_reconstruction_simple.py) print("\n" + "=" * 80) print("Decoding using vae.decode (InfinityStar's method)...") print("=" * 80) with torch.no_grad(): # Use InfinityStar's decode (same as working script) reconstructed_video_batch = vae.decode(raw_features, slice=True) if isinstance(reconstructed_video_batch, tuple): reconstructed_video_batch = reconstructed_video_batch[0] # Clamp like in InfinityStar's code (same as working script) reconstructed_video_batch = torch.clamp(reconstructed_video_batch, min=-1, max=1) print(f"Reconstructed shape: {reconstructed_video_batch.shape}") print(f"Reconstructed range: [{reconstructed_video_batch.min():.3f}, {reconstructed_video_batch.max():.3f}]") # Convert back to [C, F, H, W] format reconstructed_video = reconstructed_video_batch.squeeze(0) # [C, T, H, W] = [C, F, H, W] # Normalize reconstructed video to [0, 1] for visualization # Check if output is in [-1, 1] or [0, 1] if reconstructed_video.min() < 0: print("Reconstructed video is in [-1, 1], converting to [0, 1]") reconstructed_video_01 = (reconstructed_video + 1.0) / 2.0 else: print("Reconstructed video is already in [0, 1]") reconstructed_video_01 = reconstructed_video.clone() reconstructed_video_01 = torch.clamp(reconstructed_video_01, 0, 1) print(f"Reconstructed video [0, 1] range: [{reconstructed_video_01.min():.3f}, {reconstructed_video_01.max():.3f}]") # Normalize original video to [0, 1] for visualization original_video_01 = original_video.clone().to(device=device) if original_video_01.min() < 0: original_video_01 = (original_video_01 + 1.0) / 2.0 elif original_video_01.max() > 1.0: original_video_01 = original_video_01 / 255.0 original_video_01 = torch.clamp(original_video_01, 0, 1) print(f"Original video [0, 1] range: [{original_video_01.min():.3f}, {original_video_01.max():.3f}]") # Ensure same number of frames for comparison F_orig = original_video_01.shape[1] F_recon = reconstructed_video_01.shape[1] F_min = min(F_orig, F_recon) if F_orig != F_recon: print(f"Frame count mismatch: original={F_orig}, reconstructed={F_recon}, using first {F_min} frames for comparison") print(" (This is normal for VAE with temporal compression)") original_video_01 = original_video_01[:, :F_min, :, :] reconstructed_video_01 = reconstructed_video_01[:, :F_min, :, :] # Resize if spatial dimensions don't match if original_video_01.shape[2:] != reconstructed_video_01.shape[2:]: print(f"Resizing reconstructed video from {reconstructed_video_01.shape[2:]} to {original_video_01.shape[2:]}") # Use interpolation to resize reconstructed_video_resized = torch.zeros_like(original_video_01) for f in range(F_min): frame = reconstructed_video_01[:, f, :, :].unsqueeze(0) # [1, C, H, W] frame_resized = torch.nn.functional.interpolate( frame, size=original_video_01.shape[2:], mode='bilinear', align_corners=False ) reconstructed_video_resized[:, f, :, :] = frame_resized.squeeze(0) reconstructed_video_01 = reconstructed_video_resized # Calculate metrics (same as test_cosmos_vqvae.py) print("\nCalculating metrics...") # Convert to float32 for metric calculation (already in [0, 1]) orig_f32 = original_video_01.to(torch.float32) recon_f32 = reconstructed_video_01.to(torch.float32) # Frame-wise metrics psnr_values = [] mse_values = [] ssim_values = [] for f in range(F_min): orig_frame = orig_f32[:, f, :, :] # [C, H, W] recon_frame = recon_f32[:, f, :, :] # [C, H, W] psnr = calculate_psnr(orig_frame, recon_frame) mse = calculate_mse(orig_frame, recon_frame) ssim = calculate_ssim(orig_frame, recon_frame) psnr_values.append(psnr) mse_values.append(mse) ssim_values.append(ssim) # Overall metrics avg_psnr = np.mean(psnr_values) avg_mse = np.mean(mse_values) avg_ssim = np.mean(ssim_values) print(f"\n=== Metrics ===") print(f"PSNR: {avg_psnr:.2f} dB (per frame: {psnr_values})") print(f"MSE: {avg_mse:.6f} (per frame: {mse_values})") print(f"SSIM: {avg_ssim:.4f} (per frame: {ssim_values})") # Save metrics to file metrics_file = os.path.join(OUTPUT_DIR, f"metrics_video_{VIDEO_INDEX}.txt") with open(metrics_file, 'w') as f: f.write(f"Video Index: {VIDEO_INDEX}\n") f.write(f"Video Path: {video_path}\n") f.write(f"Caption: {caption}\n") f.write(f"\n=== Metrics ===\n") f.write(f"Average PSNR: {avg_psnr:.2f} dB\n") f.write(f"Average MSE: {avg_mse:.6f}\n") f.write(f"Average SSIM: {avg_ssim:.4f}\n") f.write(f"\nPer-frame PSNR: {psnr_values}\n") f.write(f"Per-frame MSE: {mse_values}\n") f.write(f"Per-frame SSIM: {ssim_values}\n") print(f"Saved metrics to: {metrics_file}") # Create side-by-side video print("\nCreating side-by-side comparison video...") video_output_path = os.path.join(OUTPUT_DIR, f"comparison_video_{VIDEO_INDEX}.mp4") create_side_by_side_video(original_video_01, reconstructed_video_01, video_output_path, fps=8) # Create comparison grid print("Creating comparison grid...") grid_output_path = os.path.join(OUTPUT_DIR, f"comparison_grid_video_{VIDEO_INDEX}.png") create_comparison_grid(original_video_01, reconstructed_video_01, grid_output_path, nrow=4) print(f"\n=== Test Complete ===") print(f"Results saved to: {OUTPUT_DIR}") print(f" - Metrics: {metrics_file}") print(f" - Side-by-side video: {video_output_path}") print(f" - Comparison grid: {grid_output_path}") if __name__ == "__main__": main()