Oysiyl Claude Sonnet 4.5 commited on
Commit
1d3c29b
·
1 Parent(s): d8f6a5a

Fix MPS device issues: gradient colors, device placement, and Gradio hot reload

Browse files

This commit fixes three critical issues on Apple Silicon (MPS):

1. MPS Device Placement Bug (Second Pass)
- Enhanced model (FreeU + SAG) had weights on CPU instead of MPS
- Caused "RuntimeError: Tensor for argument weight is on cpu but expected on mps"
- Fix: Recreate enhanced model before second refinement pass on MPS devices
- Locations: app.py:2506-2531, 2539, 2574

2. Gradient Filter Color Quantization
- Gradient filter showed black/white instead of blue/yellow on MPS
- PIL incorrectly auto-detected grayscale mode during tensor conversion
- Fix: Ensure RGB array shape before Image.fromarray() calls
- No deprecated 'mode' parameter (future-proof for Pillow 13)
- Locations: app.py:967-981, 2172-2178, 2207-2213, 2639-2645, 2668-2674

3. Gradio Hot Reload Support
- Import errors: "No module named 'utils.extra_config'"
- Missing 'demo' attribute error during reload
- Fix: Graceful import fallback, move app to module level, init guard
- Locations: app.py:144-167, 2717-2858, 4185-4198

All fixes are device-agnostic and don't affect CUDA/CPU performance.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +102 -23
app.py CHANGED
@@ -141,7 +141,14 @@ def add_extra_model_paths() -> None:
141
  print(
142
  "Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
143
  )
144
- from utils.extra_config import load_extra_path_config
 
 
 
 
 
 
 
145
 
146
  extra_model_paths = find_path("extra_model_paths.yaml")
147
 
@@ -151,8 +158,13 @@ def add_extra_model_paths() -> None:
151
  print("Could not find the extra_model_paths config file.")
152
 
153
 
154
- add_comfyui_directory_to_sys_path()
155
- add_extra_model_paths()
 
 
 
 
 
156
 
157
 
158
  def import_custom_nodes() -> None:
