| | import os |
| | import sys |
| | import torch |
| | import cv2 |
| | import numpy as np |
| | from PIL import Image |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) |
| |
|
| | from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper |
| |
|
| | def load_video(video_path, num_frames=13): |
| | cap = cv2.VideoCapture(video_path) |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | |
| | if total_frames <= 0: |
| | return None |
| | |
| | |
| | indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
| | frames = [] |
| | |
| | curr_idx = 0 |
| | while len(frames) < num_frames: |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| | if curr_idx in indices: |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames.append(frame) |
| | curr_idx += 1 |
| | cap.release() |
| | |
| | if not frames: |
| | return None |
| | |
| | |
| | while len(frames) < num_frames: |
| | frames.append(frames[-1]) |
| | |
| | return np.stack(frames) |
| |
|
| | def save_comparison(gt_frames, recon_frames, save_path): |
| | |
| | T, H, W, C = gt_frames.shape |
| | |
| | |
| | diff = np.abs(gt_frames.astype(np.float32) - recon_frames.astype(np.float32)) |
| | diff_vis = (diff * 1.0).clip(0, 255).astype(np.uint8) |
| | |
| | |
| | combined = np.zeros((H * 3, W * T, C), dtype=np.uint8) |
| | |
| | for t in range(T): |
| | combined[:H, t*W:(t+1)*W] = gt_frames[t] |
| | combined[H:H*2, t*W:(t+1)*W] = recon_frames[t] |
| | combined[H*2:, t*W:(t+1)*W] = diff_vis[t] |
| | |
| | Image.fromarray(combined).save(save_path) |
| |
|
| | def main(): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Using device: {device}") |
| | |
| | |
| | ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" |
| | if not os.path.exists(ckpt_path): |
| | print(f"Checkpoint not found at {ckpt_path}") |
| | return |
| |
|
| | print("Loading VAE model...") |
| | vae = WanVAEWrapper(pretrained_path=ckpt_path).to(device) |
| | |
| | datasets = { |
| | "language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/", |
| | "dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/", |
| | "rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/", |
| | "recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/" |
| | } |
| | |
| | output_dir = Path("/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_wan_vae") |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | T_vae = 13 |
| | |
| | T_vis = 8 |
| | |
| | for ds_name, ds_root in datasets.items(): |
| | print(f"\n--- Processing {ds_name} ---") |
| | meta_path = Path(ds_root) / "metadata_lite.pt" |
| | if not meta_path.exists(): |
| | print(f"Metadata not found for {ds_name}") |
| | continue |
| | |
| | metadata = torch.load(meta_path, map_location='cpu', weights_only=False) |
| | |
| | |
| | count = 0 |
| | for i in range(len(metadata)): |
| | if count >= 4: break |
| | |
| | video_rel_path = metadata[i]['video_path'] |
| | |
| | if video_rel_path.startswith('/'): |
| | video_path = video_rel_path |
| | else: |
| | video_path = str(Path(ds_root) / video_rel_path) |
| | |
| | if not os.path.exists(video_path): |
| | continue |
| | |
| | print(f"Example {count}: {video_path}") |
| | |
| | |
| | raw_frames = load_video(video_path, num_frames=T_vae) |
| | if raw_frames is None: continue |
| | |
| | |
| | |
| | video_tensor = torch.from_numpy(raw_frames).permute(0, 3, 1, 2).float() / 255.0 |
| | video_tensor = (video_tensor * 2.0 - 1.0).unsqueeze(0).to(device) |
| | |
| | with torch.no_grad(): |
| | |
| | latent = vae.encode(video_tensor) |
| | print(f" Input shape (B, T, C, H, W): {video_tensor.shape}") |
| | print(f" Latent shape (B, T', C', H', W'): {latent.shape}") |
| | recon = vae.decode_to_pixel(latent) |
| | |
| | |
| | recon_np = ((recon[0].cpu().permute(0, 2, 3, 1).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8) |
| | gt_np = raw_frames |
| | |
| | |
| | vis_indices = np.linspace(0, T_vae - 1, T_vis, dtype=int) |
| | vis_gt = gt_np[vis_indices] |
| | vis_recon = recon_np[vis_indices] |
| | |
| | save_path = output_dir / f"{ds_name}_sample_{count}.png" |
| | save_comparison(vis_gt, vis_recon, save_path) |
| | print(f" Result saved to {save_path}") |
| | count += 1 |
| |
|
| | print("\nAll tests completed.") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|