import torch import numpy as np import matplotlib.pyplot as plt import os import sys # Add project root to sys.path sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model") from wm.dynamics.bi_fulltrajectory import Bidirectional_FullTrajectory from wm.model.diffusion.flow_matching import FlowMatchScheduler # Mocking the CLASS_MAPS for testing if needed, but they should be importable from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP def create_checkerboard_o0(B, H, W, C=3): # Create a black-white checkerboard pattern for the first frame o_0 = torch.zeros((B, H, W, C)) block_size = 8 for i in range(0, H, block_size): for j in range(0, W, block_size): if (i // block_size + j // block_size) % 2 == 0: o_0[:, i:i+block_size, j:j+block_size, :] = 1.0 return o_0 def test_generate(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Testing on {device}") ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" if not os.path.exists(ckpt_path): print("VAE Checkpoint not found, skipping full test.") return # Configuration model_config = { 'in_channels': 16, 'patch_size': 2, 'dim': 128, 'num_layers': 2, 'num_heads': 4, 'action_dim': 16, 'action_compress_rate': 4, 'max_frames': 17, 'vae_name': 'WanVAE', 'vae_config': [ckpt_path], 'scheduler': FlowMatchScheduler(), 'training_timesteps': 1000 } try: model = Bidirectional_FullTrajectory('VideoDiT', model_config).to(device) print("Model initialized.") B, H, W, C = 2, 64, 64, 3 T_pixel = 17 # Wan compatible 1 + 4*k o_0 = create_checkerboard_o0(B, H, W, C).to(device) a = torch.randn(B, T_pixel, 16).to(device) print(f"Generating {T_pixel} frames...") # Use very few steps for speed video = model.generate(o_0, a, num_inference_steps=5) print(f"Generated video shape: {video.shape}") assert video.shape == (B, T_pixel, H, W, C) print("Generation test successful!") except Exception as e: print(f"Generation failed: {e}") import traceback traceback.print_exc() if __name__ == "__main__": test_generate()