File size: 2,762 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | #!/usr/bin/env python3
"""Test the full generation path for Wan2GP on MPS."""
import os
import sys
import traceback
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
os.chdir(REPO_ROOT)
# Track what gets called
class CallTracker:
def __init__(self):
self.calls = []
def add(self, msg):
self.calls.append(msg)
print(f"[TRACKER] {msg}")
tracker = CallTracker()
# Import torch FIRST
tracker.add("Importing torch")
import torch
tracker.add(f"torch version: {torch.__version__}")
tracker.add(f"cuda compiled: {torch.cuda.is_compiled() if hasattr(torch.cuda, 'is_compiled') else 'N/A'}")
tracker.add(f"mps available: {torch.backends.mps.is_available()}")
# Now import and apply device_patch
tracker.add("Applying MPS patch")
from shared.mps.device_patch import apply_mps_patch
apply_mps_patch()
tracker.add(f"torch.compile patched: {torch.compile.__name__ == '_patched_compile'}")
# Now import the model module that would be used
tracker.add("Importing wan model module")
from models.wan.ovi_fusion_engine import OviFusionEngine
tracker.add("OviFusionEngine imported successfully")
# Test torch.autocast('mps')
tracker.add("Testing torch.autocast('mps')")
try:
with torch.autocast('mps', enabled=True, dtype=torch.bfloat16):
x = torch.randn(2, device='mps')
tracker.add("torch.autocast('mps') OK")
except Exception as e:
tracker.add(f"torch.autocast('mps') FAILED: {type(e).__name__}: {e}")
traceback.print_exc()
# Test the generate function's autocast line
tracker.add("Testing autocast('cuda') in model code")
try:
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
tracker.add("autocast('cuda') succeeded (with warning)")
except Exception as e:
tracker.add(f"autocast('cuda') FAILED: {type(e).__name__}: {e}")
traceback.print_exc()
# Test torch.compile with inductor backend (this was the original issue)
tracker.add("Testing torch.compile(inductor)")
def test_fn(x):
return x + 1
try:
compiled = torch.compile(test_fn, backend='inductor')
result = compiled(torch.tensor([1.0]))
tracker.add(f"torch.compile(inductor) OK: {result}")
except Exception as e:
tracker.add(f"torch.compile(inductor) FAILED: {type(e).__name__}: {e}")
traceback.print_exc()
# Test torch.cuda.amp.autocast (used in models)
tracker.add("Testing torch.cuda.amp.autocast")
try:
import torch.cuda.amp as amp
with amp.autocast(enabled=False):
pass
tracker.add("torch.cuda.amp.autocast OK")
except Exception as e:
tracker.add(f"torch.cuda.amp.autocast FAILED: {type(e).__name__}: {e}")
traceback.print_exc()
tracker.add("All tests completed")
|