| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import os |
| | import sys |
| |
|
| | |
| | sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model") |
| |
|
| | from wm.dynamics.bi_fulltrajectory import Bidirectional_FullTrajectory |
| | from wm.model.diffusion.flow_matching import FlowMatchScheduler |
| |
|
| | |
| | from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP |
| |
|
| | def create_checkerboard_o0(B, H, W, C=3): |
| | |
| | o_0 = torch.zeros((B, H, W, C)) |
| | block_size = 8 |
| | for i in range(0, H, block_size): |
| | for j in range(0, W, block_size): |
| | if (i // block_size + j // block_size) % 2 == 0: |
| | o_0[:, i:i+block_size, j:j+block_size, :] = 1.0 |
| | return o_0 |
| |
|
| | def test_generate(): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Testing on {device}") |
| | |
| | ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" |
| | if not os.path.exists(ckpt_path): |
| | print("VAE Checkpoint not found, skipping full test.") |
| | return |
| |
|
| | |
| | model_config = { |
| | 'in_channels': 16, |
| | 'patch_size': 2, |
| | 'dim': 128, |
| | 'num_layers': 2, |
| | 'num_heads': 4, |
| | 'action_dim': 16, |
| | 'action_compress_rate': 4, |
| | 'max_frames': 17, |
| | 'vae_name': 'WanVAE', |
| | 'vae_config': [ckpt_path], |
| | 'scheduler': FlowMatchScheduler(), |
| | 'training_timesteps': 1000 |
| | } |
| |
|
| | try: |
| | model = Bidirectional_FullTrajectory('VideoDiT', model_config).to(device) |
| | print("Model initialized.") |
| | |
| | B, H, W, C = 2, 64, 64, 3 |
| | T_pixel = 17 |
| | o_0 = create_checkerboard_o0(B, H, W, C).to(device) |
| | a = torch.randn(B, T_pixel, 16).to(device) |
| | |
| | print(f"Generating {T_pixel} frames...") |
| | |
| | video = model.generate(o_0, a, num_inference_steps=5) |
| | |
| | print(f"Generated video shape: {video.shape}") |
| | assert video.shape == (B, T_pixel, H, W, C) |
| | print("Generation test successful!") |
| | |
| | except Exception as e: |
| | print(f"Generation failed: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | if __name__ == "__main__": |
| | test_generate() |
| |
|