Oysiyl commited on
Commit
67c7b1f
Β·
1 Parent(s): edb13a9

possible to run whole thing on Mac Silicon

Browse files
Files changed (2) hide show
  1. app.py +63 -0
  2. comfy/model_management.py +18 -0
app.py CHANGED
@@ -163,6 +163,69 @@ latentupscaleby = NODE_CLASS_MAPPINGS["LatentUpscaleBy"]()
163
 
164
  from comfy import model_management
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # Add all the models that load a safetensors file
167
  model_loaders = [checkpointloadersimple_4, checkpointloadersimple_artistic]
168
 
 
163
 
164
  from comfy import model_management
165
 
166
+ # MPS (Apple Silicon) comprehensive workaround for black QR code bug
167
+ # Issue: PyTorch 2.6+ FP16 handling on MPS causes black images in samplers
168
+ # Additional issue: MPS tensor operations can produce NaN/inf values (PyTorch bug #84364)
169
+ # Solution: Monkey-patch dtype functions to force fp32, enable MPS fallback
170
+ # References: https://civitai.com/articles/11106, https://github.com/pytorch/pytorch/issues/84364
171
+
172
+ import os
173
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
174
+
175
+ from comfy.cli_args import args
176
+
177
+ if torch.backends.mps.is_available():
178
+ print(f"MPS device detected (PyTorch {torch.__version__})")
179
+
180
+ # Store original dtype functions
181
+ _original_unet_dtype = model_management.unet_dtype
182
+ _original_vae_dtype = model_management.vae_dtype
183
+ _original_text_encoder_dtype = model_management.text_encoder_dtype
184
+
185
+ # Monkey-patch dtype functions to force fp32 for MPS
186
+ def mps_safe_unet_dtype(device=None, *args_inner, **kwargs):
187
+ if device is not None and model_management.is_device_mps(device):
188
+ return torch.float32
189
+ if model_management.mps_mode():
190
+ return torch.float32
191
+ return _original_unet_dtype(device, *args_inner, **kwargs)
192
+
193
+ def mps_safe_vae_dtype(device=None, *args_inner, **kwargs):
194
+ if device is not None and model_management.is_device_mps(device):
195
+ return torch.float32
196
+ if model_management.mps_mode():
197
+ return torch.float32
198
+ return _original_vae_dtype(device, *args_inner, **kwargs)
199
+
200
+ def mps_safe_text_encoder_dtype(device=None, *args_inner, **kwargs):
201
+ if device is not None and model_management.is_device_mps(device):
202
+ return torch.float32
203
+ if model_management.mps_mode():
204
+ return torch.float32
205
+ return _original_text_encoder_dtype(device, *args_inner, **kwargs)
206
+
207
+ # Replace functions in model_management module
208
+ model_management.unet_dtype = mps_safe_unet_dtype
209
+ model_management.vae_dtype = mps_safe_vae_dtype
210
+ model_management.text_encoder_dtype = mps_safe_text_encoder_dtype
211
+
212
+ # Set args for additional stability
213
+ args.force_fp32 = True
214
+ args.fp32_vae = True
215
+ args.fp32_unet = True
216
+ args.force_upcast_attention = True
217
+
218
+ # Performance settings: Tune these for speed vs stability
219
+ # Try uncommenting these one at a time for better speed:
220
+ args.lowvram = False # Set to False for FASTER (try this first!)
221
+ args.use_split_cross_attention = False # Set to False for even FASTER (might cause black images)
222
+
223
+ lowvram_status = "enabled" if args.lowvram else "disabled (faster)"
224
+ split_attn_status = "enabled" if args.use_split_cross_attention else "disabled (faster)"
225
+ print(" βœ“ Enabled global fp32 dtype enforcement (monkey-patched)")
226
+ print(" βœ“ Enabled MPS fallback mode")
227
+ print(f" βœ“ lowvram: {lowvram_status}, split-cross-attention: {split_attn_status}")
228
+
229
  # Add all the models that load a safetensors file
230
  model_loaders = [checkpointloadersimple_4, checkpointloadersimple_artistic]
231
 
comfy/model_management.py CHANGED
@@ -711,6 +711,12 @@ def maximum_vram_for_weights(device=None):
711
  return (get_total_memory(device) * 0.88 - minimum_inference_memory())
712
 
713
  def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
 
 
 
 
 
 
714
  if model_params < 0:
715
  model_params = 1000000000000000000000
716
  if args.fp32_unet:
@@ -819,6 +825,12 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0):
819
  return offload_device
820
 
821
  def text_encoder_dtype(device=None):
 
 
 
 
 
 
822
  if args.fp8_e4m3fn_text_enc:
823
  return torch.float8_e4m3fn
824
  elif args.fp8_e5m2_text_enc:
@@ -854,6 +866,12 @@ def vae_offload_device():
854
  return torch.device("cpu")
855
 
856
  def vae_dtype(device=None, allowed_dtypes=[]):
 
 
 
 
 
 
857
  if args.fp16_vae:
858
  return torch.float16
859
  elif args.bf16_vae:
 
711
  return (get_total_memory(device) * 0.88 - minimum_inference_memory())
712
 
713
  def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
714
+ # MPS workaround: Force fp32 for stability (PyTorch 2.6+ MPS bug)
715
+ if device is not None and is_device_mps(device):
716
+ return torch.float32
717
+ if mps_mode():
718
+ return torch.float32
719
+
720
  if model_params < 0:
721
  model_params = 1000000000000000000000
722
  if args.fp32_unet:
 
825
  return offload_device
826
 
827
  def text_encoder_dtype(device=None):
828
+ # MPS workaround: Force fp32 for stability (PyTorch 2.6+ MPS bug)
829
+ if device is not None and is_device_mps(device):
830
+ return torch.float32
831
+ if mps_mode():
832
+ return torch.float32
833
+
834
  if args.fp8_e4m3fn_text_enc:
835
  return torch.float8_e4m3fn
836
  elif args.fp8_e5m2_text_enc:
 
866
  return torch.device("cpu")
867
 
868
  def vae_dtype(device=None, allowed_dtypes=[]):
869
+ # MPS workaround: Force fp32 for stability (PyTorch 2.6+ MPS bug)
870
+ if device is not None and is_device_mps(device):
871
+ return torch.float32
872
+ if mps_mode():
873
+ return torch.float32
874
+
875
  if args.fp16_vae:
876
  return torch.float16
877
  elif args.bf16_vae: