import os import sys import torch import cv2 import numpy as np from PIL import Image from pathlib import Path # Add project root to sys.path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper def load_video(video_path, num_frames=13): cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames <= 0: return None # Sample num_frames with equal intervals indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) frames = [] curr_idx = 0 while len(frames) < num_frames: ret, frame = cap.read() if not ret: break if curr_idx in indices: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) curr_idx += 1 cap.release() if not frames: return None # Pad if we couldn't get enough frames while len(frames) < num_frames: frames.append(frames[-1]) return np.stack(frames) # (T, H, W, C) def save_comparison(gt_frames, recon_frames, save_path): # gt_frames, recon_frames: (T, H, W, C) T, H, W, C = gt_frames.shape # Calculate absolute difference and multiply by 5 diff = np.abs(gt_frames.astype(np.float32) - recon_frames.astype(np.float32)) diff_vis = (diff * 1.0).clip(0, 255).astype(np.uint8) # Create a grid: 3 rows (GT, Recon, Diff*5), T columns combined = np.zeros((H * 3, W * T, C), dtype=np.uint8) for t in range(T): combined[:H, t*W:(t+1)*W] = gt_frames[t] combined[H:H*2, t*W:(t+1)*W] = recon_frames[t] combined[H*2:, t*W:(t+1)*W] = diff_vis[t] Image.fromarray(combined).save(save_path) def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Initialize VAE ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" if not os.path.exists(ckpt_path): print(f"Checkpoint not found at {ckpt_path}") return print("Loading VAE model...") vae = WanVAEWrapper(pretrained_path=ckpt_path).to(device) datasets = { "language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/", "dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/", "rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/", "recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/" } output_dir = Path("/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_wan_vae") output_dir.mkdir(parents=True, exist_ok=True) # Wan VAE requires T = 1 + 4k frames. We'll use 13 frames for reconstruction. T_vae = 13 # Final visualization will sample 8 frames as requested. T_vis = 8 for ds_name, ds_root in datasets.items(): print(f"\n--- Processing {ds_name} ---") meta_path = Path(ds_root) / "metadata_lite.pt" if not meta_path.exists(): print(f"Metadata not found for {ds_name}") continue metadata = torch.load(meta_path, map_location='cpu', weights_only=False) # Take 4 examples count = 0 for i in range(len(metadata)): if count >= 4: break video_rel_path = metadata[i]['video_path'] # Handle both absolute and relative paths if video_rel_path.startswith('/'): video_path = video_rel_path else: video_path = str(Path(ds_root) / video_rel_path) if not os.path.exists(video_path): continue print(f"Example {count}: {video_path}") # Load frames raw_frames = load_video(video_path, num_frames=T_vae) if raw_frames is None: continue # Prepare tensor for VAE: (B, T, C, H, W) in [-1, 1] # Wan VAE expects frames to be normalized to [-1, 1] video_tensor = torch.from_numpy(raw_frames).permute(0, 3, 1, 2).float() / 255.0 video_tensor = (video_tensor * 2.0 - 1.0).unsqueeze(0).to(device) with torch.no_grad(): # Encode and Decode latent = vae.encode(video_tensor) print(f" Input shape (B, T, C, H, W): {video_tensor.shape}") print(f" Latent shape (B, T', C', H', W'): {latent.shape}") recon = vae.decode_to_pixel(latent) # [B, T, C, H, W] in [-1, 1] # Convert reconstructed back to uint8 RGB (T, H, W, C) recon_np = ((recon[0].cpu().permute(0, 2, 3, 1).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8) gt_np = raw_frames # (T, H, W, C) uint8 RGB # Sample 8 frames for visualization as requested vis_indices = np.linspace(0, T_vae - 1, T_vis, dtype=int) vis_gt = gt_np[vis_indices] vis_recon = recon_np[vis_indices] save_path = output_dir / f"{ds_name}_sample_{count}.png" save_comparison(vis_gt, vis_recon, save_path) print(f" Result saved to {save_path}") count += 1 print("\nAll tests completed.") if __name__ == "__main__": main()