ColabWan / shared /mps /test_mps_forward.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
5.15 kB
"""Minimal MPS forward pass test - creates WanModel with random weights and tests inference on MPS."""
import os, sys, gc, time
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
print("=" * 60)
print("Wan2GP MPS Forward Pass Test (Random Weights)")
print("=" * 60)
# Step 1: Apply MPS patch early
import torch
from shared.mps.device_patch import apply_mps_patch
apply_mps_patch()
print(f"\n[1] PyTorch {torch.__version__}, MPS: {torch.backends.mps.is_available()}")
print(f" Default device: {torch.get_default_device()}")
# Step 2: Import WanModel
print("\n[2] Importing WanModel...")
try:
from models.wan.modules.model import WanModel
print(" WanModel imported OK")
except Exception as e:
print(f" FAILED: {e}")
import traceback; traceback.print_exc()
sys.exit(1)
# Step 3: Create a small WanModel (1.3B config)
print("\n[3] Creating WanModel with 1.3B config...")
config = {
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"num_heads": 12,
"num_layers": 30,
"patch_size": (1, 2, 2),
"window_size": (-1, -1),
"qk_norm": True,
"cross_attn_norm": True,
"eps": 1e-6,
"text_len": 512,
"vae_stride": (4, 8, 8),
}
try:
model = WanModel(
dim=config["dim"],
ffn_dim=config["ffn_dim"],
freq_dim=config["freq_dim"],
num_heads=config["num_heads"],
num_layers=config["num_layers"],
patch_size=config["patch_size"],
window_size=config["window_size"],
qk_norm=config["qk_norm"],
cross_attn_norm=config["cross_attn_norm"],
eps=config["eps"],
)
print(f" Model created: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters")
except Exception as e:
print(f" FAILED creating model: {e}")
import traceback; traceback.print_exc()
sys.exit(1)
# Step 4: Move to MPS
print("\n[4] Moving model to MPS...")
try:
model = model.to("mps", dtype=torch.bfloat16)
model.eval()
print(" Model on MPS OK")
except Exception as e:
print(f" FAILED: {e}")
import traceback; traceback.print_exc()
sys.exit(1)
# Step 5: Create dummy inputs
print("\n[5] Creating dummy inputs...")
batch_size = 1
# Wan patch embedding is Conv3d with patch_size (1,2,2), input needs (B, C, T, H, W)
T, H, W = 9, 30, 52 # frames, height, width (divisible by vae_stride)
patch_size = config["patch_size"] # (1, 2, 2)
dim = config["dim"]
# Random latent input: (B, in_channels, T, H, W) where in_channels matches patch_embedding
# patch_embedding: Conv3d(in_channels=4, out_channels=dim, kernel_size=patch_size, stride=patch_size)
# After patch embedding: (B, dim, T//1, H//2, W//2) -> flattened to (B, seq_len, dim)
in_channels = 16 # Wan2.1 VAE latent channels
latent_T = T // patch_size[0] # 9
latent_H = H // patch_size[1] # 15
latent_W = W // patch_size[2] # 26
seq_len = latent_T * latent_H * latent_W # 3510
x_5d = torch.randn(batch_size, in_channels, T, H, W, device="mps", dtype=torch.bfloat16)
print(f" 5D Latent shape: {x_5d.shape}")
# WanModel forward expects x as a LIST of 5D tensors
x_list_input = [x_5d]
# Disable skips_steps_cache (TeaCache step skipping - not needed for basic inference test)
if hasattr(model, "cache"):
model.cache = None
# text_embedding: Linear(4096, dim) - UMT5-XXL text encoder output is 4096-dim
text_encoder_dim = 4096
text_len = config["text_len"]
# context must be a LIST of 2D tensors [seq_len, text_dim] per item in x_list
context = torch.randn(batch_size, text_len, text_encoder_dim, device="mps", dtype=torch.bfloat16)
# model expects list of 2D tensors: context_list = [context[i] for i in range(batch)]
context_list = [context[i] for i in range(batch_size)]
print(f" Context list: {len(context_list)} x {context_list[0].shape}")
t = torch.tensor([500] * batch_size, device="mps", dtype=torch.bfloat16)
print(f" Timestep: {t}")
# Step 6: Forward pass
print("\n[6] Running forward pass...")
torch.mps.synchronize()
start = time.time()
try:
with torch.no_grad():
with torch.autocast("mps", dtype=torch.bfloat16):
output = model(x_list_input, t, context_list)
torch.mps.synchronize()
elapsed = time.time() - start
print(f" Forward pass OK: output type = {type(output)}")
if isinstance(output, list):
for j, o in enumerate(output):
print(f" Output[{j}] shape = {o.shape}, dtype = {o.dtype}")
print(f" Output[{j}] range: [{o.min():.3f}, {o.max():.3f}]")
else:
print(f" Output shape = {output.shape}, dtype = {output.dtype}")
except Exception as e:
elapsed = time.time() - start
print(f" FAILED after {elapsed:.2f}s: {e}")
import traceback; traceback.print_exc()
sys.exit(1)
# Step 7: Memory check
print("\n[7] Memory stats...")
print(f" MPS memory allocated: {torch.mps.current_allocated_memory() / 1024**3:.2f}GB")
print(f" MPS memory driver allocated: {torch.mps.driver_allocated_memory() / 1024**3:.2f}GB")
print("\n" + "=" * 60)
print("MPS Forward Pass Test PASSED!")
print("=" * 60)