| |
| """ |
| Test script for Cosmos VQ-VAE performance. |
| |
| This script: |
| 1. Loads a video from the training dataset |
| 2. Encodes it using CosmosVideoTokenizer |
| 3. Decodes it back |
| 4. Computes metrics (PSNR, SSIM, MSE) |
| 5. Creates a side-by-side comparison video |
| 6. Saves the results |
| """ |
|
|
| import argparse |
| import os |
| import sys |
| sys.path.append(os.getcwd()) |
|
|
| import torch |
| import numpy as np |
| from PIL import Image |
| import cv2 |
| from torchvision import transforms |
| from torchvision.utils import make_grid, save_image |
|
|
| from src.pipeline_video import CosmosVideoTokenizer |
| from train.dataset_utils import OpenVid1MDataset, TinyOpenVid1MDataset |
| from transformers import T5Tokenizer |
|
|
|
|
| def calculate_psnr(img1, img2, max_val=1.0): |
| """Calculate PSNR between two images.""" |
| mse = torch.mean((img1 - img2) ** 2) |
| if mse == 0: |
| return float('inf') |
| psnr = 20 * torch.log10(max_val / torch.sqrt(mse)) |
| return psnr.item() |
|
|
|
|
| def calculate_mse(img1, img2): |
| """Calculate MSE between two images.""" |
| return torch.mean((img1 - img2) ** 2).item() |
|
|
|
|
| def calculate_ssim(img1, img2, window_size=11): |
| """Calculate SSIM between two images (simplified version).""" |
| |
| C1 = 0.01 ** 2 |
| C2 = 0.03 ** 2 |
| |
| mu1 = img1.mean() |
| mu2 = img2.mean() |
| |
| sigma1_sq = img1.var() |
| sigma2_sq = img2.var() |
| sigma12 = ((img1 - mu1) * (img2 - mu2)).mean() |
| |
| ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) |
| return ssim.item() |
|
|
|
|
| def video_to_numpy(video_tensor): |
| """ |
| Convert video tensor [C, F, H, W] in [0, 1] to numpy array [F, H, W, C] in [0, 255] (RGB). |
| """ |
| if isinstance(video_tensor, torch.Tensor): |
| |
| |
| video_np = video_tensor.permute(1, 0, 2, 3).cpu().numpy() |
| video_np = np.transpose(video_np, (0, 2, 3, 1)) |
| |
| video_np = np.clip(video_np, 0, 1) |
| video_np = (video_np * 255).astype(np.uint8) |
| else: |
| video_np = np.array(video_tensor) |
| return video_np |
|
|
|
|
| def create_side_by_side_video(original, reconstructed, output_path, fps=8): |
| """ |
| Create a side-by-side comparison video. |
| |
| Args: |
| original: Original video tensor [C, F, H, W] or numpy array |
| reconstructed: Reconstructed video tensor [C, F, H, W] or numpy array |
| output_path: Path to save the output video |
| fps: Frames per second |
| """ |
| |
| orig_np = video_to_numpy(original) |
| recon_np = video_to_numpy(reconstructed) |
| |
| |
| F, H, W, C = orig_np.shape |
| F_recon, H_recon, W_recon, C_recon = recon_np.shape |
| |
| |
| F_min = min(F, F_recon) |
| orig_np = orig_np[:F_min] |
| recon_np = recon_np[:F_min] |
| |
| |
| if H != H_recon or W != W_recon: |
| print(f"Resizing reconstructed video from ({H_recon}, {W_recon}) to ({H}, {W})") |
| recon_np_resized = np.zeros((F_min, H, W, C), dtype=np.uint8) |
| for f in range(F_min): |
| |
| recon_np_resized[f] = cv2.resize(recon_np[f], (W, H), interpolation=cv2.INTER_LINEAR) |
| recon_np = recon_np_resized |
| |
| |
| from PIL import Image, ImageDraw, ImageFont |
| side_by_side_frames = [] |
| for f in range(F_min): |
| |
| orig_frame_pil = Image.fromarray(orig_np[f]) |
| draw = ImageDraw.Draw(orig_frame_pil) |
| try: |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 32) |
| except: |
| try: |
| font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 32) |
| except: |
| font = ImageFont.load_default() |
| |
| text = "Original" |
| x, y = 20, 20 |
| for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]: |
| draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0)) |
| draw.text((x, y), text, font=font, fill=(255, 255, 255)) |
| orig_frame = np.array(orig_frame_pil) |
| |
| |
| recon_frame_pil = Image.fromarray(recon_np[f]) |
| draw = ImageDraw.Draw(recon_frame_pil) |
| text = "Reconstructed" |
| x, y = 20, 20 |
| for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]: |
| draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0)) |
| draw.text((x, y), text, font=font, fill=(255, 255, 0)) |
| recon_frame = np.array(recon_frame_pil) |
| |
| |
| frame = np.concatenate([orig_frame, recon_frame], axis=1) |
| side_by_side_frames.append(frame) |
| |
| |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| out = cv2.VideoWriter(output_path, fourcc, fps, (W * 2, H)) |
| |
| if not out.isOpened(): |
| print(f"Warning: Could not open video writer with mp4v codec, trying XVID...") |
| fourcc = cv2.VideoWriter_fourcc(*'XVID') |
| out = cv2.VideoWriter(output_path, fourcc, fps, (W * 2, H)) |
| |
| for frame in side_by_side_frames: |
| |
| frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
| out.write(frame_bgr) |
| |
| out.release() |
| print(f"Saved side-by-side video to: {output_path}") |
|
|
|
|
| def add_text_to_image(image_tensor, text, position=(10, 30)): |
| """ |
| Add text label to an image tensor. |
| |
| Args: |
| image_tensor: Image tensor [C, H, W] in [0, 1] |
| text: Text to add |
| position: (x, y) position for text |
| Returns: |
| Image tensor with text [C, H, W] |
| """ |
| |
| image_np = image_tensor.permute(1, 2, 0).cpu().numpy() |
| image_np = np.clip(image_np, 0, 1) |
| image_np = (image_np * 255).astype(np.uint8) |
| pil_image = Image.fromarray(image_np) |
| |
| |
| from PIL import ImageDraw, ImageFont |
| draw = ImageDraw.Draw(pil_image) |
| try: |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24) |
| except: |
| try: |
| font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24) |
| except: |
| font = ImageFont.load_default() |
| |
| |
| x, y = position |
| |
| for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]: |
| draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0)) |
| |
| draw.text((x, y), text, font=font, fill=(255, 255, 255)) |
| |
| |
| image_tensor = transforms.ToTensor()(pil_image) |
| return image_tensor |
|
|
|
|
| def create_comparison_grid(original, reconstructed, output_path, nrow=4): |
| """ |
| Create a grid image comparing original and reconstructed frames. |
| |
| Args: |
| original: Original video tensor [C, F, H, W] |
| reconstructed: Reconstructed video tensor [C, F, H, W] |
| output_path: Path to save the grid image |
| nrow: Number of frames per row |
| """ |
| |
| F = min(original.shape[1], reconstructed.shape[1]) |
| |
| |
| num_frames_to_show = min(8, F) |
| frame_indices = np.linspace(0, F - 1, num_frames_to_show, dtype=int) |
| |
| frames_list = [] |
| for idx in frame_indices: |
| |
| orig_frame = original[:, idx, :, :].clone() |
| orig_frame = add_text_to_image(orig_frame, "Original", position=(10, 10)) |
| frames_list.append(orig_frame) |
| |
| |
| recon_frame = reconstructed[:, idx, :, :].clone() |
| recon_frame = add_text_to_image(recon_frame, "Reconstructed", position=(10, 10)) |
| frames_list.append(recon_frame) |
| |
| |
| frames_tensor = torch.stack(frames_list, dim=0) |
| grid = make_grid(frames_tensor, nrow=nrow * 2, padding=2, pad_value=1.0) |
| |
| save_image(grid, output_path) |
| print(f"Saved comparison grid to: {output_path}") |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Test Cosmos VQ-VAE performance") |
| |
| parser.add_argument( |
| "--csv_path", |
| type=str, |
| required=True, |
| help="Path to OpenVid1M CSV file" |
| ) |
| parser.add_argument( |
| "--video_root_dir", |
| type=str, |
| default=None, |
| help="Root directory for videos (auto-detected if not provided)" |
| ) |
| parser.add_argument( |
| "--video_index", |
| type=int, |
| default=0, |
| help="Index of video to test (default: 0)" |
| ) |
| parser.add_argument( |
| "--video_tokenizer_model_id", |
| type=str, |
| default="Cosmos-1.0-Tokenizer-DV8x16x16", |
| help="Cosmos tokenizer model ID" |
| ) |
| parser.add_argument( |
| "--num_frames", |
| type=int, |
| default=16, |
| help="Number of frames" |
| ) |
| parser.add_argument( |
| "--height", |
| type=int, |
| default=480, |
| help="Video height" |
| ) |
| parser.add_argument( |
| "--width", |
| type=int, |
| default=848, |
| help="Video width" |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="./cosmos_test_output", |
| help="Output directory for results" |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device to use" |
| ) |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="float32", |
| choices=["float32", "float16", "bfloat16"], |
| help="Data type" |
| ) |
| |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| device = torch.device(args.device) |
| if args.dtype == "float16": |
| dtype = torch.float16 |
| elif args.dtype == "bfloat16": |
| dtype = torch.bfloat16 |
| else: |
| dtype = torch.float32 |
| |
| print(f"Using device: {device}, dtype: {dtype}") |
| |
| |
| print("Initializing CosmosVideoTokenizer...") |
| video_tokenizer = CosmosVideoTokenizer( |
| model_id=args.video_tokenizer_model_id, |
| device=device, |
| dtype=dtype |
| ) |
| print(f"Codebook size: {video_tokenizer.codebook_size}") |
| print(f"Downsampling factors: t={video_tokenizer.t_downsample}, " |
| f"h={video_tokenizer.h_downsample}, w={video_tokenizer.w_downsample}") |
| |
| |
| print(f"Loading dataset from: {args.csv_path}") |
| |
| |
| video_root_dir = args.video_root_dir |
| if video_root_dir is None: |
| csv_dir = os.path.dirname(args.csv_path) |
| if os.path.exists(os.path.join(csv_dir, 'video_reorg')): |
| video_root_dir = os.path.join(csv_dir, 'video_reorg') |
| elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')): |
| video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg') |
| else: |
| video_root_dir = csv_dir |
| print(f"Warning: Video directory not found, using CSV directory: {video_root_dir}") |
| |
| |
| tokenizer = T5Tokenizer.from_pretrained("google/umt5-base") |
| |
| |
| dataset = OpenVid1MDataset( |
| csv_path=args.csv_path, |
| video_root_dir=video_root_dir, |
| tokenizer=tokenizer, |
| num_frames=args.num_frames, |
| height=args.height, |
| width=args.width, |
| text_encoder_architecture="umt5-base", |
| ) |
| |
| print(f"Dataset size: {len(dataset)}") |
| |
| |
| if args.video_index >= len(dataset): |
| print(f"Error: video_index {args.video_index} >= dataset size {len(dataset)}") |
| return |
| |
| print(f"Loading video at index {args.video_index}...") |
| sample = dataset[args.video_index] |
| original_video = sample["video"] |
| |
| |
| row = dataset.data[args.video_index] |
| video_path = row.get('video', 'unknown') |
| caption = row.get('caption', 'no caption') |
| |
| print(f"Video path: {video_path}") |
| print(f"Caption: {caption}") |
| print(f"Original video shape: {original_video.shape}") |
| print(f"Original video range: [{original_video.min():.3f}, {original_video.max():.3f}]") |
| |
| |
| original_video = original_video.to(device=device, dtype=dtype) |
| |
| |
| print("\nEncoding video...") |
| with torch.no_grad(): |
| codes = video_tokenizer.encode(original_video.unsqueeze(0)) |
| |
| print(f"Encoded codes shape: {codes.shape}") |
| print(f"Codes range: [{codes.min().item()}, {codes.max().item()}]") |
| print(f"Codebook size: {video_tokenizer.codebook_size}") |
| |
| |
| print("\nDecoding video...") |
| with torch.no_grad(): |
| reconstructed_video = video_tokenizer.decode(codes) |
| reconstructed_video = reconstructed_video.squeeze(0) |
| |
| print(f"Reconstructed video shape: {reconstructed_video.shape}") |
| print(f"Reconstructed video range: [{reconstructed_video.min():.3f}, {reconstructed_video.max():.3f}]") |
| |
| |
| F_orig = original_video.shape[1] |
| F_recon = reconstructed_video.shape[1] |
| F_min = min(F_orig, F_recon) |
| |
| original_video = original_video[:, :F_min, :, :] |
| reconstructed_video = reconstructed_video[:, :F_min, :, :] |
| |
| |
| if original_video.shape[2:] != reconstructed_video.shape[2:]: |
| print(f"Resizing reconstructed video from {reconstructed_video.shape[2:]} to {original_video.shape[2:]}") |
| |
| reconstructed_video_resized = torch.zeros_like(original_video) |
| for f in range(F_min): |
| frame = reconstructed_video[:, f, :, :].unsqueeze(0) |
| frame_resized = torch.nn.functional.interpolate( |
| frame, size=original_video.shape[2:], mode='bilinear', align_corners=False |
| ) |
| reconstructed_video_resized[:, f, :, :] = frame_resized.squeeze(0) |
| reconstructed_video = reconstructed_video_resized |
| |
| |
| print("\nCalculating metrics...") |
| |
| |
| orig_f32 = original_video.to(torch.float32) |
| recon_f32 = reconstructed_video.to(torch.float32) |
| |
| |
| psnr_values = [] |
| mse_values = [] |
| ssim_values = [] |
| |
| for f in range(F_min): |
| orig_frame = orig_f32[:, f, :, :] |
| recon_frame = recon_f32[:, f, :, :] |
| |
| psnr = calculate_psnr(orig_frame, recon_frame) |
| mse = calculate_mse(orig_frame, recon_frame) |
| ssim = calculate_ssim(orig_frame, recon_frame) |
| |
| psnr_values.append(psnr) |
| mse_values.append(mse) |
| ssim_values.append(ssim) |
| |
| |
| avg_psnr = np.mean(psnr_values) |
| avg_mse = np.mean(mse_values) |
| avg_ssim = np.mean(ssim_values) |
| |
| print(f"\n=== Metrics ===") |
| print(f"PSNR: {avg_psnr:.2f} dB (per frame: {psnr_values})") |
| print(f"MSE: {avg_mse:.6f} (per frame: {mse_values})") |
| print(f"SSIM: {avg_ssim:.4f} (per frame: {ssim_values})") |
| |
| |
| metrics_file = os.path.join(args.output_dir, f"metrics_video_{args.video_index}.txt") |
| with open(metrics_file, 'w') as f: |
| f.write(f"Video Index: {args.video_index}\n") |
| f.write(f"Video Path: {video_path}\n") |
| f.write(f"Caption: {caption}\n") |
| f.write(f"\n=== Metrics ===\n") |
| f.write(f"Average PSNR: {avg_psnr:.2f} dB\n") |
| f.write(f"Average MSE: {avg_mse:.6f}\n") |
| f.write(f"Average SSIM: {avg_ssim:.4f}\n") |
| f.write(f"\nPer-frame PSNR: {psnr_values}\n") |
| f.write(f"Per-frame MSE: {mse_values}\n") |
| f.write(f"Per-frame SSIM: {ssim_values}\n") |
| |
| print(f"Saved metrics to: {metrics_file}") |
| |
| |
| print("\nCreating side-by-side comparison video...") |
| video_output_path = os.path.join(args.output_dir, f"comparison_video_{args.video_index}.mp4") |
| create_side_by_side_video(original_video, reconstructed_video, video_output_path, fps=8) |
| |
| |
| print("Creating comparison grid...") |
| grid_output_path = os.path.join(args.output_dir, f"comparison_grid_video_{args.video_index}.png") |
| create_comparison_grid(original_video, reconstructed_video, grid_output_path, nrow=4) |
| |
| print(f"\n=== Test Complete ===") |
| print(f"Results saved to: {args.output_dir}") |
| print(f" - Metrics: {metrics_file}") |
| print(f" - Side-by-side video: {video_output_path}") |
| print(f" - Comparison grid: {grid_output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|