world_model / wm /test /test_wan_vae.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import os
import sys
import torch
import cv2
import numpy as np
from PIL import Image
from pathlib import Path
# Add project root to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper
def load_video(video_path, num_frames=13):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames <= 0:
return None
# Sample num_frames with equal intervals
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
curr_idx = 0
while len(frames) < num_frames:
ret, frame = cap.read()
if not ret:
break
if curr_idx in indices:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
curr_idx += 1
cap.release()
if not frames:
return None
# Pad if we couldn't get enough frames
while len(frames) < num_frames:
frames.append(frames[-1])
return np.stack(frames) # (T, H, W, C)
def save_comparison(gt_frames, recon_frames, save_path):
# gt_frames, recon_frames: (T, H, W, C)
T, H, W, C = gt_frames.shape
# Calculate absolute difference and multiply by 5
diff = np.abs(gt_frames.astype(np.float32) - recon_frames.astype(np.float32))
diff_vis = (diff * 1.0).clip(0, 255).astype(np.uint8)
# Create a grid: 3 rows (GT, Recon, Diff*5), T columns
combined = np.zeros((H * 3, W * T, C), dtype=np.uint8)
for t in range(T):
combined[:H, t*W:(t+1)*W] = gt_frames[t]
combined[H:H*2, t*W:(t+1)*W] = recon_frames[t]
combined[H*2:, t*W:(t+1)*W] = diff_vis[t]
Image.fromarray(combined).save(save_path)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Initialize VAE
ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
if not os.path.exists(ckpt_path):
print(f"Checkpoint not found at {ckpt_path}")
return
print("Loading VAE model...")
vae = WanVAEWrapper(pretrained_path=ckpt_path).to(device)
datasets = {
"language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/",
"dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/",
"rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/",
"recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/"
}
output_dir = Path("/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_wan_vae")
output_dir.mkdir(parents=True, exist_ok=True)
# Wan VAE requires T = 1 + 4k frames. We'll use 13 frames for reconstruction.
T_vae = 13
# Final visualization will sample 8 frames as requested.
T_vis = 8
for ds_name, ds_root in datasets.items():
print(f"\n--- Processing {ds_name} ---")
meta_path = Path(ds_root) / "metadata_lite.pt"
if not meta_path.exists():
print(f"Metadata not found for {ds_name}")
continue
metadata = torch.load(meta_path, map_location='cpu', weights_only=False)
# Take 4 examples
count = 0
for i in range(len(metadata)):
if count >= 4: break
video_rel_path = metadata[i]['video_path']
# Handle both absolute and relative paths
if video_rel_path.startswith('/'):
video_path = video_rel_path
else:
video_path = str(Path(ds_root) / video_rel_path)
if not os.path.exists(video_path):
continue
print(f"Example {count}: {video_path}")
# Load frames
raw_frames = load_video(video_path, num_frames=T_vae)
if raw_frames is None: continue
# Prepare tensor for VAE: (B, T, C, H, W) in [-1, 1]
# Wan VAE expects frames to be normalized to [-1, 1]
video_tensor = torch.from_numpy(raw_frames).permute(0, 3, 1, 2).float() / 255.0
video_tensor = (video_tensor * 2.0 - 1.0).unsqueeze(0).to(device)
with torch.no_grad():
# Encode and Decode
latent = vae.encode(video_tensor)
print(f" Input shape (B, T, C, H, W): {video_tensor.shape}")
print(f" Latent shape (B, T', C', H', W'): {latent.shape}")
recon = vae.decode_to_pixel(latent) # [B, T, C, H, W] in [-1, 1]
# Convert reconstructed back to uint8 RGB (T, H, W, C)
recon_np = ((recon[0].cpu().permute(0, 2, 3, 1).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
gt_np = raw_frames # (T, H, W, C) uint8 RGB
# Sample 8 frames for visualization as requested
vis_indices = np.linspace(0, T_vae - 1, T_vis, dtype=int)
vis_gt = gt_np[vis_indices]
vis_recon = recon_np[vis_indices]
save_path = output_dir / f"{ds_name}_sample_{count}.png"
save_comparison(vis_gt, vis_recon, save_path)
print(f" Result saved to {save_path}")
count += 1
print("\nAll tests completed.")
if __name__ == "__main__":
main()