43 / Meissonic /InfinityStar /test_infinity_vqvae.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
#!/usr/bin/env python3
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
"""
Test script for InfinityStar VQ-VAE performance.
This script:
1. Loads a video from the training dataset (same as test_cosmos_vqvae.py)
2. Encodes it using InfinityStar VAE
3. Decodes it back
4. Computes metrics (PSNR, SSIM, MSE) - same as test_cosmos_vqvae.py
5. Creates a side-by-side comparison video
6. Saves the results
"""
import os
import sys
import torch
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
from torchvision.utils import make_grid, save_image
# Add Meissonic to path FIRST to avoid importing InfinityStar's train.py
meissonic_path = "/mnt/Meissonic" #os.path.join(os.path.dirname(os.path.dirname(__file__)), "Meissonic")
if os.path.exists(meissonic_path):
sys.path.insert(0, meissonic_path)
# Also add Meissonic's train directory to path
meissonic_train_path = os.path.join(meissonic_path, "train")
if os.path.exists(meissonic_train_path):
sys.path.insert(0, meissonic_train_path)
# Add InfinityStar to path (but after Meissonic)
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Avoid importing arg_util which depends on 'tap' package (has Python 2 syntax issues)
# Create a simple Args class instead
class SimpleArgs:
"""Simple replacement for Args class to avoid tap dependency."""
def __init__(self):
# Quantizer-related fields: MUST match the checkpoint config
self.semantic_scale_dim = 16
self.detail_scale_dim = 64
self.use_learnable_dim_proj = 0
self.detail_scale_min_tokens = 80
# IMPORTANT: for infinitystar_videovae.pth this must be 2,
# otherwise the quantizer takes a different feature projection path
# and reconstructions become very blurry.
self.use_feat_proj = 2
self.semantic_scales = 8
# VAE-specific attributes
self.vae_path = ""
self.vae_type = 18
self.videovae = 10
# Import load_visual_tokenizer directly, avoiding arg_util import
import sys
import importlib.util
# Load load_visual_tokenizer function without importing arg_util
def load_visual_tokenizer_safe(args, device=None):
"""Load visual tokenizer without importing arg_util."""
if not device:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.vae_type in [8,12,14,16,18,20,24,32,48,64,128]:
schedule_mode = "dynamic"
codebook_dim = args.vae_type
print(f'Load VAE from {args.vae_path}')
if args.videovae == 10: # absorb patchify
from infinity.models.videovae.models.load_vae_bsq_wan_absorb_patchify import video_vae_model
vae_local = video_vae_model(args.vae_path, schedule_mode, codebook_dim, global_args=args, test_mode=True).to(device)
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
return vae_local
# Import dataset utilities from Meissonic using direct file import to avoid conflicts
try:
# Import directly from Meissonic's train directory to avoid InfinityStar's train.py
import importlib.util
dataset_utils_path = os.path.join(meissonic_path, "train", "dataset_utils.py")
if os.path.exists(dataset_utils_path):
spec = importlib.util.spec_from_file_location("meissonic_dataset_utils", dataset_utils_path)
dataset_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(dataset_utils)
OpenVid1MDataset = dataset_utils.OpenVid1MDataset
from transformers import T5Tokenizer
DATASET_AVAILABLE = True
print(f"Loaded dataset utilities from Meissonic: {dataset_utils_path}")
else:
raise ImportError(f"Could not find dataset_utils.py at {dataset_utils_path}")
except Exception as e:
DATASET_AVAILABLE = False
print(f"Warning: Could not import dataset utilities: {e}")
print("Will use direct video loading.")
def calculate_psnr(img1, img2, max_val=1.0):
"""Calculate PSNR between two images."""
# Ensure both tensors are on CPU
img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1)
img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2)
mse = torch.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
return psnr.item()
def calculate_mse(img1, img2):
"""Calculate MSE between two images."""
# Ensure both tensors are on CPU
img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1)
img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2)
return torch.mean((img1 - img2) ** 2).item()
def calculate_ssim(img1, img2, window_size=11):
"""Calculate SSIM between two images (simplified version)."""
# Ensure both tensors are on CPU
img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1)
img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2)
# Simple SSIM approximation
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu1 = img1.mean()
mu2 = img2.mean()
sigma1_sq = img1.var()
sigma2_sq = img2.var()
sigma12 = ((img1 - mu1) * (img2 - mu2)).mean()
ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim.item()
def video_to_numpy(video_tensor):
"""
Convert video tensor [C, F, H, W] in [0, 1] to numpy array [F, H, W, C] in [0, 255] (RGB).
"""
if isinstance(video_tensor, torch.Tensor):
# [C, F, H, W] -> [F, C, H, W] -> [F, H, W, C]
video_np = video_tensor.permute(1, 0, 2, 3).cpu().numpy() # [F, C, H, W]
video_np = np.transpose(video_np, (0, 2, 3, 1)) # [F, H, W, C]
# Clamp to [0, 1] and convert to [0, 255]
video_np = np.clip(video_np, 0, 1)
video_np = (video_np * 255).astype(np.uint8)
else:
video_np = np.array(video_tensor)
return video_np
def create_side_by_side_video(original, reconstructed, output_path, fps=8):
"""
Create a side-by-side comparison video.
Args:
original: Original video tensor [C, F, H, W] or numpy array
reconstructed: Reconstructed video tensor [C, F, H, W] or numpy array
output_path: Path to save the output video
fps: Frames per second
"""
# Convert to numpy (RGB format: [F, H, W, C])
orig_np = video_to_numpy(original)
recon_np = video_to_numpy(reconstructed)
# Get dimensions
F, H, W, C = orig_np.shape
F_recon, H_recon, W_recon, C_recon = recon_np.shape
# Ensure same number of frames
F_min = min(F, F_recon)
orig_np = orig_np[:F_min]
recon_np = recon_np[:F_min]
# Resize if needed
if (H, W) != (H_recon, W_recon):
recon_np = np.array([cv2.resize(frame, (W, H)) for frame in recon_np])
# Create side-by-side frames
comparison_frames = []
for t in range(F_min):
orig = orig_np[t]
recon = recon_np[t]
# Add text labels
orig_labeled = orig.copy()
recon_labeled = recon.copy()
cv2.putText(orig_labeled, "Original", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
cv2.putText(recon_labeled, "Reconstructed", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
# Concatenate horizontally
side_by_side = np.concatenate([orig_labeled, recon_labeled], axis=1)
comparison_frames.append(side_by_side)
# Save video
if len(comparison_frames) == 0:
raise ValueError("No frames to save")
height, width = comparison_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame in comparison_frames:
# Convert RGB to BGR for OpenCV
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
print(f"Saved side-by-side video to: {output_path}")
def add_text_to_image(image_tensor, text, position=(10, 30)):
"""
Add text label to an image tensor.
Args:
image_tensor: Image tensor [C, H, W] in [0, 1]
text: Text to add
position: (x, y) position for text
Returns:
Image tensor with text [C, H, W]
"""
# Convert to PIL Image
image_np = image_tensor.permute(1, 2, 0).cpu().numpy() # [H, W, C]
image_np = np.clip(image_np, 0, 1)
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
# Add text
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(pil_image)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24)
except:
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24)
except:
font = ImageFont.load_default()
# Draw white text with black outline
x, y = position
# Draw outline
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]:
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0))
# Draw main text
draw.text((x, y), text, font=font, fill=(255, 255, 255))
# Convert back to tensor
image_tensor = transforms.ToTensor()(pil_image)
return image_tensor
def create_comparison_grid(original, reconstructed, output_path, nrow=4):
"""
Create a grid image comparing original and reconstructed frames.
Args:
original: Original video tensor [C, F, H, W]
reconstructed: Reconstructed video tensor [C, F, H, W]
output_path: Path to save the grid image
nrow: Number of frames per row
"""
# Get number of frames
F = min(original.shape[1], reconstructed.shape[1])
# Select frames to display (same as test_cosmos_vqvae.py)
num_frames_to_show = min(8, F)
frame_indices = np.linspace(0, F - 1, num_frames_to_show, dtype=int)
frames_list = []
for idx in frame_indices:
# Original frame with label
orig_frame = original[:, idx, :, :].clone() # [C, H, W]
orig_frame = add_text_to_image(orig_frame, "Original", position=(10, 10))
frames_list.append(orig_frame)
# Reconstructed frame with label
recon_frame = reconstructed[:, idx, :, :].clone() # [C, H, W]
recon_frame = add_text_to_image(recon_frame, "Reconstructed", position=(10, 10))
frames_list.append(recon_frame)
# Create grid (nrow * 2 because each frame has original and reconstructed)
frames_tensor = torch.stack(frames_list, dim=0)
grid = make_grid(frames_tensor, nrow=nrow * 2, padding=2, pad_value=1.0)
save_image(grid, output_path)
print(f"Saved comparison grid to: {output_path}")
def main():
# Direct paths (like test_cosmos_vqvae.py)
# Modify these paths according to your setup
VAE_PATH = "/mnt/Meissonic/InfinityStar/infinitystar_videovae.pth" # Update this path
VAE_TYPE = 18 # codebook_dim
VIDEOVAE = 10 # absorb patchify
# Dataset paths (same as test_cosmos_vqvae.py)
CSV_PATH = "/mnt/VideoGen/dataset/OpenVid1M/video_reorg/OpenVid1M_reorganized.csv" # Update this path
VIDEO_ROOT_DIR = None # Auto-detect if None
VIDEO_INDEX = 3 # Index of video to test
# Video parameters (same as test_cosmos_vqvae.py)
NUM_FRAMES = 16
HEIGHT = 480
WIDTH = 848
# Output
OUTPUT_DIR = "./infinity_vqvae_test_output"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = "float32"
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Set device and dtype
device = torch.device(DEVICE)
if DTYPE == "float16":
dtype = torch.float16
elif DTYPE == "bfloat16":
dtype = torch.bfloat16
else:
dtype = torch.float32
print(f"Using device: {device}, dtype: {dtype}")
# Load VAE
print("=" * 80)
print("Loading VQ-VAE model...")
print(f" VAE path: {VAE_PATH}")
print(f" VAE type: {VAE_TYPE}")
print(f" Video VAE: {VIDEOVAE}")
print("=" * 80)
vae_args = SimpleArgs()
vae_args.vae_path = VAE_PATH
vae_args.vae_type = VAE_TYPE
vae_args.videovae = VIDEOVAE
vae = load_visual_tokenizer_safe(vae_args, device=device)
vae = vae.to(device)
vae.eval()
# Disable gradient computation for all parameters (same as official code)
[p.requires_grad_(False) for p in vae.parameters()]
print("VAE loaded successfully!")
print(f" Device: {device}")
print(f" Model dtype: {next(vae.parameters()).dtype}")
print(f" Model in eval mode: {not vae.training}")
# Load dataset (same as test_cosmos_vqvae.py)
if DATASET_AVAILABLE:
print(f"\nLoading dataset from: {CSV_PATH}")
# Auto-detect video_root_dir if not provided
video_root_dir = VIDEO_ROOT_DIR
if video_root_dir is None:
csv_dir = os.path.dirname(CSV_PATH)
if os.path.exists(os.path.join(csv_dir, 'video_reorg')):
video_root_dir = os.path.join(csv_dir, 'video_reorg')
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')):
video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg')
else:
video_root_dir = csv_dir
print(f"Warning: Video directory not found, using CSV directory: {video_root_dir}")
# Initialize tokenizer for dataset
tokenizer = T5Tokenizer.from_pretrained("google/umt5-base")
# Create dataset
dataset = OpenVid1MDataset(
csv_path=CSV_PATH,
video_root_dir=video_root_dir,
tokenizer=tokenizer,
num_frames=NUM_FRAMES,
height=HEIGHT,
width=WIDTH,
text_encoder_architecture="umt5-base",
)
print(f"Dataset size: {len(dataset)}")
# Load video
if VIDEO_INDEX >= len(dataset):
print(f"Error: video_index {VIDEO_INDEX} >= dataset size {len(dataset)}")
return
print(f"Loading video at index {VIDEO_INDEX}...")
sample = dataset[VIDEO_INDEX]
original_video = sample["video"]
# Ensure video is [C, T, H, W] format (VAE expects this)
if original_video.dim() == 4:
# Check if it's [T, C, H, W] format
if original_video.shape[0] == NUM_FRAMES and original_video.shape[1] == 3:
print(f"Detected [T, C, H, W] format, converting to [C, T, H, W]")
original_video = original_video.permute(1, 0, 2, 3)
# Check if it's [T, H, W, C] format
elif original_video.shape[-1] == 3:
print(f"Detected [T, H, W, C] format, converting to [C, T, H, W]")
original_video = original_video.permute(3, 0, 1, 2)
# Get video info from dataset
row = dataset.data[VIDEO_INDEX]
video_path = row.get('video', 'unknown')
caption = row.get('caption', 'no caption')
print(f"Video path: {video_path}")
print(f"Caption: {caption}")
else:
print("Warning: Dataset utilities not available. Using dummy video.")
original_video = torch.rand(3, NUM_FRAMES, HEIGHT, WIDTH)
video_path = "dummy"
caption = "dummy video"
print(f"Original video shape (C, T, H, W): {original_video.shape}")
print(f"Original video range (from dataset): [{original_video.min():.3f}, {original_video.max():.3f}]")
# Move to device
video_for_vae = original_video.to(device=device, dtype=dtype)
# OpenVid1MDataset.process_video normalizes to [0, 1].
# VAE expects [-1, 1].
video_for_vae = video_for_vae.clamp(0.0, 1.0)
print("Dataset returns [0, 1], converting to [-1, 1] for VAE")
video_for_vae = video_for_vae * 2.0 - 1.0
print(f"Video for VAE range: [{video_for_vae.min():.3f}, {video_for_vae.max():.3f}]")
# Convert to [B, C, T, H, W] format
video_for_vae = video_for_vae.unsqueeze(0) # [1, C, T, H, W]
# Encode: Use VAE's official interface (same as test_vae_reconstruction_simple.py)
print("\n" + "=" * 80)
print("Encoding using vae.encode_for_raw_features (InfinityStar's method)...")
print("=" * 80)
with torch.no_grad():
# Use InfinityStar's encode_for_raw_features (same as working script)
raw_features, _, _ = vae.encode_for_raw_features(
video_for_vae,
scale_schedule=None,
slice=True
)
print(f"Encoded latent shape: {raw_features.shape}")
print(f"Encoded latent range: [{raw_features.min().item():.4f}, {raw_features.max().item():.4f}]")
# Decode: Use VAE's official interface (same as test_vae_reconstruction_simple.py)
print("\n" + "=" * 80)
print("Decoding using vae.decode (InfinityStar's method)...")
print("=" * 80)
with torch.no_grad():
# Use InfinityStar's decode (same as working script)
reconstructed_video_batch = vae.decode(raw_features, slice=True)
if isinstance(reconstructed_video_batch, tuple):
reconstructed_video_batch = reconstructed_video_batch[0]
# Clamp like in InfinityStar's code (same as working script)
reconstructed_video_batch = torch.clamp(reconstructed_video_batch, min=-1, max=1)
print(f"Reconstructed shape: {reconstructed_video_batch.shape}")
print(f"Reconstructed range: [{reconstructed_video_batch.min():.3f}, {reconstructed_video_batch.max():.3f}]")
# Convert back to [C, F, H, W] format
reconstructed_video = reconstructed_video_batch.squeeze(0) # [C, T, H, W] = [C, F, H, W]
# Normalize reconstructed video to [0, 1] for visualization
# Check if output is in [-1, 1] or [0, 1]
if reconstructed_video.min() < 0:
print("Reconstructed video is in [-1, 1], converting to [0, 1]")
reconstructed_video_01 = (reconstructed_video + 1.0) / 2.0
else:
print("Reconstructed video is already in [0, 1]")
reconstructed_video_01 = reconstructed_video.clone()
reconstructed_video_01 = torch.clamp(reconstructed_video_01, 0, 1)
print(f"Reconstructed video [0, 1] range: [{reconstructed_video_01.min():.3f}, {reconstructed_video_01.max():.3f}]")
# Normalize original video to [0, 1] for visualization
original_video_01 = original_video.clone().to(device=device)
if original_video_01.min() < 0:
original_video_01 = (original_video_01 + 1.0) / 2.0
elif original_video_01.max() > 1.0:
original_video_01 = original_video_01 / 255.0
original_video_01 = torch.clamp(original_video_01, 0, 1)
print(f"Original video [0, 1] range: [{original_video_01.min():.3f}, {original_video_01.max():.3f}]")
# Ensure same number of frames for comparison
F_orig = original_video_01.shape[1]
F_recon = reconstructed_video_01.shape[1]
F_min = min(F_orig, F_recon)
if F_orig != F_recon:
print(f"Frame count mismatch: original={F_orig}, reconstructed={F_recon}, using first {F_min} frames for comparison")
print(" (This is normal for VAE with temporal compression)")
original_video_01 = original_video_01[:, :F_min, :, :]
reconstructed_video_01 = reconstructed_video_01[:, :F_min, :, :]
# Resize if spatial dimensions don't match
if original_video_01.shape[2:] != reconstructed_video_01.shape[2:]:
print(f"Resizing reconstructed video from {reconstructed_video_01.shape[2:]} to {original_video_01.shape[2:]}")
# Use interpolation to resize
reconstructed_video_resized = torch.zeros_like(original_video_01)
for f in range(F_min):
frame = reconstructed_video_01[:, f, :, :].unsqueeze(0) # [1, C, H, W]
frame_resized = torch.nn.functional.interpolate(
frame, size=original_video_01.shape[2:], mode='bilinear', align_corners=False
)
reconstructed_video_resized[:, f, :, :] = frame_resized.squeeze(0)
reconstructed_video_01 = reconstructed_video_resized
# Calculate metrics (same as test_cosmos_vqvae.py)
print("\nCalculating metrics...")
# Convert to float32 for metric calculation (already in [0, 1])
orig_f32 = original_video_01.to(torch.float32)
recon_f32 = reconstructed_video_01.to(torch.float32)
# Frame-wise metrics
psnr_values = []
mse_values = []
ssim_values = []
for f in range(F_min):
orig_frame = orig_f32[:, f, :, :] # [C, H, W]
recon_frame = recon_f32[:, f, :, :] # [C, H, W]
psnr = calculate_psnr(orig_frame, recon_frame)
mse = calculate_mse(orig_frame, recon_frame)
ssim = calculate_ssim(orig_frame, recon_frame)
psnr_values.append(psnr)
mse_values.append(mse)
ssim_values.append(ssim)
# Overall metrics
avg_psnr = np.mean(psnr_values)
avg_mse = np.mean(mse_values)
avg_ssim = np.mean(ssim_values)
print(f"\n=== Metrics ===")
print(f"PSNR: {avg_psnr:.2f} dB (per frame: {psnr_values})")
print(f"MSE: {avg_mse:.6f} (per frame: {mse_values})")
print(f"SSIM: {avg_ssim:.4f} (per frame: {ssim_values})")
# Save metrics to file
metrics_file = os.path.join(OUTPUT_DIR, f"metrics_video_{VIDEO_INDEX}.txt")
with open(metrics_file, 'w') as f:
f.write(f"Video Index: {VIDEO_INDEX}\n")
f.write(f"Video Path: {video_path}\n")
f.write(f"Caption: {caption}\n")
f.write(f"\n=== Metrics ===\n")
f.write(f"Average PSNR: {avg_psnr:.2f} dB\n")
f.write(f"Average MSE: {avg_mse:.6f}\n")
f.write(f"Average SSIM: {avg_ssim:.4f}\n")
f.write(f"\nPer-frame PSNR: {psnr_values}\n")
f.write(f"Per-frame MSE: {mse_values}\n")
f.write(f"Per-frame SSIM: {ssim_values}\n")
print(f"Saved metrics to: {metrics_file}")
# Create side-by-side video
print("\nCreating side-by-side comparison video...")
video_output_path = os.path.join(OUTPUT_DIR, f"comparison_video_{VIDEO_INDEX}.mp4")
create_side_by_side_video(original_video_01, reconstructed_video_01, video_output_path, fps=8)
# Create comparison grid
print("Creating comparison grid...")
grid_output_path = os.path.join(OUTPUT_DIR, f"comparison_grid_video_{VIDEO_INDEX}.png")
create_comparison_grid(original_video_01, reconstructed_video_01, grid_output_path, nrow=4)
print(f"\n=== Test Complete ===")
print(f"Results saved to: {OUTPUT_DIR}")
print(f" - Metrics: {metrics_file}")
print(f" - Side-by-side video: {video_output_path}")
print(f" - Comparison grid: {grid_output_path}")
if __name__ == "__main__":
main()