File size: 16,193 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
"""Apple Silicon MPS compatibility patch for Wan2GP.

Import EARLY at startup — BEFORE any mmgp import. Patches torch.cuda functions
to redirect to MPS, disables torch.compile (CUDA-less PyTorch build), and adds
stub attributes for CUDA-only code paths.
"""
import os
import sys
import types

# CRITICAL: Disable torch.compile / dynamo before torch is imported.
# PyTorch 2.11 on macOS is built with USE_CUDA=OFF. torch.compile traces into
# functions like torch.manual_seed, follows the CUDA call chain, and hits
# C++-level "not linked with cuda" errors that Python patches cannot intercept.
os.environ.setdefault('TORCH_COMPILE', '0')
os.environ.setdefault('TORCHINDUCTOR', '0')
os.environ.setdefault('PYTORCH_ENABLE_MPS_FALLBACK', '1')

def apply_mps_patch():
    """Patch torch.cuda functions for MPS compatibility."""
    import torch as _torch

    chip_name = _get_chip_name()
    system_ram_gb = _get_system_memory_gb()
    total_memory_bytes = int(system_ram_gb * 1024 ** 3)

    if 'M1' in chip_name or 'M2' in chip_name:
        dev_cap = (7, 0)
        bfloat16_supported = False
    else:
        dev_cap = (11, 0)
        bfloat16_supported = True

    print(f"[MPS Patch] Detected: {chip_name}, {system_ram_gb:.0f}GB RAM")
    print(f"[MPS Patch] Device capability: {dev_cap}, BF16: {bfloat16_supported}")

    # Dummy objects
    _dummy_stream = types.SimpleNamespace(
        synchronize=_torch.mps.synchronize,
        wait_stream=lambda *a, **kw: None,
        query=lambda: True,
        priority=0,
    )

    class _CudaDeviceProperties:
        def __init__(self):
            self.total_memory = total_memory_bytes
            self.name = chip_name
            self.major = dev_cap[0]
            self.minor = dev_cap[1]
            _torch.cuda._device_props_cache = self
        multi_processor_count = 0
        warp_size = 32

    class _DummyEvent:
        def __init__(self, *a, **kw): pass
        def record(self, *a, **kw): pass
        def elapsed_time(self, *a, **kw): return 0.0
        def synchronize(self, *a, **kw): pass
        def query(self): return True

    class _DummyDeviceContext:
        def __init__(self, device=None): pass
        def __enter__(self): return self
        def __exit__(self, *args): pass

    class _DummyStreamContext:
        def __init__(self, s): pass
        def __enter__(self): return self
        def __exit__(self, *a): pass

    class _DummyGraph:
        replay = lambda s, *a, **kw: None
        capture_begin = lambda s, *a, **kw: None
        capture_end = lambda s, *a, **kw: None

    # AMP stub
    class _MpsAutocast:
        def __init__(self, enabled=True, dtype=None, device_type='mps', cache_enabled=None):
            self._autocast = _torch.autocast('mps', enabled=enabled, dtype=dtype)
        def __enter__(self): return self._autocast.__enter__()
        def __exit__(self, *a): return self._autocast.__exit__(*a)

    class _autocast_mode_mod:
        autocast = _MpsAutocast

    class _amp_common:
        @staticmethod
        def amp_definitely_not_available():
            return True

    class _PatchedAMP:
        autocast = _MpsAutocast
        autocast_mode = _autocast_mode_mod
        common = _amp_common
        class GradScaler:
            def __init__(self, *a, **kw): pass
            def step(self, *a, **kw): return a[0] if a else None
            def update(self, *a, **kw): pass
            def unscale_(self, *a, **kw): pass
            def get_scale(self): return 1.0
            def state_dict(self): return {}
            def load_state_dict(self, *a): pass

    _cuda = _torch.cuda

    # Core function patches
    _cuda.is_available = lambda: False
    _cuda._is_compiled = lambda: False
    _cuda.empty_cache = _torch.mps.empty_cache
    _cuda.synchronize = _torch.mps.synchronize
    _cuda.get_device_capability = lambda device=None: dev_cap
    _cuda.manual_seed_all = lambda seed: None
    _cuda.manual_seed = lambda device_or_seed, seed=None: None
    _cuda.current_stream = lambda device=None: _dummy_stream
    _cuda.get_device_properties = lambda device=None: _CudaDeviceProperties()
    _cuda._CudaDeviceProperties = _CudaDeviceProperties
    _cuda.default_stream = lambda device=None: _dummy_stream
    _cuda.set_device = lambda device: None
    _cuda.current_device = lambda: 0
    _cuda.device_count = lambda: 1
    _cuda.ipc_collect = lambda: None
    _cuda.device = _DummyDeviceContext

    class _PatchedStream:
        priority = 0
        def __init__(self, *a, **kw): pass
        def synchronize(self, *a, **kw): pass
        def wait_stream(self, *a, **kw): pass
        def query(self): return True

    _cuda.Stream = _PatchedStream
    _cuda.stream = lambda s: _DummyStreamContext(s)
    _cuda.Event = _DummyEvent
    _cuda.is_bf16_supported = lambda device=None: bfloat16_supported
    _cuda.bfloat16_supported = lambda device=None: bfloat16_supported
    _cuda.is_current_stream_capturing = lambda: False
    _cuda.graph = lambda *a, **kw: _DummyGraph()
    _cuda.CUDAGraph = _DummyGraph
    _cuda.graph_pool_handle = lambda: None
    _cuda.mem_get_info = lambda device=None: (total_memory_bytes, total_memory_bytes)
    _cuda.memory_allocated = lambda device=None: 0
    _cuda.memory_reserved = lambda device=None: 0
    _cuda.max_memory_allocated = lambda device=None: 0
    _cuda.max_memory_reserved = lambda device=None: 0
    _cuda.reset_peak_memory_stats = lambda device=None: None
    _cuda.memory_stats = lambda device=None: {}
    try:
        import torch.cuda.amp as _cuda_amp
        _cuda_amp.autocast = _MpsAutocast
        _cuda_amp.GradScaler = _PatchedAMP.GradScaler
        _cuda.amp = _cuda_amp
    except Exception:
        _cuda.amp = _PatchedAMP
    _cuda.is_initialized = lambda: True
    _cuda._lazy_init = lambda: None

    # CRITICAL: Patch torch.manual_seed to avoid internal CUDA calls.
    # torch.manual_seed calls torch.cuda.manual_seed_all internally. Even with
    # cuda.manual_seed_all patched to no-op, torch.compile tracing into manual_seed
    # can trigger C++-level CUDA failures. Replace it with MPS-only version.
    _orig_manual_seed = _torch.manual_seed
    def _mps_manual_seed(seed):
        seed = int(seed)
        _orig_manual_seed(seed)  # CPU seed
        _torch.mps.manual_seed(seed)  # MPS seed
        return _torch._C.Generator()
    _torch.manual_seed = _mps_manual_seed

    # CRITICAL: Replace torch.compile with a true no-op.
    # PyTorch 2.11 on macOS is built with USE_CUDA=OFF. Even the 'eager' backend
    # involves dynamo tracing which can trigger C++-level CUDA failures.
    # Simply return the original function unchanged.
    def _patched_compile(fn=None, *args, **kwargs):
        if fn is not None:
            return fn
        def decorator(f):
            return f
        return decorator
    _torch.compile = _patched_compile

    # CRITICAL: Patch torch.autocast to redirect 'cuda' -> 'mps'
    # Code uses torch.autocast('cuda', ...) or torch.autocast(device_type='cuda', ...)
    _orig_autocast = _torch.autocast
    def _patched_autocast(device_type=None, *args, **kwargs):
        if device_type == 'cuda':
            device_type = 'mps'
        if kwargs.get('device_type') == 'cuda':
            kwargs['device_type'] = 'mps'
        # Handle torch.cuda.amp.autocast which calls with device_type=None initially
        if device_type is None and 'device_type' not in kwargs:
            device_type = 'mps'
        return _orig_autocast(device_type, *args, **kwargs)
    _torch.autocast = _patched_autocast
    # Also patch torch.amp.autocast
    _torch.amp.autocast = _patched_autocast

    # Disable torch._dynamo entirely to avoid any traced CUDA calls
    try:
        _torch._dynamo.config.suppress_errors = True
        _torch._dynamo.config.cache_size_limit = 128
    except Exception:
        pass

    # Tensor and Module .cuda() redirects
    def _patched_tensor_cuda(self, device=None, *args, **kwargs):
        return self.to("mps")
    _torch.Tensor.cuda = _patched_tensor_cuda

    def _patched_module_cuda(self, device=None):
        return self.to("mps")
    _torch.nn.Module.cuda = _patched_module_cuda

    def _replace_cuda_device(val):
        if isinstance(val, str) and val.startswith("cuda"):
            return "mps"
        if isinstance(val, _torch.device) and val.type == "cuda":
            return _torch.device("mps")
        return val

    def _replace_map_location(map_location):
        if isinstance(map_location, dict):
            return {key: _replace_cuda_device(value) for key, value in map_location.items()}
        return _replace_cuda_device(map_location)

    # CRITICAL: Patch Tensor.to and Module.to to intercept cuda device strings.
    # This catches code that passes device="cuda" as a string to .to() calls.
    _orig_tensor_to = _torch.Tensor.to
    def _patched_tensor_to(self, *args, **kwargs):
        # Handle positional args: .to("cuda"), .to(device), .to(dtype, device)
        new_args = [_replace_cuda_device(a) for a in args]
        # Handle keyword device arg
        if "device" in kwargs:
            kwargs["device"] = _replace_cuda_device(kwargs["device"])
        return _orig_tensor_to(self, *new_args, **kwargs)
    _torch.Tensor.to = _patched_tensor_to

    _orig_module_to = _torch.nn.Module.to
    def _patched_module_to(self, *args, **kwargs):
        new_args = [_replace_cuda_device(a) for a in args]
        if "device" in kwargs:
            kwargs["device"] = _replace_cuda_device(kwargs["device"])
        return _orig_module_to(self, *new_args, **kwargs)
    _torch.nn.Module.to = _patched_module_to

    def _patched_pin_memory(self, *args, **kwargs):
        return self
    _torch.Tensor.pin_memory = _patched_pin_memory

    _orig_load = _torch.load
    def _patched_load(*args, **kwargs):
        if "map_location" in kwargs:
            kwargs["map_location"] = _replace_map_location(kwargs["map_location"])
        elif len(args) >= 2:
            args = (args[0], _replace_map_location(args[1]), *args[2:])
        return _orig_load(*args, **kwargs)
    _torch.load = _patched_load

    # Generator patch
    _Gen = _torch.Generator
    class _PatchedGen(_Gen):
        def __new__(cls, device=None):
            device = _replace_cuda_device(device)
            if device:
                return super().__new__(cls, device=device)
            return super().__new__(cls)
    _torch.Generator = _PatchedGen

    # Tensor creation patch — redirect cuda->mps, fix pin_memory bug
    for fn_name in ['zeros', 'ones', 'randn', 'rand', 'tensor', 'arange',
                     'linspace', 'empty', 'full', 'eye', 'zeros_like', 'ones_like',
                     'randn_like', 'rand_like', 'empty_like', 'full_like',
                     'as_tensor', 'from_numpy']:
        if hasattr(_torch, fn_name):
            orig = getattr(_torch, fn_name)
            def make_patcher(o):
                def patched(*args, **kwargs):
                    dev = kwargs.get('device')
                    new_dev = _replace_cuda_device(dev)
                    if new_dev is not dev:
                        kwargs['device'] = new_dev
                    if 'pin_memory' in kwargs:
                        kwargs.pop('pin_memory', None)
                    return o(*args, **kwargs)
                return patched
            setattr(_torch, fn_name, make_patcher(orig))

    print(f"[MPS Patch] Applied successfully")
    print(f"[MPS Patch] BF16 supported: {bfloat16_supported}")
    print(f"[MPS Patch] Available system RAM: {system_ram_gb:.0f}GB")

    # Fix: Some Wan model loading paths call .weight on an nn.Parameter,
    # which is a Tensor subclass, not a Module. On MPS this fails because
    # nn.Parameter doesn't have a .weight attribute. Duck-type it to return self.
    # Reference: https://github.com/deepbeepmeep/Wan2GP/pull/1750#issuecomment-4387455446
    if not hasattr(_torch.nn.Parameter, "weight"):
        _torch.nn.Parameter.weight = property(lambda self: self)

    return True  # signal success

