| """Minimal MPS forward pass test - creates WanModel with random weights and tests inference on MPS.""" |
| import os, sys, gc, time |
| REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) |
| if REPO_ROOT not in sys.path: |
| sys.path.insert(0, REPO_ROOT) |
|
|
| print("=" * 60) |
| print("Wan2GP MPS Forward Pass Test (Random Weights)") |
| print("=" * 60) |
|
|
| |
| import torch |
| from shared.mps.device_patch import apply_mps_patch |
| apply_mps_patch() |
|
|
| print(f"\n[1] PyTorch {torch.__version__}, MPS: {torch.backends.mps.is_available()}") |
| print(f" Default device: {torch.get_default_device()}") |
|
|
| |
| print("\n[2] Importing WanModel...") |
| try: |
| from models.wan.modules.model import WanModel |
| print(" WanModel imported OK") |
| except Exception as e: |
| print(f" FAILED: {e}") |
| import traceback; traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| print("\n[3] Creating WanModel with 1.3B config...") |
| config = { |
| "dim": 1536, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "num_heads": 12, |
| "num_layers": 30, |
| "patch_size": (1, 2, 2), |
| "window_size": (-1, -1), |
| "qk_norm": True, |
| "cross_attn_norm": True, |
| "eps": 1e-6, |
| "text_len": 512, |
| "vae_stride": (4, 8, 8), |
| } |
|
|
| try: |
| model = WanModel( |
| dim=config["dim"], |
| ffn_dim=config["ffn_dim"], |
| freq_dim=config["freq_dim"], |
| num_heads=config["num_heads"], |
| num_layers=config["num_layers"], |
| patch_size=config["patch_size"], |
| window_size=config["window_size"], |
| qk_norm=config["qk_norm"], |
| cross_attn_norm=config["cross_attn_norm"], |
| eps=config["eps"], |
| ) |
| print(f" Model created: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters") |
| except Exception as e: |
| print(f" FAILED creating model: {e}") |
| import traceback; traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| print("\n[4] Moving model to MPS...") |
| try: |
| model = model.to("mps", dtype=torch.bfloat16) |
| model.eval() |
| print(" Model on MPS OK") |
| except Exception as e: |
| print(f" FAILED: {e}") |
| import traceback; traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| print("\n[5] Creating dummy inputs...") |
| batch_size = 1 |
| |
| T, H, W = 9, 30, 52 |
| patch_size = config["patch_size"] |
| dim = config["dim"] |
|
|
| |
| |
| |
| in_channels = 16 |
| latent_T = T // patch_size[0] |
| latent_H = H // patch_size[1] |
| latent_W = W // patch_size[2] |
| seq_len = latent_T * latent_H * latent_W |
|
|
| x_5d = torch.randn(batch_size, in_channels, T, H, W, device="mps", dtype=torch.bfloat16) |
| print(f" 5D Latent shape: {x_5d.shape}") |
|
|
| |
| x_list_input = [x_5d] |
|
|
| |
| if hasattr(model, "cache"): |
| model.cache = None |
|
|
| |
| text_encoder_dim = 4096 |
| text_len = config["text_len"] |
| |
| context = torch.randn(batch_size, text_len, text_encoder_dim, device="mps", dtype=torch.bfloat16) |
| |
| context_list = [context[i] for i in range(batch_size)] |
| print(f" Context list: {len(context_list)} x {context_list[0].shape}") |
|
|
| t = torch.tensor([500] * batch_size, device="mps", dtype=torch.bfloat16) |
| print(f" Timestep: {t}") |
|
|
| |
| print("\n[6] Running forward pass...") |
| torch.mps.synchronize() |
| start = time.time() |
| try: |
| with torch.no_grad(): |
| with torch.autocast("mps", dtype=torch.bfloat16): |
| output = model(x_list_input, t, context_list) |
| torch.mps.synchronize() |
| elapsed = time.time() - start |
| print(f" Forward pass OK: output type = {type(output)}") |
| if isinstance(output, list): |
| for j, o in enumerate(output): |
| print(f" Output[{j}] shape = {o.shape}, dtype = {o.dtype}") |
| print(f" Output[{j}] range: [{o.min():.3f}, {o.max():.3f}]") |
| else: |
| print(f" Output shape = {output.shape}, dtype = {output.dtype}") |
| except Exception as e: |
| elapsed = time.time() - start |
| print(f" FAILED after {elapsed:.2f}s: {e}") |
| import traceback; traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| print("\n[7] Memory stats...") |
| print(f" MPS memory allocated: {torch.mps.current_allocated_memory() / 1024**3:.2f}GB") |
| print(f" MPS memory driver allocated: {torch.mps.driver_allocated_memory() / 1024**3:.2f}GB") |
|
|
| print("\n" + "=" * 60) |
| print("MPS Forward Pass Test PASSED!") |
| print("=" * 60) |
|
|