| """Apple Silicon MPS compatibility patch for Wan2GP. |
| |
| Import EARLY at startup — BEFORE any mmgp import. Patches torch.cuda functions |
| to redirect to MPS, disables torch.compile (CUDA-less PyTorch build), and adds |
| stub attributes for CUDA-only code paths. |
| """ |
| import os |
| import sys |
| import types |
|
|
| |
| |
| |
| |
| os.environ.setdefault('TORCH_COMPILE', '0') |
| os.environ.setdefault('TORCHINDUCTOR', '0') |
| os.environ.setdefault('PYTORCH_ENABLE_MPS_FALLBACK', '1') |
|
|
| def apply_mps_patch(): |
| """Patch torch.cuda functions for MPS compatibility.""" |
| import torch as _torch |
|
|
| chip_name = _get_chip_name() |
| system_ram_gb = _get_system_memory_gb() |
| total_memory_bytes = int(system_ram_gb * 1024 ** 3) |
|
|
| if 'M1' in chip_name or 'M2' in chip_name: |
| dev_cap = (7, 0) |
| bfloat16_supported = False |
| else: |
| dev_cap = (11, 0) |
| bfloat16_supported = True |
|
|
| print(f"[MPS Patch] Detected: {chip_name}, {system_ram_gb:.0f}GB RAM") |
| print(f"[MPS Patch] Device capability: {dev_cap}, BF16: {bfloat16_supported}") |
|
|
| |
| _dummy_stream = types.SimpleNamespace( |
| synchronize=_torch.mps.synchronize, |
| wait_stream=lambda *a, **kw: None, |
| query=lambda: True, |
| priority=0, |
| ) |
|
|
| class _CudaDeviceProperties: |
| def __init__(self): |
| self.total_memory = total_memory_bytes |
| self.name = chip_name |
| self.major = dev_cap[0] |
| self.minor = dev_cap[1] |
| _torch.cuda._device_props_cache = self |
| multi_processor_count = 0 |
| warp_size = 32 |
|
|
| class _DummyEvent: |
| def __init__(self, *a, **kw): pass |
| def record(self, *a, **kw): pass |
| def elapsed_time(self, *a, **kw): return 0.0 |
| def synchronize(self, *a, **kw): pass |
| def query(self): return True |
|
|
| class _DummyDeviceContext: |
| def __init__(self, device=None): pass |
| def __enter__(self): return self |
| def __exit__(self, *args): pass |
|
|
| class _DummyStreamContext: |
| def __init__(self, s): pass |
| def __enter__(self): return self |
| def __exit__(self, *a): pass |
|
|
| class _DummyGraph: |
| replay = lambda s, *a, **kw: None |
| capture_begin = lambda s, *a, **kw: None |
| capture_end = lambda s, *a, **kw: None |
|
|
| |
| class _MpsAutocast: |
| def __init__(self, enabled=True, dtype=None, device_type='mps', cache_enabled=None): |
| self._autocast = _torch.autocast('mps', enabled=enabled, dtype=dtype) |
| def __enter__(self): return self._autocast.__enter__() |
| def __exit__(self, *a): return self._autocast.__exit__(*a) |
|
|
| class _autocast_mode_mod: |
| autocast = _MpsAutocast |
|
|
| class _amp_common: |
| @staticmethod |
| def amp_definitely_not_available(): |
| return True |
|
|
| class _PatchedAMP: |
| autocast = _MpsAutocast |
| autocast_mode = _autocast_mode_mod |
| common = _amp_common |
| class GradScaler: |
| def __init__(self, *a, **kw): pass |
| def step(self, *a, **kw): return a[0] if a else None |
| def update(self, *a, **kw): pass |
| def unscale_(self, *a, **kw): pass |
| def get_scale(self): return 1.0 |
| def state_dict(self): return {} |
| def load_state_dict(self, *a): pass |
|
|
| _cuda = _torch.cuda |
|
|
| |
| _cuda.is_available = lambda: False |
| _cuda._is_compiled = lambda: False |
| _cuda.empty_cache = _torch.mps.empty_cache |
| _cuda.synchronize = _torch.mps.synchronize |
| _cuda.get_device_capability = lambda device=None: dev_cap |
| _cuda.manual_seed_all = lambda seed: None |
| _cuda.manual_seed = lambda device_or_seed, seed=None: None |
| _cuda.current_stream = lambda device=None: _dummy_stream |
| _cuda.get_device_properties = lambda device=None: _CudaDeviceProperties() |
| _cuda._CudaDeviceProperties = _CudaDeviceProperties |
| _cuda.default_stream = lambda device=None: _dummy_stream |
| _cuda.set_device = lambda device: None |
| _cuda.current_device = lambda: 0 |
| _cuda.device_count = lambda: 1 |
| _cuda.ipc_collect = lambda: None |
| _cuda.device = _DummyDeviceContext |
|
|
| class _PatchedStream: |
| priority = 0 |
| def __init__(self, *a, **kw): pass |
| def synchronize(self, *a, **kw): pass |
| def wait_stream(self, *a, **kw): pass |
| def query(self): return True |
|
|
| _cuda.Stream = _PatchedStream |
| _cuda.stream = lambda s: _DummyStreamContext(s) |
| _cuda.Event = _DummyEvent |
| _cuda.is_bf16_supported = lambda device=None: bfloat16_supported |
| _cuda.bfloat16_supported = lambda device=None: bfloat16_supported |
| _cuda.is_current_stream_capturing = lambda: False |
| _cuda.graph = lambda *a, **kw: _DummyGraph() |
| _cuda.CUDAGraph = _DummyGraph |
| _cuda.graph_pool_handle = lambda: None |
| _cuda.mem_get_info = lambda device=None: (total_memory_bytes, total_memory_bytes) |
| _cuda.memory_allocated = lambda device=None: 0 |
| _cuda.memory_reserved = lambda device=None: 0 |
| _cuda.max_memory_allocated = lambda device=None: 0 |
| _cuda.max_memory_reserved = lambda device=None: 0 |
| _cuda.reset_peak_memory_stats = lambda device=None: None |
| _cuda.memory_stats = lambda device=None: {} |
| try: |
| import torch.cuda.amp as _cuda_amp |
| _cuda_amp.autocast = _MpsAutocast |
| _cuda_amp.GradScaler = _PatchedAMP.GradScaler |
| _cuda.amp = _cuda_amp |
| except Exception: |
| _cuda.amp = _PatchedAMP |
| _cuda.is_initialized = lambda: True |
| _cuda._lazy_init = lambda: None |
|
|
| |
| |
| |
| |
| _orig_manual_seed = _torch.manual_seed |
| def _mps_manual_seed(seed): |
| seed = int(seed) |
| _orig_manual_seed(seed) |
| _torch.mps.manual_seed(seed) |
| return _torch._C.Generator() |
| _torch.manual_seed = _mps_manual_seed |
|
|
| |
| |
| |
| |
| def _patched_compile(fn=None, *args, **kwargs): |
| if fn is not None: |
| return fn |
| def decorator(f): |
| return f |
| return decorator |
| _torch.compile = _patched_compile |
|
|
| |
| |
| _orig_autocast = _torch.autocast |
| def _patched_autocast(device_type=None, *args, **kwargs): |
| if device_type == 'cuda': |
| device_type = 'mps' |
| if kwargs.get('device_type') == 'cuda': |
| kwargs['device_type'] = 'mps' |
| |
| if device_type is None and 'device_type' not in kwargs: |
| device_type = 'mps' |
| return _orig_autocast(device_type, *args, **kwargs) |
| _torch.autocast = _patched_autocast |
| |
| _torch.amp.autocast = _patched_autocast |
|
|
| |
| try: |
| _torch._dynamo.config.suppress_errors = True |
| _torch._dynamo.config.cache_size_limit = 128 |
| except Exception: |
| pass |
|
|
| |
| def _patched_tensor_cuda(self, device=None, *args, **kwargs): |
| return self.to("mps") |
| _torch.Tensor.cuda = _patched_tensor_cuda |
|
|
| def _patched_module_cuda(self, device=None): |
| return self.to("mps") |
| _torch.nn.Module.cuda = _patched_module_cuda |
|
|
| def _replace_cuda_device(val): |
| if isinstance(val, str) and val.startswith("cuda"): |
| return "mps" |
| if isinstance(val, _torch.device) and val.type == "cuda": |
| return _torch.device("mps") |
| return val |
|
|
| def _replace_map_location(map_location): |
| if isinstance(map_location, dict): |
| return {key: _replace_cuda_device(value) for key, value in map_location.items()} |
| return _replace_cuda_device(map_location) |
|
|
| |
| |
| _orig_tensor_to = _torch.Tensor.to |
| def _patched_tensor_to(self, *args, **kwargs): |
| |
| new_args = [_replace_cuda_device(a) for a in args] |
| |
| if "device" in kwargs: |
| kwargs["device"] = _replace_cuda_device(kwargs["device"]) |
| return _orig_tensor_to(self, *new_args, **kwargs) |
| _torch.Tensor.to = _patched_tensor_to |
|
|
| _orig_module_to = _torch.nn.Module.to |
| def _patched_module_to(self, *args, **kwargs): |
| new_args = [_replace_cuda_device(a) for a in args] |
| if "device" in kwargs: |
| kwargs["device"] = _replace_cuda_device(kwargs["device"]) |
| return _orig_module_to(self, *new_args, **kwargs) |
| _torch.nn.Module.to = _patched_module_to |
|
|
| def _patched_pin_memory(self, *args, **kwargs): |
| return self |
| _torch.Tensor.pin_memory = _patched_pin_memory |
|
|
| _orig_load = _torch.load |
| def _patched_load(*args, **kwargs): |
| if "map_location" in kwargs: |
| kwargs["map_location"] = _replace_map_location(kwargs["map_location"]) |
| elif len(args) >= 2: |
| args = (args[0], _replace_map_location(args[1]), *args[2:]) |
| return _orig_load(*args, **kwargs) |
| _torch.load = _patched_load |
|
|
| |
| _Gen = _torch.Generator |
| class _PatchedGen(_Gen): |
| def __new__(cls, device=None): |
| device = _replace_cuda_device(device) |
| if device: |
| return super().__new__(cls, device=device) |
| return super().__new__(cls) |
| _torch.Generator = _PatchedGen |
|
|
| |
| for fn_name in ['zeros', 'ones', 'randn', 'rand', 'tensor', 'arange', |
| 'linspace', 'empty', 'full', 'eye', 'zeros_like', 'ones_like', |
| 'randn_like', 'rand_like', 'empty_like', 'full_like', |
| 'as_tensor', 'from_numpy']: |
| if hasattr(_torch, fn_name): |
| orig = getattr(_torch, fn_name) |
| def make_patcher(o): |
| def patched(*args, **kwargs): |
| dev = kwargs.get('device') |
| new_dev = _replace_cuda_device(dev) |
| if new_dev is not dev: |
| kwargs['device'] = new_dev |
| if 'pin_memory' in kwargs: |
| kwargs.pop('pin_memory', None) |
| return o(*args, **kwargs) |
| return patched |
| setattr(_torch, fn_name, make_patcher(orig)) |
|
|
| print(f"[MPS Patch] Applied successfully") |
| print(f"[MPS Patch] BF16 supported: {bfloat16_supported}") |
| print(f"[MPS Patch] Available system RAM: {system_ram_gb:.0f}GB") |
|
|
| |
| |
| |
| |
| if not hasattr(_torch.nn.Parameter, "weight"): |
| _torch.nn.Parameter.weight = property(lambda self: self) |
|
|
| return True |
|
|
| def _get_chip_name(): |
| try: |
| import subprocess |
| out = subprocess.check_output(['system_profiler', 'SPDisplaysDataType'], encoding='utf-8', stderr=subprocess.DEVNULL) |
| for line in out.split('\n'): |
| if 'Chip' in line: |
| return line.split(':', 1)[1].strip() |
| except Exception: |
| pass |
| return "Unknown Apple Silicon" |
|
|
| def _get_system_memory_gb(): |
| try: |
| import subprocess |
| out = subprocess.check_output(['sysctl', '-n', 'hw.memsize'], encoding='utf-8').strip() |
| return int(out) / (1024 ** 3) |
| except Exception: |
| return 16.0 |
|
|
| |
| import torch as _torch |
| _is_mps = sys.platform == 'darwin' and hasattr(_torch.backends, 'mps') and _torch.backends.mps.is_available() |
|
|
| |
| _C = _torch._C |
| if not hasattr(_C, '_cuda_getDefaultStream'): |
| def _cuda_getDefaultStream_stub(device_index=0): |
| return (0, device_index, 0) |
| _C._cuda_getDefaultStream = _cuda_getDefaultStream_stub |
|
|
| |
| if not hasattr(_torch.mps, 'current_device'): |
| _torch.mps.current_device = lambda: 0 |
| if not hasattr(_torch.mps, 'device_count'): |
| _torch.mps.device_count = lambda: 1 |
| if not hasattr(_torch.mps, 'set_device'): |
| _torch.mps.set_device = lambda device: None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if _is_mps: |
| _orig_sdpa = _torch.nn.functional.scaled_dot_product_attention |
| _sdpa_call_count = [0] |
|
|
| def _manual_sdpa_fallback(query, key, value, attn_mask=None, is_causal=False, scale=None): |
| """Manual attention fallback: matmul + softmax, no Metal SDPA.""" |
| |
| L = query.size(-2) |
| D = query.size(-1) |
| if scale is None: |
| scale = D ** -0.5 |
|
|
| attn_weights = _torch.matmul(query, key.transpose(-2, -1)) * scale |
| if attn_mask is not None: |
| attn_weights = attn_weights + attn_mask |
| if is_causal: |
| causal_mask = _torch.triu( |
| _torch.ones(L, L, device=query.device, dtype=_torch.bool), diagonal=1 |
| ) |
| attn_weights = attn_weights.masked_fill(causal_mask, float('-inf')) |
| attn_weights = _torch.nn.functional.softmax(attn_weights, dim=-1) |
| return _torch.matmul(attn_weights, value) |
|
|
| def _patched_sdpa(*args, **kwargs): |
| |
| _torch.mps.synchronize() |
|
|
| |
| _sdpa_call_count[0] += 1 |
| if _sdpa_call_count[0] % 256 == 0: |
| _torch.mps.empty_cache() |
|
|
| try: |
| with _torch.nn.attention.sdpa_kernel([_torch.nn.attention.SDPBackend.MATH]): |
| return _orig_sdpa(*args, **kwargs) |
| except Exception: |
| |
| |
| return _manual_sdpa_fallback(*args, **kwargs) |
|
|
| _torch.nn.functional.scaled_dot_product_attention = _patched_sdpa |
|
|
| if _is_mps: |
| try: |
| apply_mps_patch() |
| except Exception as e: |
| print(f"[MPS Patch] Failed to apply patch: {e}") |
| import traceback |
| traceback.print_exc() |
|
|