File size: 5,153 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Minimal MPS forward pass test - creates WanModel with random weights and tests inference on MPS."""
import os, sys, gc, time
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

print("=" * 60)
print("Wan2GP MPS Forward Pass Test (Random Weights)")
print("=" * 60)

# Step 1: Apply MPS patch early
import torch
from shared.mps.device_patch import apply_mps_patch
apply_mps_patch()

print(f"\n[1] PyTorch {torch.__version__}, MPS: {torch.backends.mps.is_available()}")
print(f"    Default device: {torch.get_default_device()}")

# Step 2: Import WanModel
print("\n[2] Importing WanModel...")
try:
    from models.wan.modules.model import WanModel
    print("    WanModel imported OK")
except Exception as e:
    print(f"    FAILED: {e}")
    import traceback; traceback.print_exc()
    sys.exit(1)

# Step 3: Create a small WanModel (1.3B config)
print("\n[3] Creating WanModel with 1.3B config...")
config = {
    "dim": 1536,
    "ffn_dim": 8960,
    "freq_dim": 256,
    "num_heads": 12,
    "num_layers": 30,
    "patch_size": (1, 2, 2),
    "window_size": (-1, -1),
    "qk_norm": True,
    "cross_attn_norm": True,
    "eps": 1e-6,
    "text_len": 512,
    "vae_stride": (4, 8, 8),
}

try:
    model = WanModel(
        dim=config["dim"],
        ffn_dim=config["ffn_dim"],
        freq_dim=config["freq_dim"],
        num_heads=config["num_heads"],
        num_layers=config["num_layers"],
        patch_size=config["patch_size"],
        window_size=config["window_size"],
        qk_norm=config["qk_norm"],
        cross_attn_norm=config["cross_attn_norm"],
        eps=config["eps"],
    )
    print(f"    Model created: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters")
except Exception as e:
    print(f"    FAILED creating model: {e}")
    import traceback; traceback.print_exc()
    sys.exit(1)

# Step 4: Move to MPS
print("\n[4] Moving model to MPS...")
try:
    model = model.to("mps", dtype=torch.bfloat16)
    model.eval()
    print("    Model on MPS OK")
except Exception as e:
    print(f"    FAILED: {e}")
    import traceback; traceback.print_exc()
    sys.exit(1)

# Step 5: Create dummy inputs
print("\n[5] Creating dummy inputs...")
batch_size = 1
# Wan patch embedding is Conv3d with patch_size (1,2,2), input needs (B, C, T, H, W)
T, H, W = 9, 30, 52  # frames, height, width (divisible by vae_stride)
patch_size = config["patch_size"]  # (1, 2, 2)
dim = config["dim"]

# Random latent input: (B, in_channels, T, H, W) where in_channels matches patch_embedding
# patch_embedding: Conv3d(in_channels=4, out_channels=dim, kernel_size=patch_size, stride=patch_size)
# After patch embedding: (B, dim, T//1, H//2, W//2) -> flattened to (B, seq_len, dim)
in_channels = 16  # Wan2.1 VAE latent channels
latent_T = T // patch_size[0]  # 9
latent_H = H // patch_size[1]  # 15
latent_W = W // patch_size[2]  # 26
seq_len = latent_T * latent_H * latent_W  # 3510

x_5d = torch.randn(batch_size, in_channels, T, H, W, device="mps", dtype=torch.bfloat16)
print(f"    5D Latent shape: {x_5d.shape}")

# WanModel forward expects x as a LIST of 5D tensors
x_list_input = [x_5d]

# Disable skips_steps_cache (TeaCache step skipping - not needed for basic inference test)
if hasattr(model, "cache"):
    model.cache = None

# text_embedding: Linear(4096, dim) - UMT5-XXL text encoder output is 4096-dim
text_encoder_dim = 4096
text_len = config["text_len"]
# context must be a LIST of 2D tensors [seq_len, text_dim] per item in x_list
context = torch.randn(batch_size, text_len, text_encoder_dim, device="mps", dtype=torch.bfloat16)
# model expects list of 2D tensors: context_list = [context[i] for i in range(batch)]
context_list = [context[i] for i in range(batch_size)]
print(f"    Context list: {len(context_list)} x {context_list[0].shape}")

t = torch.tensor([500] * batch_size, device="mps", dtype=torch.bfloat16)
print(f"    Timestep: {t}")

# Step 6: Forward pass
print("\n[6] Running forward pass...")
torch.mps.synchronize()
start = time.time()
try:
    with torch.no_grad():
        with torch.autocast("mps", dtype=torch.bfloat16):
            output = model(x_list_input, t, context_list)
    torch.mps.synchronize()
    elapsed = time.time() - start
    print(f"    Forward pass OK: output type = {type(output)}")
    if isinstance(output, list):
        for j, o in enumerate(output):
            print(f"    Output[{j}] shape = {o.shape}, dtype = {o.dtype}")
            print(f"    Output[{j}] range: [{o.min():.3f}, {o.max():.3f}]")
    else:
        print(f"    Output shape = {output.shape}, dtype = {output.dtype}")
except Exception as e:
    elapsed = time.time() - start
    print(f"    FAILED after {elapsed:.2f}s: {e}")
    import traceback; traceback.print_exc()
    sys.exit(1)

# Step 7: Memory check
print("\n[7] Memory stats...")
print(f"    MPS memory allocated: {torch.mps.current_allocated_memory() / 1024**3:.2f}GB")
print(f"    MPS memory driver allocated: {torch.mps.driver_allocated_memory() / 1024**3:.2f}GB")

print("\n" + "=" * 60)
print("MPS Forward Pass Test PASSED!")
print("=" * 60)