def _get_chip_name():
    try:
        import subprocess
        out = subprocess.check_output(['system_profiler', 'SPDisplaysDataType'], encoding='utf-8', stderr=subprocess.DEVNULL)
        for line in out.split('\n'):
            if 'Chip' in line:
                return line.split(':', 1)[1].strip()
    except Exception:
        pass
    return "Unknown Apple Silicon"

def _get_system_memory_gb():
    try:
        import subprocess
        out = subprocess.check_output(['sysctl', '-n', 'hw.memsize'], encoding='utf-8').strip()
        return int(out) / (1024 ** 3)
    except Exception:
        return 16.0

# Auto-apply on import if on macOS with MPS
import torch as _torch
_is_mps = sys.platform == 'darwin' and hasattr(_torch.backends, 'mps') and _torch.backends.mps.is_available()

# Patch torch._C missing C extension functions
_C = _torch._C
if not hasattr(_C, '_cuda_getDefaultStream'):
    def _cuda_getDefaultStream_stub(device_index=0):
        return (0, device_index, 0)
    _C._cuda_getDefaultStream = _cuda_getDefaultStream_stub

# Add missing torch.mps attributes
if not hasattr(_torch.mps, 'current_device'):
    _torch.mps.current_device = lambda: 0
if not hasattr(_torch.mps, 'device_count'):
    _torch.mps.device_count = lambda: 1
