43 / Meissonic /train /test_cosmos_vqvae.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
#!/usr/bin/env python3
"""
Test script for Cosmos VQ-VAE performance.
This script:
1. Loads a video from the training dataset
2. Encodes it using CosmosVideoTokenizer
3. Decodes it back
4. Computes metrics (PSNR, SSIM, MSE)
5. Creates a side-by-side comparison video
6. Saves the results
"""
import argparse
import os
import sys
sys.path.append(os.getcwd())
import torch
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
from torchvision.utils import make_grid, save_image
from src.pipeline_video import CosmosVideoTokenizer
from train.dataset_utils import OpenVid1MDataset, TinyOpenVid1MDataset
from transformers import T5Tokenizer
def calculate_psnr(img1, img2, max_val=1.0):
"""Calculate PSNR between two images."""
mse = torch.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
return psnr.item()
def calculate_mse(img1, img2):
"""Calculate MSE between two images."""
return torch.mean((img1 - img2) ** 2).item()
def calculate_ssim(img1, img2, window_size=11):
"""Calculate SSIM between two images (simplified version)."""
# Simple SSIM approximation
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu1 = img1.mean()
mu2 = img2.mean()
sigma1_sq = img1.var()
sigma2_sq = img2.var()
sigma12 = ((img1 - mu1) * (img2 - mu2)).mean()
ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim.item()
def video_to_numpy(video_tensor):
"""
Convert video tensor [C, F, H, W] in [0, 1] to numpy array [F, H, W, C] in [0, 255] (RGB).
"""
if isinstance(video_tensor, torch.Tensor):
# [C, F, H, W] -> [F, C, H, W] -> [F, H, W, C]
# First move frame dimension to front, then transpose channels to last
video_np = video_tensor.permute(1, 0, 2, 3).cpu().numpy() # [F, C, H, W]
video_np = np.transpose(video_np, (0, 2, 3, 1)) # [F, H, W, C]
# Clamp to [0, 1] and convert to [0, 255]
video_np = np.clip(video_np, 0, 1)
video_np = (video_np * 255).astype(np.uint8)
else:
video_np = np.array(video_tensor)
return video_np
def create_side_by_side_video(original, reconstructed, output_path, fps=8):
"""
Create a side-by-side comparison video.
Args:
original: Original video tensor [C, F, H, W] or numpy array
reconstructed: Reconstructed video tensor [C, F, H, W] or numpy array
output_path: Path to save the output video
fps: Frames per second
"""
# Convert to numpy (RGB format: [F, H, W, C])
orig_np = video_to_numpy(original)
recon_np = video_to_numpy(reconstructed)
# Get dimensions
F, H, W, C = orig_np.shape
F_recon, H_recon, W_recon, C_recon = recon_np.shape
# Ensure same number of frames
F_min = min(F, F_recon)
orig_np = orig_np[:F_min]
recon_np = recon_np[:F_min]
# Resize if dimensions don't match
if H != H_recon or W != W_recon:
print(f"Resizing reconstructed video from ({H_recon}, {W_recon}) to ({H}, {W})")
recon_np_resized = np.zeros((F_min, H, W, C), dtype=np.uint8)
for f in range(F_min):
# cv2.resize expects (width, height) for size parameter
recon_np_resized[f] = cv2.resize(recon_np[f], (W, H), interpolation=cv2.INTER_LINEAR)
recon_np = recon_np_resized
# Add text labels to frames
from PIL import Image, ImageDraw, ImageFont
side_by_side_frames = []
for f in range(F_min):
# Original frame with label
orig_frame_pil = Image.fromarray(orig_np[f])
draw = ImageDraw.Draw(orig_frame_pil)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 32)
except:
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 32)
except:
font = ImageFont.load_default()
# Draw text with outline for visibility
text = "Original"
x, y = 20, 20
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]:
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0))
draw.text((x, y), text, font=font, fill=(255, 255, 255))
orig_frame = np.array(orig_frame_pil)
# Reconstructed frame with label
recon_frame_pil = Image.fromarray(recon_np[f])
draw = ImageDraw.Draw(recon_frame_pil)
text = "Reconstructed"
x, y = 20, 20
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]:
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0))
draw.text((x, y), text, font=font, fill=(255, 255, 0)) # Yellow text
recon_frame = np.array(recon_frame_pil)
# Concatenate horizontally
frame = np.concatenate([orig_frame, recon_frame], axis=1)
side_by_side_frames.append(frame)
# Write video using OpenCV (needs BGR format)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (W * 2, H))
if not out.isOpened():
print(f"Warning: Could not open video writer with mp4v codec, trying XVID...")
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(output_path, fourcc, fps, (W * 2, H))
for frame in side_by_side_frames:
# Convert RGB to BGR for OpenCV
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
print(f"Saved side-by-side video to: {output_path}")
def add_text_to_image(image_tensor, text, position=(10, 30)):
"""
Add text label to an image tensor.
Args:
image_tensor: Image tensor [C, H, W] in [0, 1]
text: Text to add
position: (x, y) position for text
Returns:
Image tensor with text [C, H, W]
"""
# Convert to PIL Image
image_np = image_tensor.permute(1, 2, 0).cpu().numpy() # [H, W, C]
image_np = np.clip(image_np, 0, 1)
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
# Add text
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(pil_image)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24)
except:
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24)
except:
font = ImageFont.load_default()
# Draw white text with black outline
x, y = position
# Draw outline
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]:
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0))
# Draw main text
draw.text((x, y), text, font=font, fill=(255, 255, 255))
# Convert back to tensor
image_tensor = transforms.ToTensor()(pil_image)
return image_tensor
def create_comparison_grid(original, reconstructed, output_path, nrow=4):
"""
Create a grid image comparing original and reconstructed frames.
Args:
original: Original video tensor [C, F, H, W]
reconstructed: Reconstructed video tensor [C, F, H, W]
output_path: Path to save the grid image
nrow: Number of frames per row
"""
# Get number of frames
F = min(original.shape[1], reconstructed.shape[1])
# Select frames to display
num_frames_to_show = min(8, F)
frame_indices = np.linspace(0, F - 1, num_frames_to_show, dtype=int)
frames_list = []
for idx in frame_indices:
# Original frame with label
orig_frame = original[:, idx, :, :].clone() # [C, H, W]
orig_frame = add_text_to_image(orig_frame, "Original", position=(10, 10))
frames_list.append(orig_frame)
# Reconstructed frame with label
recon_frame = reconstructed[:, idx, :, :].clone() # [C, H, W]
recon_frame = add_text_to_image(recon_frame, "Reconstructed", position=(10, 10))
frames_list.append(recon_frame)
# Create grid
frames_tensor = torch.stack(frames_list, dim=0)
grid = make_grid(frames_tensor, nrow=nrow * 2, padding=2, pad_value=1.0)
save_image(grid, output_path)
print(f"Saved comparison grid to: {output_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Test Cosmos VQ-VAE performance")
parser.add_argument(
"--csv_path",
type=str,
required=True,
help="Path to OpenVid1M CSV file"
)
parser.add_argument(
"--video_root_dir",
type=str,
default=None,
help="Root directory for videos (auto-detected if not provided)"
)
parser.add_argument(
"--video_index",
type=int,
default=0,
help="Index of video to test (default: 0)"
)
parser.add_argument(
"--video_tokenizer_model_id",
type=str,
default="Cosmos-1.0-Tokenizer-DV8x16x16",
help="Cosmos tokenizer model ID"
)
parser.add_argument(
"--num_frames",
type=int,
default=16,
help="Number of frames"
)
parser.add_argument(
"--height",
type=int,
default=480,
help="Video height"
)
parser.add_argument(
"--width",
type=int,
default=848,
help="Video width"
)
parser.add_argument(
"--output_dir",
type=str,
default="./cosmos_test_output",
help="Output directory for results"
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use"
)
parser.add_argument(
"--dtype",
type=str,
default="float32",
choices=["float32", "float16", "bfloat16"],
help="Data type"
)
return parser.parse_args()
def main():
args = parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Set device and dtype
device = torch.device(args.device)
if args.dtype == "float16":
dtype = torch.float16
elif args.dtype == "bfloat16":
dtype = torch.bfloat16
else:
dtype = torch.float32
print(f"Using device: {device}, dtype: {dtype}")
# Initialize tokenizer
print("Initializing CosmosVideoTokenizer...")
video_tokenizer = CosmosVideoTokenizer(
model_id=args.video_tokenizer_model_id,
device=device,
dtype=dtype
)
print(f"Codebook size: {video_tokenizer.codebook_size}")
print(f"Downsampling factors: t={video_tokenizer.t_downsample}, "
f"h={video_tokenizer.h_downsample}, w={video_tokenizer.w_downsample}")
# Load dataset
print(f"Loading dataset from: {args.csv_path}")
# Auto-detect video_root_dir if not provided
video_root_dir = args.video_root_dir
if video_root_dir is None:
csv_dir = os.path.dirname(args.csv_path)
if os.path.exists(os.path.join(csv_dir, 'video_reorg')):
video_root_dir = os.path.join(csv_dir, 'video_reorg')
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')):
video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg')
else:
video_root_dir = csv_dir
print(f"Warning: Video directory not found, using CSV directory: {video_root_dir}")
# Initialize tokenizer for dataset (needed for OpenVid1MDataset)
tokenizer = T5Tokenizer.from_pretrained("google/umt5-base")
# Create dataset
dataset = OpenVid1MDataset(
csv_path=args.csv_path,
video_root_dir=video_root_dir,
tokenizer=tokenizer,
num_frames=args.num_frames,
height=args.height,
width=args.width,
text_encoder_architecture="umt5-base",
)
print(f"Dataset size: {len(dataset)}")
# Load video
if args.video_index >= len(dataset):
print(f"Error: video_index {args.video_index} >= dataset size {len(dataset)}")
return
print(f"Loading video at index {args.video_index}...")
sample = dataset[args.video_index]
original_video = sample["video"] # [C, F, H, W]
# Get video info from dataset
row = dataset.data[args.video_index]
video_path = row.get('video', 'unknown')
caption = row.get('caption', 'no caption')
print(f"Video path: {video_path}")
print(f"Caption: {caption}")
print(f"Original video shape: {original_video.shape}")
print(f"Original video range: [{original_video.min():.3f}, {original_video.max():.3f}]")
# Move to device
original_video = original_video.to(device=device, dtype=dtype)
# Encode
print("\nEncoding video...")
with torch.no_grad():
codes = video_tokenizer.encode(original_video.unsqueeze(0)) # [1, F', H', W']
print(f"Encoded codes shape: {codes.shape}")
print(f"Codes range: [{codes.min().item()}, {codes.max().item()}]")
print(f"Codebook size: {video_tokenizer.codebook_size}")
# Decode
print("\nDecoding video...")
with torch.no_grad():
reconstructed_video = video_tokenizer.decode(codes) # [1, C, F, H, W]
reconstructed_video = reconstructed_video.squeeze(0) # [C, F, H, W]
print(f"Reconstructed video shape: {reconstructed_video.shape}")
print(f"Reconstructed video range: [{reconstructed_video.min():.3f}, {reconstructed_video.max():.3f}]")
# Ensure same number of frames for comparison
F_orig = original_video.shape[1]
F_recon = reconstructed_video.shape[1]
F_min = min(F_orig, F_recon)
original_video = original_video[:, :F_min, :, :]
reconstructed_video = reconstructed_video[:, :F_min, :, :]
# Resize if spatial dimensions don't match
if original_video.shape[2:] != reconstructed_video.shape[2:]:
print(f"Resizing reconstructed video from {reconstructed_video.shape[2:]} to {original_video.shape[2:]}")
# Use interpolation to resize
reconstructed_video_resized = torch.zeros_like(original_video)
for f in range(F_min):
frame = reconstructed_video[:, f, :, :].unsqueeze(0) # [1, C, H, W]
frame_resized = torch.nn.functional.interpolate(
frame, size=original_video.shape[2:], mode='bilinear', align_corners=False
)
reconstructed_video_resized[:, f, :, :] = frame_resized.squeeze(0)
reconstructed_video = reconstructed_video_resized
# Calculate metrics
print("\nCalculating metrics...")
# Convert to float32 for metric calculation
orig_f32 = original_video.to(torch.float32)
recon_f32 = reconstructed_video.to(torch.float32)
# Frame-wise metrics
psnr_values = []
mse_values = []
ssim_values = []
for f in range(F_min):
orig_frame = orig_f32[:, f, :, :] # [C, H, W]
recon_frame = recon_f32[:, f, :, :] # [C, H, W]
psnr = calculate_psnr(orig_frame, recon_frame)
mse = calculate_mse(orig_frame, recon_frame)
ssim = calculate_ssim(orig_frame, recon_frame)
psnr_values.append(psnr)
mse_values.append(mse)
ssim_values.append(ssim)
# Overall metrics
avg_psnr = np.mean(psnr_values)
avg_mse = np.mean(mse_values)
avg_ssim = np.mean(ssim_values)
print(f"\n=== Metrics ===")
print(f"PSNR: {avg_psnr:.2f} dB (per frame: {psnr_values})")
print(f"MSE: {avg_mse:.6f} (per frame: {mse_values})")
print(f"SSIM: {avg_ssim:.4f} (per frame: {ssim_values})")
# Save metrics to file
metrics_file = os.path.join(args.output_dir, f"metrics_video_{args.video_index}.txt")
with open(metrics_file, 'w') as f:
f.write(f"Video Index: {args.video_index}\n")
f.write(f"Video Path: {video_path}\n")
f.write(f"Caption: {caption}\n")
f.write(f"\n=== Metrics ===\n")
f.write(f"Average PSNR: {avg_psnr:.2f} dB\n")
f.write(f"Average MSE: {avg_mse:.6f}\n")
f.write(f"Average SSIM: {avg_ssim:.4f}\n")
f.write(f"\nPer-frame PSNR: {psnr_values}\n")
f.write(f"Per-frame MSE: {mse_values}\n")
f.write(f"Per-frame SSIM: {ssim_values}\n")
print(f"Saved metrics to: {metrics_file}")
# Create side-by-side video
print("\nCreating side-by-side comparison video...")
video_output_path = os.path.join(args.output_dir, f"comparison_video_{args.video_index}.mp4")
create_side_by_side_video(original_video, reconstructed_video, video_output_path, fps=8)
# Create comparison grid
print("Creating comparison grid...")
grid_output_path = os.path.join(args.output_dir, f"comparison_grid_video_{args.video_index}.png")
create_comparison_grid(original_video, reconstructed_video, grid_output_path, nrow=4)
print(f"\n=== Test Complete ===")
print(f"Results saved to: {args.output_dir}")
print(f" - Metrics: {metrics_file}")
print(f" - Side-by-side video: {video_output_path}")
print(f" - Comparison grid: {grid_output_path}")
if __name__ == "__main__":
main()