@@ -964,12 +976,21 @@ def apply_color_quantization(
964
  if len(palette) < 2:
965
  palette = [(0, 0, 0), (255, 255, 255)] # Default to black & white
966
 
 
 
 
 
 
967
  # Convert PIL Image to numpy array
968
  img_array = np.array(image)
969
 
970
- # Handle RGBA images by converting to RGB
971
- if img_array.shape[2] == 4:
972
  img_array = img_array[:, :, :3]
 
 
 
 
973
 
974
  h, w, c = img_array.shape
975
  pixels = img_array.reshape(h * w, c).astype(np.float32)
@@ -2159,6 +2180,13 @@ def _pipeline_standard(
2159
  image_tensor = get_value_at_index(upscaled, 0)
2160
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2161
  image_np = image_np[0]
 
 
 
 
 
 
 
2162
  pil_image = Image.fromarray(image_np)
2163
  current_step += 1
2164
 
@@ -2187,6 +2215,13 @@ def _pipeline_standard(
2187
  image_tensor = get_value_at_index(vaedecode_21, 0)
2188
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2189
  image_np = image_np[0]
 
 
 
 
 
 
 
2190
  pil_image = Image.fromarray(image_np)
2191
  current_step = 3
2192
 
@@ -2503,6 +2538,33 @@ def _pipeline_artistic(
2503
  # Final sampling pass
2504
  log_progress("Second pass (refinement)...", gr_progress, 0.6)
2505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2506
  # Use animation-enabled sampler if requested
2507
  if animation_handler and enable_animation:
2508
  # Run ksampler in thread to allow real-time image yielding
@@ -2510,7 +2572,7 @@ def _pipeline_artistic(
2510
 
2511
  def run_ksampler():
2512
  result_container[0] = ksampler_with_animation(
2513
- model=enhanced_model, # Using FreeU + SAG enhanced model
2514
  seed=seed + 1,
2515
  steps=30,
2516
  cfg=7,
@@ -2545,7 +2607,7 @@ def _pipeline_artistic(
2545
  sampler_name="dpmpp_3m_sde",
2546
  scheduler="karras",
2547
  denoise=0.8,
2548
- model=enhanced_model, # Using FreeU + SAG enhanced model
2549
  positive=get_value_at_index(controlnet_apply_final, 0),
2550
  negative=get_value_at_index(controlnet_apply_final, 1),
2551
  latent_image=get_value_at_index(upscaled_latent, 0),
@@ -2597,6 +2659,13 @@ def _pipeline_artistic(
2597
  image_tensor = get_value_at_index(upscaled, 0)
2598
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2599
  image_np = image_np[0]
 
 
 
 
 
 
 
2600
  final_image = Image.fromarray(image_np)
2601
 
2602
  # Apply color quantization if enabled
@@ -2619,6 +2688,13 @@ def _pipeline_artistic(
2619
  image_tensor = get_value_at_index(final_decoded, 0)
2620
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2621
  image_np = image_np[0]
 
 
 
 
 
 
 
2622
  final_image = Image.fromarray(image_np)
2623
 
2624
  # Apply color quantization if enabled
@@ -2638,16 +2714,8 @@ def _pipeline_artistic(
2638
  return # Explicit return to cleanly exit generator
2639
 
2640
 
2641
- if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2642
- # Call AOT compilation during startup (only on CUDA, not MPS)
2643
- # Must be called after module init but before Gradio app launch
2644
- if not torch.backends.mps.is_available():
2645
- compile_models_with_aoti()
2646
- else:
2647
- print("ℹ️ AOT compilation skipped on MPS (MacBook) - using eager mode\n")
2648
-
2649
- # Define artistic examples data
2650
- ARTISTIC_EXAMPLES = [
2651
  {
2652
  "image": "examples/artistic/japanese_temple.jpg",
2653
  "label": "Japanese Temple",
@@ -2783,11 +2851,11 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2783
  "seed": 718313,
2784
  "sag_blur_sigma": 1.5,
2785
  },
2786
- ]
2787
 
2788
- # Start your Gradio app with automatic cache cleanup
2789
- # delete_cache=(3600, 3600) means: check every hour and delete files older than 1 hour
2790
- with gr.Blocks(delete_cache=(3600, 3600)) as app:
2791
  # Add a title and description
2792
  gr.Markdown("# QR Code Art Generator")
2793
  gr.Markdown("""
@@ -4113,7 +4181,18 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
4113
  )
4114
 
4115
  # ARTISTIC QR TAB
4116
- app.queue() # Required for gr.Progress() to work!
4117
- app.launch(share=False, mcp_server=True)
 
 
 
 
 
 
 
 
 
 
 
4118
  # Note: Automatic file cleanup via delete_cache not available in Gradio 5.49.1
4119
  # Files will be cleaned up when the server is restarted
 
141
  print(
142
  "Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
143
  )
144
+ try:
145
+ from utils.extra_config import load_extra_path_config
146
+ except (ImportError, ModuleNotFoundError) as e:
147
+ print(
148
+ f"Could not import load_extra_path_config from utils.extra_config either: {e}"
149
+ )
150
+ print("Skipping extra model paths configuration (this is OK for Gradio hot reload).")
151
+ return
152
 
153
  extra_model_paths = find_path("extra_model_paths.yaml")
154
 
 
158
  print("Could not find the extra_model_paths config file.")
159
 
160
 
161
+ # Only run initialization on first load, not during Gradio hot reload
162
+ if not hasattr(__builtins__, '_comfy_initialized'):
163
+ __builtins__._comfy_initialized = True
164
+ add_comfyui_directory_to_sys_path()
165
+ add_extra_model_paths()
166
+ else:
167
+ print("Skipping ComfyUI initialization (Gradio hot reload detected)")
168
 
169
 
170
  def import_custom_nodes() -> None:
 
976
  if len(palette) < 2:
977
  palette = [(0, 0, 0), (255, 255, 255)] # Default to black & white
978
 
979
+ # Ensure image is in RGB mode (fixes MPS grayscale conversion bug)
980
+ # On MPS devices, PIL might incorrectly interpret the image as grayscale
981
+ if image.mode != 'RGB':
982
+ image = image.convert('RGB')
983
+
984
  # Convert PIL Image to numpy array
985
  img_array = np.array(image)
986
 
987
+ # Handle RGBA images by converting to RGB (though we already converted above)
988
+ if len(img_array.shape) == 3 and img_array.shape[2] == 4:
989
  img_array = img_array[:, :, :3]
990
+ # Handle grayscale images that slipped through
991
+ elif len(img_array.shape) == 2:
992
+ # Convert grayscale to RGB by repeating the channel
993
+ img_array = np.stack([img_array, img_array, img_array], axis=2)
994
 
995
  h, w, c = img_array.shape
996
  pixels = img_array.reshape(h * w, c).astype(np.float32)
 
2180
  image_tensor = get_value_at_index(upscaled, 0)
2181
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2182
  image_np = image_np[0]
2183
+ # Ensure RGB array shape to prevent MPS grayscale conversion bug
2184
+ if len(image_np.shape) == 2:
2185
+ # Convert grayscale (H, W) to RGB (H, W, 3)
2186
+ image_np = np.stack([image_np, image_np, image_np], axis=2)
2187
+ elif image_np.shape[2] == 1:
2188
+ # Convert (H, W, 1) to (H, W, 3)
2189
+ image_np = np.repeat(image_np, 3, axis=2)
2190
  pil_image = Image.fromarray(image_np)
2191
  current_step += 1
2192
 
 
2215
  image_tensor = get_value_at_index(vaedecode_21, 0)
2216
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2217
  image_np = image_np[0]
2218
+ # Ensure RGB array shape to prevent MPS grayscale conversion bug
2219
+ if len(image_np.shape) == 2:
2220
+ # Convert grayscale (H, W) to RGB (H, W, 3)
2221
+ image_np = np.stack([image_np, image_np, image_np], axis=2)
2222
+ elif image_np.shape[2] == 1:
2223
+ # Convert (H, W, 1) to (H, W, 3)
2224
+ image_np = np.repeat(image_np, 3, axis=2)
2225
  pil_image = Image.fromarray(image_np)
2226
  current_step = 3
2227
 
 
2538
  # Final sampling pass
2539
  log_progress("Second pass (refinement)...", gr_progress, 0.6)
2540
 
2541
+ # MPS device workaround: Recreate enhanced model for second pass to avoid device placement issues
2542
+ # After the first threaded sampling pass, some model weights can end up on CPU instead of MPS
2543
+ # This happens due to threading interaction with MPS backend + SAG making additional model calls
2544
+ if torch.backends.mps.is_available():
2545
+ # Recreate FreeU enhanced model from base model
2546
+ freeu_model_second = freeu.patch(
2547
+ model=base_model,
2548
+ b1=freeu_b1,
2549
+ b2=freeu_b2,
2550
+ s1=freeu_s1,
2551
+ s2=freeu_s2,
2552
+ )[0]
2553
+
2554
+ # Reapply SAG if enabled
2555
+ if enable_sag:
2556
+ smoothed_energy_second = NODE_CLASS_MAPPINGS["SelfAttentionGuidance"]()
2557
+ enhanced_model_second = smoothed_energy_second.patch(
2558
+ model=freeu_model_second,
2559
+ scale=sag_scale,
2560
+ blur_sigma=sag_blur_sigma,
2561
+ )[0]
2562
+ else:
2563
+ enhanced_model_second = freeu_model_second
2564
+ else:
2565
+ # On non-MPS devices, reuse the same enhanced model
2566
+ enhanced_model_second = enhanced_model
2567
+
2568
  # Use animation-enabled sampler if requested
2569
  if animation_handler and enable_animation:
2570
  # Run ksampler in thread to allow real-time image yielding
 
2572
 
2573
  def run_ksampler():
2574
  result_container[0] = ksampler_with_animation(
2575
+ model=enhanced_model_second, # Using recreated FreeU + SAG enhanced model (MPS fix)
2576
  seed=seed + 1,
2577
  steps=30,
2578
  cfg=7,
 
2607
  sampler_name="dpmpp_3m_sde",
2608
  scheduler="karras",
2609
  denoise=0.8,
2610
+ model=enhanced_model_second, # Using recreated FreeU + SAG enhanced model (MPS fix)
2611
  positive=get_value_at_index(controlnet_apply_final, 0),
2612
  negative=get_value_at_index(controlnet_apply_final, 1),
2613
  latent_image=get_value_at_index(upscaled_latent, 0),
 
2659
  image_tensor = get_value_at_index(upscaled, 0)
2660
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2661
  image_np = image_np[0]
2662
+ # Ensure RGB array shape to prevent MPS grayscale conversion bug
2663
+ if len(image_np.shape) == 2:
2664
+ # Convert grayscale (H, W) to RGB (H, W, 3)
2665
+ image_np = np.stack([image_np, image_np, image_np], axis=2)
2666
+ elif image_np.shape[2] == 1:
2667
+ # Convert (H, W, 1) to (H, W, 3)
2668
+ image_np = np.repeat(image_np, 3, axis=2)
2669
  final_image = Image.fromarray(image_np)
2670
 
2671
  # Apply color quantization if enabled
 
2688
  image_tensor = get_value_at_index(final_decoded, 0)
2689
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2690
  image_np = image_np[0]
2691
+ # Ensure RGB array shape to prevent MPS grayscale conversion bug
2692
+ if len(image_np.shape) == 2:
2693
+ # Convert grayscale (H, W) to RGB (H, W, 3)
2694
+ image_np = np.stack([image_np, image_np, image_np], axis=2)
2695
+ elif image_np.shape[2] == 1:
2696
+ # Convert (H, W, 1) to (H, W, 3)
2697
+ image_np = np.repeat(image_np, 3, axis=2)
2698
  final_image = Image.fromarray(image_np)
2699
 
2700
  # Apply color quantization if enabled
 
2714
  return # Explicit return to cleanly exit generator
2715
 
2716
 
2717
+ # Define artistic examples data (at module level for hot reload)
2718
+ ARTISTIC_EXAMPLES = [
 
 
 
 
 
 
 
 
2719
  {
2720
  "image": "examples/artistic/japanese_temple.jpg",
2721
  "label": "Japanese Temple",
 
2851
  "seed": 718313,
2852
  "sag_blur_sigma": 1.5,
2853
  },
2854
+ ]
2855
 
2856
+ # Start your Gradio app with automatic cache cleanup (at module level for hot reload)
2857
+ # delete_cache=(3600, 3600) means: check every hour and delete files older than 1 hour
2858
+ with gr.Blocks(delete_cache=(3600, 3600)) as demo:
2859
  # Add a title and description
2860
  gr.Markdown("# QR Code Art Generator")
2861
  gr.Markdown("""
 
4181
  )
4182
 
4183
  # ARTISTIC QR TAB
4184
+
4185
+ # Queue is required for gr.Progress() to work!
4186
+ demo.queue()
4187
+
4188
+ # Launch the app when run directly (not during hot reload)
4189
+ if __name__ == "__main__":
4190
+ # Call AOT compilation during startup (only on CUDA, not MPS)
4191
+ if not torch.backends.mps.is_available() and not os.environ.get("QR_TESTING_MODE"):
4192
+ compile_models_with_aoti()
4193
+ else:
4194
+ print("ℹ️ AOT compilation skipped (MPS or testing mode)\n")
4195
+
4196
+ demo.launch(share=False, mcp_server=True)
4197
  # Note: Automatic file cleanup via delete_cache not available in Gradio 5.49.1
4198
  # Files will be cleaned up when the server is restarted