Spaces:
Running
on
Zero
Running
on
Zero
possible to run whole thing on Mac Silicon
Browse files- app.py +63 -0
- comfy/model_management.py +18 -0
app.py
CHANGED
|
@@ -163,6 +163,69 @@ latentupscaleby = NODE_CLASS_MAPPINGS["LatentUpscaleBy"]()
|
|
| 163 |
|
| 164 |
from comfy import model_management
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
# Add all the models that load a safetensors file
|
| 167 |
model_loaders = [checkpointloadersimple_4, checkpointloadersimple_artistic]
|
| 168 |
|
|
|
|
| 163 |
|
| 164 |
from comfy import model_management
|
| 165 |
|
| 166 |
+
# MPS (Apple Silicon) comprehensive workaround for black QR code bug
|
| 167 |
+
# Issue: PyTorch 2.6+ FP16 handling on MPS causes black images in samplers
|
| 168 |
+
# Additional issue: MPS tensor operations can produce NaN/inf values (PyTorch bug #84364)
|
| 169 |
+
# Solution: Monkey-patch dtype functions to force fp32, enable MPS fallback
|
| 170 |
+
# References: https://civitai.com/articles/11106, https://github.com/pytorch/pytorch/issues/84364
|
| 171 |
+
|
| 172 |
+
import os
|
| 173 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
| 174 |
+
|
| 175 |
+
from comfy.cli_args import args
|
| 176 |
+
|
| 177 |
+
if torch.backends.mps.is_available():
|
| 178 |
+
print(f"MPS device detected (PyTorch {torch.__version__})")
|
| 179 |
+
|
| 180 |
+
# Store original dtype functions
|
| 181 |
+
_original_unet_dtype = model_management.unet_dtype
|
| 182 |
+
_original_vae_dtype = model_management.vae_dtype
|
| 183 |
+
_original_text_encoder_dtype = model_management.text_encoder_dtype
|
| 184 |
+
|
| 185 |
+
# Monkey-patch dtype functions to force fp32 for MPS
|
| 186 |
+
def mps_safe_unet_dtype(device=None, *args_inner, **kwargs):
|
| 187 |
+
if device is not None and model_management.is_device_mps(device):
|
| 188 |
+
return torch.float32
|
| 189 |
+
if model_management.mps_mode():
|
| 190 |
+
return torch.float32
|
| 191 |
+
return _original_unet_dtype(device, *args_inner, **kwargs)
|
| 192 |
+
|
| 193 |
+
def mps_safe_vae_dtype(device=None, *args_inner, **kwargs):
|
| 194 |
+
if device is not None and model_management.is_device_mps(device):
|
| 195 |
+
return torch.float32
|
| 196 |
+
if model_management.mps_mode():
|
| 197 |
+
return torch.float32
|
| 198 |
+
return _original_vae_dtype(device, *args_inner, **kwargs)
|
| 199 |
+
|
| 200 |
+
def mps_safe_text_encoder_dtype(device=None, *args_inner, **kwargs):
|
| 201 |
+
if device is not None and model_management.is_device_mps(device):
|
| 202 |
+
return torch.float32
|
| 203 |
+
if model_management.mps_mode():
|
| 204 |
+
return torch.float32
|
| 205 |
+
return _original_text_encoder_dtype(device, *args_inner, **kwargs)
|
| 206 |
+
|
| 207 |
+
# Replace functions in model_management module
|
| 208 |
+
model_management.unet_dtype = mps_safe_unet_dtype
|
| 209 |
+
model_management.vae_dtype = mps_safe_vae_dtype
|
| 210 |
+
model_management.text_encoder_dtype = mps_safe_text_encoder_dtype
|
| 211 |
+
|
| 212 |
+
# Set args for additional stability
|
| 213 |
+
args.force_fp32 = True
|
| 214 |
+
args.fp32_vae = True
|
| 215 |
+
args.fp32_unet = True
|
| 216 |
+
args.force_upcast_attention = True
|
| 217 |
+
|
| 218 |
+
# Performance settings: Tune these for speed vs stability
|
| 219 |
+
# Try uncommenting these one at a time for better speed:
|
| 220 |
+
args.lowvram = False # Set to False for FASTER (try this first!)
|
| 221 |
+
args.use_split_cross_attention = False # Set to False for even FASTER (might cause black images)
|
| 222 |
+
|
| 223 |
+
lowvram_status = "enabled" if args.lowvram else "disabled (faster)"
|
| 224 |
+
split_attn_status = "enabled" if args.use_split_cross_attention else "disabled (faster)"
|
| 225 |
+
print(" β Enabled global fp32 dtype enforcement (monkey-patched)")
|
| 226 |
+
print(" β Enabled MPS fallback mode")
|
| 227 |
+
print(f" β lowvram: {lowvram_status}, split-cross-attention: {split_attn_status}")
|
| 228 |
+
|
| 229 |
# Add all the models that load a safetensors file
|
| 230 |
model_loaders = [checkpointloadersimple_4, checkpointloadersimple_artistic]
|
| 231 |
|
comfy/model_management.py
CHANGED
|
@@ -711,6 +711,12 @@ def maximum_vram_for_weights(device=None):
|
|
| 711 |
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
| 712 |
|
| 713 |
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
if model_params < 0:
|
| 715 |
model_params = 1000000000000000000000
|
| 716 |
if args.fp32_unet:
|
|
@@ -819,6 +825,12 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
|
| 819 |
return offload_device
|
| 820 |
|
| 821 |
def text_encoder_dtype(device=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 822 |
if args.fp8_e4m3fn_text_enc:
|
| 823 |
return torch.float8_e4m3fn
|
| 824 |
elif args.fp8_e5m2_text_enc:
|
|
@@ -854,6 +866,12 @@ def vae_offload_device():
|
|
| 854 |
return torch.device("cpu")
|
| 855 |
|
| 856 |
def vae_dtype(device=None, allowed_dtypes=[]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 857 |
if args.fp16_vae:
|
| 858 |
return torch.float16
|
| 859 |
elif args.bf16_vae:
|
|
|
|
| 711 |
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
| 712 |
|
| 713 |
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
|
| 714 |
+
# MPS workaround: Force fp32 for stability (PyTorch 2.6+ MPS bug)
|
| 715 |
+
if device is not None and is_device_mps(device):
|
| 716 |
+
return torch.float32
|
| 717 |
+
if mps_mode():
|
| 718 |
+
return torch.float32
|
| 719 |
+
|
| 720 |
if model_params < 0:
|
| 721 |
model_params = 1000000000000000000000
|
| 722 |
if args.fp32_unet:
|
|
|
|
| 825 |
return offload_device
|
| 826 |
|
| 827 |
def text_encoder_dtype(device=None):
|
| 828 |
+
# MPS workaround: Force fp32 for stability (PyTorch 2.6+ MPS bug)
|
| 829 |
+
if device is not None and is_device_mps(device):
|
| 830 |
+
return torch.float32
|
| 831 |
+
if mps_mode():
|
| 832 |
+
return torch.float32
|
| 833 |
+
|
| 834 |
if args.fp8_e4m3fn_text_enc:
|
| 835 |
return torch.float8_e4m3fn
|
| 836 |
elif args.fp8_e5m2_text_enc:
|
|
|
|
| 866 |
return torch.device("cpu")
|
| 867 |
|
| 868 |
def vae_dtype(device=None, allowed_dtypes=[]):
|
| 869 |
+
# MPS workaround: Force fp32 for stability (PyTorch 2.6+ MPS bug)
|
| 870 |
+
if device is not None and is_device_mps(device):
|
| 871 |
+
return torch.float32
|
| 872 |
+
if mps_mode():
|
| 873 |
+
return torch.float32
|
| 874 |
+
|
| 875 |
if args.fp16_vae:
|
| 876 |
return torch.float16
|
| 877 |
elif args.bf16_vae:
|