File size: 5,518 Bytes
f17ae24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import os
import sys
import torch
import cv2
import numpy as np
from PIL import Image
from pathlib import Path
# Add project root to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper
def load_video(video_path, num_frames=13):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames <= 0:
return None
# Sample num_frames with equal intervals
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
curr_idx = 0
while len(frames) < num_frames:
ret, frame = cap.read()
if not ret:
break
if curr_idx in indices:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
curr_idx += 1
cap.release()
if not frames:
return None
# Pad if we couldn't get enough frames
while len(frames) < num_frames:
frames.append(frames[-1])
return np.stack(frames) # (T, H, W, C)
def save_comparison(gt_frames, recon_frames, save_path):
# gt_frames, recon_frames: (T, H, W, C)
T, H, W, C = gt_frames.shape
# Calculate absolute difference and multiply by 5
diff = np.abs(gt_frames.astype(np.float32) - recon_frames.astype(np.float32))
diff_vis = (diff * 1.0).clip(0, 255).astype(np.uint8)
# Create a grid: 3 rows (GT, Recon, Diff*5), T columns
combined = np.zeros((H * 3, W * T, C), dtype=np.uint8)
for t in range(T):
combined[:H, t*W:(t+1)*W] = gt_frames[t]
combined[H:H*2, t*W:(t+1)*W] = recon_frames[t]
combined[H*2:, t*W:(t+1)*W] = diff_vis[t]
Image.fromarray(combined).save(save_path)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Initialize VAE
ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
if not os.path.exists(ckpt_path):
print(f"Checkpoint not found at {ckpt_path}")
return
print("Loading VAE model...")
vae = WanVAEWrapper(pretrained_path=ckpt_path).to(device)
datasets = {
"language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/",
"dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/",
"rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/",
"recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/"
}
output_dir = Path("/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_wan_vae")
output_dir.mkdir(parents=True, exist_ok=True)
# Wan VAE requires T = 1 + 4k frames. We'll use 13 frames for reconstruction.
T_vae = 13
# Final visualization will sample 8 frames as requested.
T_vis = 8
for ds_name, ds_root in datasets.items():
print(f"\n--- Processing {ds_name} ---")
meta_path = Path(ds_root) / "metadata_lite.pt"
if not meta_path.exists():
print(f"Metadata not found for {ds_name}")
continue
metadata = torch.load(meta_path, map_location='cpu', weights_only=False)
# Take 4 examples
count = 0
for i in range(len(metadata)):
if count >= 4: break
video_rel_path = metadata[i]['video_path']
# Handle both absolute and relative paths
if video_rel_path.startswith('/'):
video_path = video_rel_path
else:
video_path = str(Path(ds_root) / video_rel_path)
if not os.path.exists(video_path):
continue
print(f"Example {count}: {video_path}")
# Load frames
raw_frames = load_video(video_path, num_frames=T_vae)
if raw_frames is None: continue
# Prepare tensor for VAE: (B, T, C, H, W) in [-1, 1]
# Wan VAE expects frames to be normalized to [-1, 1]
video_tensor = torch.from_numpy(raw_frames).permute(0, 3, 1, 2).float() / 255.0
video_tensor = (video_tensor * 2.0 - 1.0).unsqueeze(0).to(device)
with torch.no_grad():
# Encode and Decode
latent = vae.encode(video_tensor)
print(f" Input shape (B, T, C, H, W): {video_tensor.shape}")
print(f" Latent shape (B, T', C', H', W'): {latent.shape}")
recon = vae.decode_to_pixel(latent) # [B, T, C, H, W] in [-1, 1]
# Convert reconstructed back to uint8 RGB (T, H, W, C)
recon_np = ((recon[0].cpu().permute(0, 2, 3, 1).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
gt_np = raw_frames # (T, H, W, C) uint8 RGB
# Sample 8 frames for visualization as requested
vis_indices = np.linspace(0, T_vae - 1, T_vis, dtype=int)
vis_gt = gt_np[vis_indices]
vis_recon = recon_np[vis_indices]
save_path = output_dir / f"{ds_name}_sample_{count}.png"
save_comparison(vis_gt, vis_recon, save_path)
print(f" Result saved to {save_path}")
count += 1
print("\nAll tests completed.")
if __name__ == "__main__":
main()
|