if not hasattr(_torch.mps, 'set_device'):
    _torch.mps.set_device = lambda device: None

# Force SDPA math backend on MPS to avoid Metal command buffer double-commit crash
# Reference: [IOGPUMetalCommandBuffer validate]:214: failed assertion `commit an already committed command buffer'
#
# Root cause: MPS fallback ops (CPU fallback from PYTORCH_ENABLE_MPS_FALLBACK=1)
# corrupt Metal command buffers when mixed with native MPS SDPA. This affects:
#   - Wan 2.2 5B (quanto mbf16 quantization)
#   - Wan 2.1 1.3B (standard safetensors, but ops still fallback)
#   - Any model where CPU-fallback ops precede an SDPA call
#
# Fix strategy (defense in depth):
#   1. Synchronize MPS before SDPA to flush pending fallback ops
#   2. Force MATH backend (avoids MPS-native SDPA bugs on some macOS versions)
#   3. If SDPA fails with Metal error, fall back to manual attention (matmul + softmax)
#   4. Periodic empty_cache to prevent memory fragmentation
if _is_mps:
    _orig_sdpa = _torch.nn.functional.scaled_dot_product_attention
    _sdpa_call_count = [0]

    def _manual_sdpa_fallback(query, key, value, attn_mask=None, is_causal=False, scale=None):
        """Manual attention fallback: matmul + softmax, no Metal SDPA."""
        # query: (B, H, L, D) or (B, L, H, D) after sdpa_kernel wrap
        L = query.size(-2)
        D = query.size(-1)
        if scale is None:
            scale = D ** -0.5

        attn_weights = _torch.matmul(query, key.transpose(-2, -1)) * scale
        if attn_mask is not None:
            attn_weights = attn_weights + attn_mask
        if is_causal:
            causal_mask = _torch.triu(
                _torch.ones(L, L, device=query.device, dtype=_torch.bool), diagonal=1
            )
            attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
        attn_weights = _torch.nn.functional.softmax(attn_weights, dim=-1)
        return _torch.matmul(attn_weights, value)

    def _patched_sdpa(*args, **kwargs):
        # Flush pending MPS fallback ops before SDPA
        _torch.mps.synchronize()

        # Periodic cache cleanup every 256 SDPA calls
        _sdpa_call_count[0] += 1
        if _sdpa_call_count[0] % 256 == 0:
            _torch.mps.empty_cache()

        try:
            with _torch.nn.attention.sdpa_kernel([_torch.nn.attention.SDPBackend.MATH]):
                return _orig_sdpa(*args, **kwargs)
        except Exception:
            # Metal command buffer corruption caught — fall back to manual attention
            # This handles cases where synchronize isn't sufficient (e.g. macOS 26.x bugs)
            return _manual_sdpa_fallback(*args, **kwargs)

    _torch.nn.functional.scaled_dot_product_attention = _patched_sdpa

if _is_mps:
    try:
        apply_mps_patch()
    except Exception as e:
        print(f"[MPS Patch] Failed to apply patch: {e}")
        import traceback
        traceback.print_exc()