File size: 2,425 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

# Add project root to sys.path
sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model")

from wm.dynamics.bi_fulltrajectory import Bidirectional_FullTrajectory
from wm.model.diffusion.flow_matching import FlowMatchScheduler

# Mocking the CLASS_MAPS for testing if needed, but they should be importable
from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP

def create_checkerboard_o0(B, H, W, C=3):
    # Create a black-white checkerboard pattern for the first frame
    o_0 = torch.zeros((B, H, W, C))
    block_size = 8
    for i in range(0, H, block_size):
        for j in range(0, W, block_size):
            if (i // block_size + j // block_size) % 2 == 0:
                o_0[:, i:i+block_size, j:j+block_size, :] = 1.0
    return o_0

def test_generate():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Testing on {device}")
    
    ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
    if not os.path.exists(ckpt_path):
        print("VAE Checkpoint not found, skipping full test.")
        return

    # Configuration
    model_config = {
        'in_channels': 16,
        'patch_size': 2,
        'dim': 128,
        'num_layers': 2,
        'num_heads': 4,
        'action_dim': 16,
        'action_compress_rate': 4,
        'max_frames': 17,
        'vae_name': 'WanVAE',
        'vae_config': [ckpt_path], 
        'scheduler': FlowMatchScheduler(),
        'training_timesteps': 1000
    }

    try:
        model = Bidirectional_FullTrajectory('VideoDiT', model_config).to(device)
        print("Model initialized.")
        
        B, H, W, C = 2, 64, 64, 3
        T_pixel = 17 # Wan compatible 1 + 4*k
        o_0 = create_checkerboard_o0(B, H, W, C).to(device)
        a = torch.randn(B, T_pixel, 16).to(device)
        
        print(f"Generating {T_pixel} frames...")
        # Use very few steps for speed
        video = model.generate(o_0, a, num_inference_steps=5)
        
        print(f"Generated video shape: {video.shape}")
        assert video.shape == (B, T_pixel, H, W, C)
        print("Generation test successful!")
        
    except Exception as e:
        print(f"Generation failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    test_generate()