Spaces:
Sleeping
Fix MPS device issues: gradient colors, device placement, and Gradio hot reload
Browse filesThis commit fixes three critical issues on Apple Silicon (MPS):
1. MPS Device Placement Bug (Second Pass)
- Enhanced model (FreeU + SAG) had weights on CPU instead of MPS
- Caused "RuntimeError: Tensor for argument weight is on cpu but expected on mps"
- Fix: Recreate enhanced model before second refinement pass on MPS devices
- Locations: app.py:2506-2531, 2539, 2574
2. Gradient Filter Color Quantization
- Gradient filter showed black/white instead of blue/yellow on MPS
- PIL incorrectly auto-detected grayscale mode during tensor conversion
- Fix: Ensure RGB array shape before Image.fromarray() calls
- No deprecated 'mode' parameter (future-proof for Pillow 13)
- Locations: app.py:967-981, 2172-2178, 2207-2213, 2639-2645, 2668-2674
3. Gradio Hot Reload Support
- Import errors: "No module named 'utils.extra_config'"
- Missing 'demo' attribute error during reload
- Fix: Graceful import fallback, move app to module level, init guard
- Locations: app.py:144-167, 2717-2858, 4185-4198
All fixes are device-agnostic and don't affect CUDA/CPU performance.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
|
@@ -141,7 +141,14 @@ def add_extra_model_paths() -> None:
|
|
| 141 |
print(
|
| 142 |
"Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
|
| 143 |
)
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
extra_model_paths = find_path("extra_model_paths.yaml")
|
| 147 |
|
|
@@ -151,8 +158,13 @@ def add_extra_model_paths() -> None:
|
|
| 151 |
print("Could not find the extra_model_paths config file.")
|
| 152 |
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
def import_custom_nodes() -> None:
|
|
@@ -964,12 +976,21 @@ def apply_color_quantization(
|
|
| 964 |
if len(palette) < 2:
|
| 965 |
palette = [(0, 0, 0), (255, 255, 255)] # Default to black & white
|
| 966 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 967 |
# Convert PIL Image to numpy array
|
| 968 |
img_array = np.array(image)
|
| 969 |
|
| 970 |
-
# Handle RGBA images by converting to RGB
|
| 971 |
-
if img_array.shape[2] == 4:
|
| 972 |
img_array = img_array[:, :, :3]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 973 |
|
| 974 |
h, w, c = img_array.shape
|
| 975 |
pixels = img_array.reshape(h * w, c).astype(np.float32)
|
|
@@ -2159,6 +2180,13 @@ def _pipeline_standard(
|
|
| 2159 |
image_tensor = get_value_at_index(upscaled, 0)
|
| 2160 |
image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 2161 |
image_np = image_np[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2162 |
pil_image = Image.fromarray(image_np)
|
| 2163 |
current_step += 1
|
| 2164 |
|
|
@@ -2187,6 +2215,13 @@ def _pipeline_standard(
|
|
| 2187 |
image_tensor = get_value_at_index(vaedecode_21, 0)
|
| 2188 |
image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 2189 |
image_np = image_np[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2190 |
pil_image = Image.fromarray(image_np)
|
| 2191 |
current_step = 3
|
| 2192 |
|
|
@@ -2503,6 +2538,33 @@ def _pipeline_artistic(
|
|
| 2503 |
# Final sampling pass
|
| 2504 |
log_progress("Second pass (refinement)...", gr_progress, 0.6)
|
| 2505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2506 |
# Use animation-enabled sampler if requested
|
| 2507 |
if animation_handler and enable_animation:
|
| 2508 |
# Run ksampler in thread to allow real-time image yielding
|
|
@@ -2510,7 +2572,7 @@ def _pipeline_artistic(
|
|
| 2510 |
|
| 2511 |
def run_ksampler():
|
| 2512 |
result_container[0] = ksampler_with_animation(
|
| 2513 |
-
model=
|
| 2514 |
seed=seed + 1,
|
| 2515 |
steps=30,
|
| 2516 |
cfg=7,
|
|
@@ -2545,7 +2607,7 @@ def _pipeline_artistic(
|
|
| 2545 |
sampler_name="dpmpp_3m_sde",
|
| 2546 |
scheduler="karras",
|
| 2547 |
denoise=0.8,
|
| 2548 |
-
model=
|
| 2549 |
positive=get_value_at_index(controlnet_apply_final, 0),
|
| 2550 |
negative=get_value_at_index(controlnet_apply_final, 1),
|
| 2551 |
latent_image=get_value_at_index(upscaled_latent, 0),
|
|
@@ -2597,6 +2659,13 @@ def _pipeline_artistic(
|
|
| 2597 |
image_tensor = get_value_at_index(upscaled, 0)
|
| 2598 |
image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 2599 |
image_np = image_np[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2600 |
final_image = Image.fromarray(image_np)
|
| 2601 |
|
| 2602 |
# Apply color quantization if enabled
|
|
@@ -2619,6 +2688,13 @@ def _pipeline_artistic(
|
|
| 2619 |
image_tensor = get_value_at_index(final_decoded, 0)
|
| 2620 |
image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 2621 |
image_np = image_np[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2622 |
final_image = Image.fromarray(image_np)
|
| 2623 |
|
| 2624 |
# Apply color quantization if enabled
|
|
@@ -2638,16 +2714,8 @@ def _pipeline_artistic(
|
|
| 2638 |
return # Explicit return to cleanly exit generator
|
| 2639 |
|
| 2640 |
|
| 2641 |
-
|
| 2642 |
-
|
| 2643 |
-
# Must be called after module init but before Gradio app launch
|
| 2644 |
-
if not torch.backends.mps.is_available():
|
| 2645 |
-
compile_models_with_aoti()
|
| 2646 |
-
else:
|
| 2647 |
-
print("ℹ️ AOT compilation skipped on MPS (MacBook) - using eager mode\n")
|
| 2648 |
-
|
| 2649 |
-
# Define artistic examples data
|
| 2650 |
-
ARTISTIC_EXAMPLES = [
|
| 2651 |
{
|
| 2652 |
"image": "examples/artistic/japanese_temple.jpg",
|
| 2653 |
"label": "Japanese Temple",
|
|
@@ -2783,11 +2851,11 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
|
|
| 2783 |
"seed": 718313,
|
| 2784 |
"sag_blur_sigma": 1.5,
|
| 2785 |
},
|
| 2786 |
-
|
| 2787 |
|
| 2788 |
-
|
| 2789 |
-
|
| 2790 |
-
|
| 2791 |
# Add a title and description
|
| 2792 |
gr.Markdown("# QR Code Art Generator")
|
| 2793 |
gr.Markdown("""
|
|
@@ -4113,7 +4181,18 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
|
|
| 4113 |
)
|
| 4114 |
|
| 4115 |
# ARTISTIC QR TAB
|
| 4116 |
-
|
| 4117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4118 |
# Note: Automatic file cleanup via delete_cache not available in Gradio 5.49.1
|
| 4119 |
# Files will be cleaned up when the server is restarted
|
|
|
|
| 141 |
print(
|
| 142 |
"Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
|
| 143 |
)
|
| 144 |
+
try:
|
| 145 |
+
from utils.extra_config import load_extra_path_config
|
| 146 |
+
except (ImportError, ModuleNotFoundError) as e:
|
| 147 |
+
print(
|
| 148 |
+
f"Could not import load_extra_path_config from utils.extra_config either: {e}"
|
| 149 |
+
)
|
| 150 |
+
print("Skipping extra model paths configuration (this is OK for Gradio hot reload).")
|
| 151 |
+
return
|
| 152 |
|
| 153 |
extra_model_paths = find_path("extra_model_paths.yaml")
|
| 154 |
|
|
|
|
| 158 |
print("Could not find the extra_model_paths config file.")
|
| 159 |
|
| 160 |
|
| 161 |
+
# Only run initialization on first load, not during Gradio hot reload
|
| 162 |
+
if not hasattr(__builtins__, '_comfy_initialized'):
|
| 163 |
+
__builtins__._comfy_initialized = True
|
| 164 |
+
add_comfyui_directory_to_sys_path()
|
| 165 |
+
add_extra_model_paths()
|
| 166 |
+
else:
|
| 167 |
+
print("Skipping ComfyUI initialization (Gradio hot reload detected)")
|
| 168 |
|
| 169 |
|
| 170 |
def import_custom_nodes() -> None:
|
|
|
|
| 976 |
if len(palette) < 2:
|
| 977 |
palette = [(0, 0, 0), (255, 255, 255)] # Default to black & white
|
| 978 |
|
| 979 |
+
# Ensure image is in RGB mode (fixes MPS grayscale conversion bug)
|
| 980 |
+
# On MPS devices, PIL might incorrectly interpret the image as grayscale
|
| 981 |
+
if image.mode != 'RGB':
|
| 982 |
+
image = image.convert('RGB')
|
| 983 |
+
|
| 984 |
# Convert PIL Image to numpy array
|
| 985 |
img_array = np.array(image)
|
| 986 |
|
| 987 |
+
# Handle RGBA images by converting to RGB (though we already converted above)
|
| 988 |
+
if len(img_array.shape) == 3 and img_array.shape[2] == 4:
|
| 989 |
img_array = img_array[:, :, :3]
|
| 990 |
+
# Handle grayscale images that slipped through
|
| 991 |
+
elif len(img_array.shape) == 2:
|
| 992 |
+
# Convert grayscale to RGB by repeating the channel
|
| 993 |
+
img_array = np.stack([img_array, img_array, img_array], axis=2)
|
| 994 |
|
| 995 |
h, w, c = img_array.shape
|
| 996 |
pixels = img_array.reshape(h * w, c).astype(np.float32)
|
|
|
|
| 2180 |
image_tensor = get_value_at_index(upscaled, 0)
|
| 2181 |
image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 2182 |
image_np = image_np[0]
|
| 2183 |
+
# Ensure RGB array shape to prevent MPS grayscale conversion bug
|
| 2184 |
+
if len(image_np.shape) == 2:
|
| 2185 |
+
# Convert grayscale (H, W) to RGB (H, W, 3)
|
| 2186 |
+
image_np = np.stack([image_np, image_np, image_np], axis=2)
|
| 2187 |
+
elif image_np.shape[2] == 1:
|
| 2188 |
+
# Convert (H, W, 1) to (H, W, 3)
|
| 2189 |
+
image_np = np.repeat(image_np, 3, axis=2)
|
| 2190 |
pil_image = Image.fromarray(image_np)
|
| 2191 |
current_step += 1
|
| 2192 |
|
|
|
|
| 2215 |
image_tensor = get_value_at_index(vaedecode_21, 0)
|
| 2216 |
image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 2217 |
image_np = image_np[0]
|
| 2218 |
+
# Ensure RGB array shape to prevent MPS grayscale conversion bug
|
| 2219 |
+
if len(image_np.shape) == 2:
|
| 2220 |
+
# Convert grayscale (H, W) to RGB (H, W, 3)
|
| 2221 |
+
image_np = np.stack([image_np, image_np, image_np], axis=2)
|
| 2222 |
+
elif image_np.shape[2] == 1:
|
| 2223 |
+
# Convert (H, W, 1) to (H, W, 3)
|
| 2224 |
+
image_np = np.repeat(image_np, 3, axis=2)
|
| 2225 |
pil_image = Image.fromarray(image_np)
|
| 2226 |
current_step = 3
|
| 2227 |
|
|
|
|
| 2538 |
# Final sampling pass
|
| 2539 |
log_progress("Second pass (refinement)...", gr_progress, 0.6)
|
| 2540 |
|
| 2541 |
+
# MPS device workaround: Recreate enhanced model for second pass to avoid device placement issues
|
| 2542 |
+
# After the first threaded sampling pass, some model weights can end up on CPU instead of MPS
|
| 2543 |
+
# This happens due to threading interaction with MPS backend + SAG making additional model calls
|
| 2544 |
+
if torch.backends.mps.is_available():
|
| 2545 |
+
# Recreate FreeU enhanced model from base model
|
| 2546 |
+
freeu_model_second = freeu.patch(
|
| 2547 |
+
model=base_model,
|
| 2548 |
+
b1=freeu_b1,
|
| 2549 |
+
b2=freeu_b2,
|
| 2550 |
+
s1=freeu_s1,
|
| 2551 |
+
s2=freeu_s2,
|
| 2552 |
+
)[0]
|
| 2553 |
+
|
| 2554 |
+
# Reapply SAG if enabled
|
| 2555 |
+
if enable_sag:
|
| 2556 |
+
smoothed_energy_second = NODE_CLASS_MAPPINGS["SelfAttentionGuidance"]()
|
| 2557 |
+
enhanced_model_second = smoothed_energy_second.patch(
|
| 2558 |
+
model=freeu_model_second,
|
| 2559 |
+
scale=sag_scale,
|
| 2560 |
+
blur_sigma=sag_blur_sigma,
|
| 2561 |
+
)[0]
|
| 2562 |
+
else:
|
| 2563 |
+
enhanced_model_second = freeu_model_second
|
| 2564 |
+
else:
|
| 2565 |
+
# On non-MPS devices, reuse the same enhanced model
|
| 2566 |
+
enhanced_model_second = enhanced_model
|
| 2567 |
+
|
| 2568 |
# Use animation-enabled sampler if requested
|
| 2569 |
if animation_handler and enable_animation:
|
| 2570 |
# Run ksampler in thread to allow real-time image yielding
|
|
|
|
| 2572 |
|
| 2573 |
def run_ksampler():
|
| 2574 |
result_container[0] = ksampler_with_animation(
|
| 2575 |
+
model=enhanced_model_second, # Using recreated FreeU + SAG enhanced model (MPS fix)
|
| 2576 |
seed=seed + 1,
|
| 2577 |
steps=30,
|
| 2578 |
cfg=7,
|
|
|
|
| 2607 |
sampler_name="dpmpp_3m_sde",
|
| 2608 |
scheduler="karras",
|
| 2609 |
denoise=0.8,
|
| 2610 |
+
model=enhanced_model_second, # Using recreated FreeU + SAG enhanced model (MPS fix)
|
| 2611 |
positive=get_value_at_index(controlnet_apply_final, 0),
|
| 2612 |
negative=get_value_at_index(controlnet_apply_final, 1),
|
| 2613 |
latent_image=get_value_at_index(upscaled_latent, 0),
|
|
|
|
| 2659 |
image_tensor = get_value_at_index(upscaled, 0)
|
| 2660 |
image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 2661 |
image_np = image_np[0]
|
| 2662 |
+
# Ensure RGB array shape to prevent MPS grayscale conversion bug
|
| 2663 |
+
if len(image_np.shape) == 2:
|
| 2664 |
+
# Convert grayscale (H, W) to RGB (H, W, 3)
|
| 2665 |
+
image_np = np.stack([image_np, image_np, image_np], axis=2)
|
| 2666 |
+
elif image_np.shape[2] == 1:
|
| 2667 |
+
# Convert (H, W, 1) to (H, W, 3)
|
| 2668 |
+
image_np = np.repeat(image_np, 3, axis=2)
|
| 2669 |
final_image = Image.fromarray(image_np)
|
| 2670 |
|
| 2671 |
# Apply color quantization if enabled
|
|
|
|
| 2688 |
image_tensor = get_value_at_index(final_decoded, 0)
|
| 2689 |
image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
|
| 2690 |
image_np = image_np[0]
|
| 2691 |
+
# Ensure RGB array shape to prevent MPS grayscale conversion bug
|
| 2692 |
+
if len(image_np.shape) == 2:
|
| 2693 |
+
# Convert grayscale (H, W) to RGB (H, W, 3)
|
| 2694 |
+
image_np = np.stack([image_np, image_np, image_np], axis=2)
|
| 2695 |
+
elif image_np.shape[2] == 1:
|
| 2696 |
+
# Convert (H, W, 1) to (H, W, 3)
|
| 2697 |
+
image_np = np.repeat(image_np, 3, axis=2)
|
| 2698 |
final_image = Image.fromarray(image_np)
|
| 2699 |
|
| 2700 |
# Apply color quantization if enabled
|
|
|
|
| 2714 |
return # Explicit return to cleanly exit generator
|
| 2715 |
|
| 2716 |
|
| 2717 |
+
# Define artistic examples data (at module level for hot reload)
|
| 2718 |
+
ARTISTIC_EXAMPLES = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2719 |
{
|
| 2720 |
"image": "examples/artistic/japanese_temple.jpg",
|
| 2721 |
"label": "Japanese Temple",
|
|
|
|
| 2851 |
"seed": 718313,
|
| 2852 |
"sag_blur_sigma": 1.5,
|
| 2853 |
},
|
| 2854 |
+
]
|
| 2855 |
|
| 2856 |
+
# Start your Gradio app with automatic cache cleanup (at module level for hot reload)
|
| 2857 |
+
# delete_cache=(3600, 3600) means: check every hour and delete files older than 1 hour
|
| 2858 |
+
with gr.Blocks(delete_cache=(3600, 3600)) as demo:
|
| 2859 |
# Add a title and description
|
| 2860 |
gr.Markdown("# QR Code Art Generator")
|
| 2861 |
gr.Markdown("""
|
|
|
|
| 4181 |
)
|
| 4182 |
|
| 4183 |
# ARTISTIC QR TAB
|
| 4184 |
+
|
| 4185 |
+
# Queue is required for gr.Progress() to work!
|
| 4186 |
+
demo.queue()
|
| 4187 |
+
|
| 4188 |
+
# Launch the app when run directly (not during hot reload)
|
| 4189 |
+
if __name__ == "__main__":
|
| 4190 |
+
# Call AOT compilation during startup (only on CUDA, not MPS)
|
| 4191 |
+
if not torch.backends.mps.is_available() and not os.environ.get("QR_TESTING_MODE"):
|
| 4192 |
+
compile_models_with_aoti()
|
| 4193 |
+
else:
|
| 4194 |
+
print("ℹ️ AOT compilation skipped (MPS or testing mode)\n")
|
| 4195 |
+
|
| 4196 |
+
demo.launch(share=False, mcp_server=True)
|
| 4197 |
# Note: Automatic file cleanup via delete_cache not available in Gradio 5.49.1
|
| 4198 |
# Files will be cleaned up when the server is restarted
|