world_model / wm /test /test_generate.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
# Add project root to sys.path
sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model")
from wm.dynamics.bi_fulltrajectory import Bidirectional_FullTrajectory
from wm.model.diffusion.flow_matching import FlowMatchScheduler
# Mocking the CLASS_MAPS for testing if needed, but they should be importable
from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP
def create_checkerboard_o0(B, H, W, C=3):
# Create a black-white checkerboard pattern for the first frame
o_0 = torch.zeros((B, H, W, C))
block_size = 8
for i in range(0, H, block_size):
for j in range(0, W, block_size):
if (i // block_size + j // block_size) % 2 == 0:
o_0[:, i:i+block_size, j:j+block_size, :] = 1.0
return o_0
def test_generate():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Testing on {device}")
ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
if not os.path.exists(ckpt_path):
print("VAE Checkpoint not found, skipping full test.")
return
# Configuration
model_config = {
'in_channels': 16,
'patch_size': 2,
'dim': 128,
'num_layers': 2,
'num_heads': 4,
'action_dim': 16,
'action_compress_rate': 4,
'max_frames': 17,
'vae_name': 'WanVAE',
'vae_config': [ckpt_path],
'scheduler': FlowMatchScheduler(),
'training_timesteps': 1000
}
try:
model = Bidirectional_FullTrajectory('VideoDiT', model_config).to(device)
print("Model initialized.")
B, H, W, C = 2, 64, 64, 3
T_pixel = 17 # Wan compatible 1 + 4*k
o_0 = create_checkerboard_o0(B, H, W, C).to(device)
a = torch.randn(B, T_pixel, 16).to(device)
print(f"Generating {T_pixel} frames...")
# Use very few steps for speed
video = model.generate(o_0, a, num_inference_steps=5)
print(f"Generated video shape: {video.shape}")
assert video.shape == (B, T_pixel, H, W, C)
print("Generation test successful!")
except Exception as e:
print(f"Generation failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_generate()