Oysiyl Claude Sonnet 4.5 commited on
Commit
a3a5777
·
1 Parent(s): 151a131

Add gradient color filter to quantization feature

Browse files

Integrate gradient filter as optional enhancement to color quantization:
- Add apply_color_quantization() with gradient mode support
- Preserve QR colors (1-2) while applying gradients to background (3-4)
- Add UI controls for gradient strength and variation steps
- Update both Standard and Artistic pipelines
- Include gradient parameters in settings export/import
- Fix event handler placement for proper component references

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +770 -78
app.py CHANGED
@@ -15,13 +15,13 @@ import gradio as gr
15
  import numpy as np
16
  import spaces
17
  import torch
 
 
18
 
19
  # ComfyUI imports (after HF hub downloads)
20
  from comfy import model_management
21
  from comfy.cli_args import args
22
  from comfy_extras.nodes_freelunch import FreeU_V2
23
- from huggingface_hub import hf_hub_download
24
- from PIL import Image
25
 
26
  # Suppress torchsde floating-point precision warnings (cosmetic only, no functional impact)
27
  warnings.filterwarnings("ignore", message="Should have tb<=t1 but got")
@@ -358,6 +358,7 @@ def _apply_torch_compile_optimizations():
358
 
359
  # Increase cache limit to handle batch size variations (CFG uses batch 1 and 2)
360
  import torch._dynamo.config
 
361
  torch._dynamo.config.cache_size_limit = 64 # Allow more cached graphs
362
 
363
  # Compile standard pipeline model (DreamShaper 3.32)
@@ -408,8 +409,8 @@ def compile_models_with_aoti():
408
  TEST_SEED = 12345
409
 
410
  try:
411
- from spaces import aoti_capture, aoti_compile, aoti_apply
412
  import torch.export
 
413
 
414
  print(" Attempting AOT compilation...\n")
415
 
@@ -421,20 +422,22 @@ def compile_models_with_aoti():
421
 
422
  # Capture example run
423
  with aoti_capture(standard_model.model.diffusion_model) as call_standard:
424
- list(_pipeline_standard(
425
- prompt=TEST_PROMPT,
426
- qr_text=TEST_TEXT,
427
- input_type="URL",
428
- image_size=512,
429
- border_size=4,
430
- error_correction="Medium (15%)",
431
- module_size=12,
432
- module_drawer="Square",
433
- seed=TEST_SEED,
434
- enable_upscale=False,
435
- controlnet_strength_first=1.5,
436
- controlnet_strength_final=0.9,
437
- ))
 
 
438
 
439
  # Export and compile
440
  exported_standard = torch.export.export(
@@ -454,27 +457,29 @@ def compile_models_with_aoti():
454
 
455
  # Capture example run
456
  with aoti_capture(artistic_model.model.diffusion_model) as call_artistic:
457
- list(_pipeline_artistic(
458
- prompt=TEST_PROMPT,
459
- qr_text=TEST_TEXT,
460
- input_type="URL",
461
- image_size=640,
462
- border_size=4,
463
- error_correction="Medium (15%)",
464
- module_size=12,
465
- module_drawer="Square",
466
- seed=TEST_SEED,
467
- enable_upscale=False,
468
- controlnet_strength_first=1.5,
469
- controlnet_strength_final=0.9,
470
- freeu_b1=1.3,
471
- freeu_b2=1.4,
472
- freeu_s1=0.9,
473
- freeu_s2=0.2,
474
- enable_sag=True,
475
- sag_scale=0.75,
476
- sag_blur_sigma=2.0,
477
- ))
 
 
478
 
479
  # Export and compile
480
  exported_artistic = torch.export.export(
@@ -498,53 +503,61 @@ def compile_models_with_aoti():
498
  _apply_torch_compile_optimizations()
499
 
500
  # Run warmup inference to trigger torch.compile compilation
501
- print("🔥 Running warmup inference to compile models (this takes 2-3 minutes)...")
 
 
502
 
503
  try:
504
  # Warmup standard pipeline @ 512px
505
  print(" [1/2] Warming up standard pipeline...")
506
- list(_pipeline_standard(
507
- prompt=TEST_PROMPT,
508
- qr_text=TEST_TEXT,
509
- input_type="URL",
510
- image_size=512,
511
- border_size=4,
512
- error_correction="Medium (15%)",
513
- module_size=12,
514
- module_drawer="Square",
515
- seed=TEST_SEED,
516
- enable_upscale=False,
517
- controlnet_strength_first=1.5,
518
- controlnet_strength_final=0.9,
519
- ))
 
 
520
  print(" ✓ Standard pipeline compiled")
521
 
522
  # Warmup artistic pipeline @ 640px
523
  print(" [2/2] Warming up artistic pipeline...")
524
- list(_pipeline_artistic(
525
- prompt=TEST_PROMPT,
526
- qr_text=TEST_TEXT,
527
- input_type="URL",
528
- image_size=640,
529
- border_size=4,
530
- error_correction="Medium (15%)",
531
- module_size=12,
532
- module_drawer="Square",
533
- seed=TEST_SEED,
534
- enable_upscale=False,
535
- controlnet_strength_first=1.5,
536
- controlnet_strength_final=0.9,
537
- freeu_b1=1.3,
538
- freeu_b2=1.4,
539
- freeu_s1=0.9,
540
- freeu_s2=0.2,
541
- enable_sag=True,
542
- sag_scale=0.75,
543
- sag_blur_sigma=2.0,
544
- ))
 
 
545
  print(" ✓ Artistic pipeline compiled")
546
 
547
- print("\n✅ torch.compile warmup complete! Models ready for fast inference.\n")
 
 
548
  return True
549
 
550
  except Exception as warmup_error:
@@ -578,6 +591,15 @@ def generate_qr_code_unified(
578
  controlnet_strength_final: float = 0.7,
579
  controlnet_strength_standard_first: float = 0.45,
580
  controlnet_strength_standard_final: float = 1.0,
 
 
 
 
 
 
 
 
 
581
  progress=gr.Progress(),
582
  ):
583
  # Only manipulate the text if it's a URL input type
@@ -606,6 +628,15 @@ def generate_qr_code_unified(
606
  enable_upscale,
607
  controlnet_strength_standard_first,
608
  controlnet_strength_standard_final,
 
 
 
 
 
 
 
 
 
609
  progress,
610
  )
611
  else: # artistic
@@ -629,10 +660,208 @@ def generate_qr_code_unified(
629
  sag_blur_sigma,
630
  controlnet_strength_first,
631
  controlnet_strength_final,
 
 
 
 
 
 
 
 
 
632
  progress,
633
  )
634
 
635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
  def generate_standard_qr(
637
  prompt: str,
638
  text_input: str,
@@ -648,6 +877,15 @@ def generate_standard_qr(
648
  enable_freeu: bool = False,
649
  controlnet_strength_standard_first: float = 0.45,
650
  controlnet_strength_standard_final: float = 1.0,
 
 
 
 
 
 
 
 
 
651
  progress=gr.Progress(),
652
  ):
653
  """Wrapper function for standard QR generation"""
@@ -671,6 +909,15 @@ def generate_standard_qr(
671
  "enable_freeu": enable_freeu,
672
  "controlnet_strength_standard_first": controlnet_strength_standard_first,
673
  "controlnet_strength_standard_final": controlnet_strength_standard_final,
 
 
 
 
 
 
 
 
 
674
  }
675
  settings_json = generate_settings_json(settings_dict)
676
 
@@ -690,6 +937,15 @@ def generate_standard_qr(
690
  enable_upscale=enable_upscale,
691
  controlnet_strength_standard_first=controlnet_strength_standard_first,
692
  controlnet_strength_standard_final=controlnet_strength_standard_final,
 
 
 
 
 
 
 
 
 
693
  progress=progress,
694
  )
695
 
@@ -734,6 +990,15 @@ def generate_artistic_qr(
734
  sag_blur_sigma: float = 0.5,
735
  controlnet_strength_first: float = 0.45,
736
  controlnet_strength_final: float = 0.70,
 
 
 
 
 
 
 
 
 
737
  progress=gr.Progress(),
738
  ):
739
  """Wrapper function for artistic QR generation with FreeU and SAG parameters"""
@@ -764,6 +1029,15 @@ def generate_artistic_qr(
764
  "sag_blur_sigma": sag_blur_sigma,
765
  "controlnet_strength_first": controlnet_strength_first,
766
  "controlnet_strength_final": controlnet_strength_final,
 
 
 
 
 
 
 
 
 
767
  }
768
  settings_json = generate_settings_json(settings_dict)
769
 
@@ -790,6 +1064,15 @@ def generate_artistic_qr(
790
  sag_blur_sigma=sag_blur_sigma,
791
  controlnet_strength_first=controlnet_strength_first,
792
  controlnet_strength_final=controlnet_strength_final,
 
 
 
 
 
 
 
 
 
793
  progress=progress,
794
  )
795
 
@@ -867,6 +1150,15 @@ def load_settings_from_json_standard(json_string: str):
867
  gr.update(),
868
  gr.update(),
869
  gr.update(),
 
 
 
 
 
 
 
 
 
870
  gr.update(value=error_msg, visible=True),
871
  )
872
 
@@ -889,6 +1181,15 @@ def load_settings_from_json_standard(json_string: str):
889
  controlnet_strength_standard_final = params.get(
890
  "controlnet_strength_standard_final", 1.0
891
  )
 
 
 
 
 
 
 
 
 
892
 
893
  success_msg = "✅ Settings loaded successfully!"
894
  return (
@@ -906,6 +1207,15 @@ def load_settings_from_json_standard(json_string: str):
906
  enable_freeu,
907
  controlnet_strength_standard_first,
908
  controlnet_strength_standard_final,
 
 
 
 
 
 
 
 
 
909
  gr.update(value=success_msg, visible=True),
910
  )
911
 
@@ -926,6 +1236,15 @@ def load_settings_from_json_standard(json_string: str):
926
  gr.update(),
927
  gr.update(),
928
  gr.update(),
 
 
 
 
 
 
 
 
 
929
  gr.update(value=error_msg, visible=True),
930
  )
931
  except Exception as e:
@@ -945,6 +1264,15 @@ def load_settings_from_json_standard(json_string: str):
945
  gr.update(),
946
  gr.update(),
947
  gr.update(),
 
 
 
 
 
 
 
 
 
948
  gr.update(value=error_msg, visible=True),
949
  )
950
 
@@ -983,6 +1311,15 @@ def load_settings_from_json_artistic(json_string: str):
983
  gr.update(),
984
  gr.update(),
985
  gr.update(),
 
 
 
 
 
 
 
 
 
986
  gr.update(value=error_msg, visible=True),
987
  )
988
 
@@ -1008,6 +1345,15 @@ def load_settings_from_json_artistic(json_string: str):
1008
  sag_blur_sigma = params.get("sag_blur_sigma", 0.5)
1009
  controlnet_strength_first = params.get("controlnet_strength_first", 0.45)
1010
  controlnet_strength_final = params.get("controlnet_strength_final", 0.7)
 
 
 
 
 
 
 
 
 
1011
 
1012
  success_msg = "✅ Settings loaded successfully!"
1013
  return (
@@ -1032,6 +1378,15 @@ def load_settings_from_json_artistic(json_string: str):
1032
  sag_blur_sigma,
1033
  controlnet_strength_first,
1034
  controlnet_strength_final,
 
 
 
 
 
 
 
 
 
1035
  gr.update(value=success_msg, visible=True),
1036
  )
1037
 
@@ -1059,6 +1414,15 @@ def load_settings_from_json_artistic(json_string: str):
1059
  gr.update(),
1060
  gr.update(),
1061
  gr.update(),
 
 
 
 
 
 
 
 
 
1062
  gr.update(value=error_msg, visible=True),
1063
  )
1064
  except Exception as e:
@@ -1085,6 +1449,15 @@ def load_settings_from_json_artistic(json_string: str):
1085
  gr.update(),
1086
  gr.update(),
1087
  gr.update(),
 
 
 
 
 
 
 
 
 
1088
  gr.update(value=error_msg, visible=True),
1089
  )
1090
 
@@ -1192,6 +1565,15 @@ def _pipeline_standard(
1192
  enable_upscale: bool = False,
1193
  controlnet_strength_first: float = 0.45,
1194
  controlnet_strength_final: float = 1.0,
 
 
 
 
 
 
 
 
 
1195
  gr_progress=None,
1196
  ):
1197
  emptylatentimage_5 = emptylatentimage.generate(
@@ -1387,7 +1769,9 @@ def _pipeline_standard(
1387
  if enable_upscale:
1388
  # Show pre-upscale result
1389
  pre_upscale_tensor = get_value_at_index(vaedecode_21, 0)
1390
- pre_upscale_np = (pre_upscale_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
 
 
1391
  pre_upscale_np = pre_upscale_np[0]
1392
  pre_upscale_pil = Image.fromarray(pre_upscale_np)
1393
  msg = "Enhancement complete (step 3/4)... upscaling image"
@@ -1405,6 +1789,18 @@ def _pipeline_standard(
1405
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
1406
  image_np = image_np[0]
1407
  pil_image = Image.fromarray(image_np)
 
 
 
 
 
 
 
 
 
 
 
 
1408
  msg = "No errors, all good! Final QR art generated and upscaled. (step 4/4)"
1409
  log_progress(msg, gr_progress, 1.0)
1410
  yield (pil_image, msg)
@@ -1414,6 +1810,18 @@ def _pipeline_standard(
1414
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
1415
  image_np = image_np[0]
1416
  pil_image = Image.fromarray(image_np)
 
 
 
 
 
 
 
 
 
 
 
 
1417
  msg = "No errors, all good! Final QR art generated."
1418
  log_progress(msg, gr_progress, 1.0)
1419
  yield pil_image, msg
@@ -1439,6 +1847,15 @@ def _pipeline_artistic(
1439
  sag_blur_sigma: float = 0.5,
1440
  controlnet_strength_first: float = 0.45,
1441
  controlnet_strength_final: float = 0.7,
 
 
 
 
 
 
 
 
 
1442
  gr_progress=None,
1443
  ):
1444
  # Generate QR code
@@ -1497,7 +1914,9 @@ def _pipeline_artistic(
1497
  )
1498
 
1499
  # Show the noisy QR so you can see the border cubic pattern effect
1500
- noisy_qr_np = (qr_with_border_noise.detach().cpu().numpy() * 255).astype(np.uint8)
 
 
1501
  noisy_qr_np = noisy_qr_np[0]
1502
  noisy_qr_pil = Image.fromarray(noisy_qr_np)
1503
  msg = f"Added QR-like cubics to border... enhancing with AI (step {current_step}/{total_steps})"
@@ -1693,7 +2112,9 @@ def _pipeline_artistic(
1693
  if enable_upscale:
1694
  # Show result before upscaling
1695
  pre_upscale_tensor = get_value_at_index(final_decoded, 0)
1696
- pre_upscale_np = (pre_upscale_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
 
 
1697
  pre_upscale_np = pre_upscale_np[0]
1698
  pre_upscale_pil = Image.fromarray(pre_upscale_np)
1699
  msg = f"Final refinement complete (step {current_step}/{total_steps})... upscaling image"
@@ -1713,6 +2134,18 @@ def _pipeline_artistic(
1713
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
1714
  image_np = image_np[0]
1715
  final_image = Image.fromarray(image_np)
 
 
 
 
 
 
 
 
 
 
 
 
1716
  msg = f"No errors, all good! Final artistic QR code generated and upscaled. (step {current_step}/{total_steps})"
1717
  log_progress(msg, gr_progress, 1.0)
1718
  yield (final_image, msg)
@@ -1722,10 +2155,23 @@ def _pipeline_artistic(
1722
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
1723
  image_np = image_np[0]
1724
  final_image = Image.fromarray(image_np)
 
 
 
 
 
 
 
 
 
 
 
 
1725
  msg = f"No errors, all good! Final artistic QR code generated. (step {current_step}/{total_steps})"
1726
  log_progress(msg, gr_progress, 1.0)
1727
  yield (final_image, msg)
1728
 
 
1729
  if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
1730
  # Call AOT compilation during startup (only on CUDA, not MPS)
1731
  # Must be called after module init but before Gradio app launch
@@ -1939,6 +2385,113 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
1939
  info="Enable upscaling with RealESRGAN for higher quality output (enabled by default for artistic pipeline)",
1940
  )
1941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1942
  # Add seed controls for artistic QR
1943
  artistic_use_custom_seed = gr.Checkbox(
1944
  label="Use Custom Seed",
@@ -2094,6 +2647,15 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2094
  sag_blur_sigma,
2095
  controlnet_strength_first,
2096
  controlnet_strength_final,
 
 
 
 
 
 
 
 
 
2097
  ],
2098
  outputs=[
2099
  artistic_output_image,
@@ -2129,6 +2691,15 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2129
  sag_blur_sigma,
2130
  controlnet_strength_first,
2131
  controlnet_strength_final,
 
 
 
 
 
 
 
 
 
2132
  import_status_artistic,
2133
  ],
2134
  )
@@ -2729,6 +3300,109 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2729
  info="Enable FreeU quality enhancement (disabled by default for standard pipeline)",
2730
  )
2731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2732
  # Add seed controls
2733
  use_custom_seed = gr.Checkbox(
2734
  label="Use Custom Seed",
@@ -2811,6 +3485,15 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2811
  enable_freeu_standard,
2812
  controlnet_strength_standard_first,
2813
  controlnet_strength_standard_final,
 
 
 
 
 
 
 
 
 
2814
  ],
2815
  outputs=[
2816
  output_image,
@@ -2839,6 +3522,15 @@ if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2839
  enable_freeu_standard,
2840
  controlnet_strength_standard_first,
2841
  controlnet_strength_standard_final,
 
 
 
 
 
 
 
 
 
2842
  import_status_standard,
2843
  ],
2844
  )
 
15
  import numpy as np
16
  import spaces
17
  import torch
18
+ from huggingface_hub import hf_hub_download
19
+ from PIL import Image
20
 
21
  # ComfyUI imports (after HF hub downloads)
22
  from comfy import model_management
23
  from comfy.cli_args import args
24
  from comfy_extras.nodes_freelunch import FreeU_V2
 
 
25
 
26
  # Suppress torchsde floating-point precision warnings (cosmetic only, no functional impact)
27
  warnings.filterwarnings("ignore", message="Should have tb<=t1 but got")
 
358
 
359
  # Increase cache limit to handle batch size variations (CFG uses batch 1 and 2)
360
  import torch._dynamo.config
361
+
362
  torch._dynamo.config.cache_size_limit = 64 # Allow more cached graphs
363
 
364
  # Compile standard pipeline model (DreamShaper 3.32)
 
409
  TEST_SEED = 12345
410
 
411
  try:
 
412
  import torch.export
413
+ from spaces import aoti_apply, aoti_capture, aoti_compile
414
 
415
  print(" Attempting AOT compilation...\n")
416
 
 
422
 
423
  # Capture example run
424
  with aoti_capture(standard_model.model.diffusion_model) as call_standard:
425
+ list(
426
+ _pipeline_standard(
427
+ prompt=TEST_PROMPT,
428
+ qr_text=TEST_TEXT,
429
+ input_type="URL",
430
+ image_size=512,
431
+ border_size=4,
432
+ error_correction="Medium (15%)",
433
+ module_size=12,
434
+ module_drawer="Square",
435
+ seed=TEST_SEED,
436
+ enable_upscale=False,
437
+ controlnet_strength_first=1.5,
438
+ controlnet_strength_final=0.9,
439
+ )
440
+ )
441
 
442
  # Export and compile
443
  exported_standard = torch.export.export(
 
457
 
458
  # Capture example run
459
  with aoti_capture(artistic_model.model.diffusion_model) as call_artistic:
460
+ list(
461
+ _pipeline_artistic(
462
+ prompt=TEST_PROMPT,
463
+ qr_text=TEST_TEXT,
464
+ input_type="URL",
465
+ image_size=640,
466
+ border_size=4,
467
+ error_correction="Medium (15%)",
468
+ module_size=12,
469
+ module_drawer="Square",
470
+ seed=TEST_SEED,
471
+ enable_upscale=False,
472
+ controlnet_strength_first=1.5,
473
+ controlnet_strength_final=0.9,
474
+ freeu_b1=1.3,
475
+ freeu_b2=1.4,
476
+ freeu_s1=0.9,
477
+ freeu_s2=0.2,
478
+ enable_sag=True,
479
+ sag_scale=0.75,
480
+ sag_blur_sigma=2.0,
481
+ )
482
+ )
483
 
484
  # Export and compile
485
  exported_artistic = torch.export.export(
 
503
  _apply_torch_compile_optimizations()
504
 
505
  # Run warmup inference to trigger torch.compile compilation
506
+ print(
507
+ "🔥 Running warmup inference to compile models (this takes 2-3 minutes)..."
508
+ )
509
 
510
  try:
511
  # Warmup standard pipeline @ 512px
512
  print(" [1/2] Warming up standard pipeline...")
513
+ list(
514
+ _pipeline_standard(
515
+ prompt=TEST_PROMPT,
516
+ qr_text=TEST_TEXT,
517
+ input_type="URL",
518
+ image_size=512,
519
+ border_size=4,
520
+ error_correction="Medium (15%)",
521
+ module_size=12,
522
+ module_drawer="Square",
523
+ seed=TEST_SEED,
524
+ enable_upscale=False,
525
+ controlnet_strength_first=1.5,
526
+ controlnet_strength_final=0.9,
527
+ )
528
+ )
529
  print(" ✓ Standard pipeline compiled")
530
 
531
  # Warmup artistic pipeline @ 640px
532
  print(" [2/2] Warming up artistic pipeline...")
533
+ list(
534
+ _pipeline_artistic(
535
+ prompt=TEST_PROMPT,
536
+ qr_text=TEST_TEXT,
537
+ input_type="URL",
538
+ image_size=640,
539
+ border_size=4,
540
+ error_correction="Medium (15%)",
541
+ module_size=12,
542
+ module_drawer="Square",
543
+ seed=TEST_SEED,
544
+ enable_upscale=False,
545
+ controlnet_strength_first=1.5,
546
+ controlnet_strength_final=0.9,
547
+ freeu_b1=1.3,
548
+ freeu_b2=1.4,
549
+ freeu_s1=0.9,
550
+ freeu_s2=0.2,
551
+ enable_sag=True,
552
+ sag_scale=0.75,
553
+ sag_blur_sigma=2.0,
554
+ )
555
+ )
556
  print(" ✓ Artistic pipeline compiled")
557
 
558
+ print(
559
+ "\n✅ torch.compile warmup complete! Models ready for fast inference.\n"
560
+ )
561
  return True
562
 
563
  except Exception as warmup_error:
 
591
  controlnet_strength_final: float = 0.7,
592
  controlnet_strength_standard_first: float = 0.45,
593
  controlnet_strength_standard_final: float = 1.0,
594
+ enable_color_quantization: bool = False,
595
+ num_colors: int = 4,
596
+ color_1: str = "#000000",
597
+ color_2: str = "#FFFFFF",
598
+ color_3: str = "#FF0000",
599
+ color_4: str = "#00FF00",
600
+ apply_gradient_filter: bool = False,
601
+ gradient_strength: float = 0.3,
602
+ variation_steps: int = 5,
603
  progress=gr.Progress(),
604
  ):
605
  # Only manipulate the text if it's a URL input type
 
628
  enable_upscale,
629
  controlnet_strength_standard_first,
630
  controlnet_strength_standard_final,
631
+ enable_color_quantization,
632
+ num_colors,
633
+ color_1,
634
+ color_2,
635
+ color_3,
636
+ color_4,
637
+ apply_gradient_filter,
638
+ gradient_strength,
639
+ variation_steps,
640
  progress,
641
  )
642
  else: # artistic
 
660
  sag_blur_sigma,
661
  controlnet_strength_first,
662
  controlnet_strength_final,
663
+ enable_color_quantization,
664
+ num_colors,
665
+ color_1,
666
+ color_2,
667
+ color_3,
668
+ color_4,
669
+ apply_gradient_filter,
670
+ gradient_strength,
671
+ variation_steps,
672
  progress,
673
  )
674
 
675
 
676
+ def apply_color_quantization(
677
+ image: Image.Image,
678
+ colors: list[str],
679
+ num_colors: int = 4,
680
+ apply_gradients: bool = False,
681
+ gradient_strength: float = 0.3,
682
+ variation_steps: int = 5,
683
+ ) -> Image.Image:
684
+ """
685
+ Apply color quantization to an image using nearest-color mapping.
686
+ Optionally apply gradient filter for artistic effect while preserving QR scannability.
687
+
688
+ Args:
689
+ image: PIL Image to quantize
690
+ colors: List of hex color strings (e.g., ["#FF0000", "#00FF00", "#0000FF", "#FFFFFF"])
691
+ num_colors: Number of colors to use from the colors list (2-4)
692
+ apply_gradients: If True, create gradient variations around base colors
693
+ gradient_strength: How much brightness variation to allow (0.0-1.0), e.g. 0.3 = ±30%
694
+ variation_steps: Number of gradient steps for each color (1-10)
695
+
696
+ Returns:
697
+ Quantized PIL Image (with optional gradient effect)
698
+
699
+ Note:
700
+ When gradients are enabled, first 2 colors are always preserved (no gradients)
701
+ to maintain QR code scannability. Only colors 3-4 get gradient variations.
702
+ """
703
+ # Validate num_colors
704
+ if num_colors < 2:
705
+ num_colors = 2
706
+ if num_colors > len(colors):
707
+ num_colors = len(colors)
708
+
709
+ # Parse colors with error handling (supports both hex and rgba formats)
710
+ palette = []
711
+ for color_str in colors[:num_colors]:
712
+ try:
713
+ # Check if it's an rgba string (from Gradio ColorPicker)
714
+ if color_str.startswith("rgba("):
715
+ # Extract RGB values from "rgba(r, g, b, a)" format
716
+ rgb_part = color_str[5:-1] # Remove "rgba(" and ")"
717
+ values = [float(v.strip()) for v in rgb_part.split(",")]
718
+ r = int(values[0])
719
+ g = int(values[1])
720
+ b = int(values[2])
721
+ palette.append((r, g, b))
722
+ else:
723
+ # Assume hex format
724
+ color_hex = color_str.lstrip("#")
725
+ r = int(color_hex[0:2], 16)
726
+ g = int(color_hex[2:4], 16)
727
+ b = int(color_hex[4:6], 16)
728
+ palette.append((r, g, b))
729
+ except (ValueError, IndexError, AttributeError):
730
+ # Fallback to black for invalid colors
731
+ palette.append((0, 0, 0))
732
+
733
+ # Ensure at least 2 colors
734
+ if len(palette) < 2:
735
+ palette = [(0, 0, 0), (255, 255, 255)] # Default to black & white
736
+
737
+ # Convert PIL Image to numpy array
738
+ img_array = np.array(image)
739
+
740
+ # Handle RGBA images by converting to RGB
741
+ if img_array.shape[2] == 4:
742
+ img_array = img_array[:, :, :3]
743
+
744
+ h, w, c = img_array.shape
745
+ pixels = img_array.reshape(h * w, c).astype(np.float32)
746
+
747
+ # ============================================================
748
+ # GRADIENT FILTER MODE: Create gradient variations
749
+ # ============================================================
750
+ if apply_gradients:
751
+ # Always preserve first 2 colors (black/white for QR scannability)
752
+ preserve_colors = [0, 1]
753
+
754
+ # Create gradient palette
755
+ palette_with_gradients = []
756
+ color_family_map = [] # Track which base color each gradient belongs to
757
+
758
+ for base_idx, base_color in enumerate(palette):
759
+ r, g, b = base_color
760
+
761
+ # Check if this color should be preserved (no gradients)
762
+ if base_idx in preserve_colors:
763
+ # Keep this color pure - only add the base color once
764
+ palette_with_gradients.append((r, g, b))
765
+ color_family_map.append(base_idx)
766
+ else:
767
+ # Create variations from dark to light
768
+ for step in range(variation_steps):
769
+ # Calculate brightness multiplier
770
+ if variation_steps == 1:
771
+ multiplier = 1.0 # Only use base color when steps=1
772
+ else:
773
+ multiplier = 1.0 + gradient_strength * (
774
+ 2 * step / (variation_steps - 1) - 1
775
+ )
776
+
777
+ # Apply multiplier and clamp to valid range
778
+ varied_r = int(np.clip(r * multiplier, 0, 255))
779
+ varied_g = int(np.clip(g * multiplier, 0, 255))
780
+ varied_b = int(np.clip(b * multiplier, 0, 255))
781
+
782
+ palette_with_gradients.append((varied_r, varied_g, varied_b))
783
+ color_family_map.append(base_idx)
784
+
785
+ gradient_palette_array = np.array(palette_with_gradients, dtype=np.float32)
786
+ base_palette_array = np.array(palette, dtype=np.float32)
787
+
788
+ # Calculate original pixel brightness for gradient selection
789
+ pixel_brightness = np.mean(pixels, axis=1)
790
+
791
+ # Step 1: Find nearest BASE color for each pixel
792
+ distances_to_base = np.sqrt(
793
+ np.sum((pixels[:, None, :] - base_palette_array[None, :, :]) ** 2, axis=2)
794
+ )
795
+ nearest_base_idx = np.argmin(distances_to_base, axis=1)
796
+
797
+ # Step 2: Fully vectorized gradient assignment
798
+ # Create mapping from base color index to gradient range
799
+ gradient_ranges = {}
800
+ for base_idx in range(len(palette)):
801
+ family_indices = [
802
+ i for i, fam in enumerate(color_family_map) if fam == base_idx
803
+ ]
804
+ gradient_ranges[base_idx] = np.array(family_indices)
805
+
806
+ # Initialize result
807
+ result_indices = np.zeros(len(pixels), dtype=int)
808
+
809
+ # For each base color family, compute gradient indices
810
+ for base_idx in range(len(palette)):
811
+ mask = nearest_base_idx == base_idx
812
+ if not np.any(mask):
813
+ continue
814
+
815
+ family_indices = gradient_ranges[base_idx]
816
+ masked_brightness = pixel_brightness[mask]
817
+
818
+ # Normalize brightness within this family
819
+ min_b, max_b = masked_brightness.min(), masked_brightness.max()
820
+ if max_b > min_b:
821
+ norm_bright = (masked_brightness - min_b) / (max_b - min_b)
822
+ else:
823
+ norm_bright = np.full(len(masked_brightness), 0.5)
824
+
825
+ # Map to gradient steps
826
+ steps = (norm_bright * (len(family_indices) - 1)).astype(int)
827
+ steps = np.clip(steps, 0, len(family_indices) - 1)
828
+
829
+ # Assign palette indices
830
+ result_indices[mask] = family_indices[steps]
831
+
832
+ # Final color assignment
833
+ result_pixels = gradient_palette_array[result_indices].astype(np.uint8)
834
+ quantized_image = result_pixels.reshape(h, w, c)
835
+
836
+ # ============================================================
837
+ # STRICT QUANTIZATION MODE: No gradients
838
+ # ============================================================
839
+ else:
840
+ # Convert palette to numpy array
841
+ palette_array = np.array(palette, dtype=np.uint8)
842
+
843
+ # Calculate Euclidean distance from each pixel to each palette color
844
+ distances = np.sqrt(
845
+ np.sum(
846
+ (pixels[:, None, :] - palette_array[None, :, :].astype(np.float32))
847
+ ** 2,
848
+ axis=2,
849
+ )
850
+ )
851
+
852
+ # Find index of nearest color for each pixel
853
+ nearest_indices = np.argmin(distances, axis=1)
854
+
855
+ # Map each pixel to its nearest palette color
856
+ quantized = palette_array[nearest_indices]
857
+
858
+ # Reshape back to image dimensions
859
+ quantized_image = quantized.reshape(h, w, c).astype(np.uint8)
860
+
861
+ # Convert back to PIL Image
862
+ return Image.fromarray(quantized_image)
863
+
864
+
865
  def generate_standard_qr(
866
  prompt: str,
867
  text_input: str,
 
877
  enable_freeu: bool = False,
878
  controlnet_strength_standard_first: float = 0.45,
879
  controlnet_strength_standard_final: float = 1.0,
880
+ enable_color_quantization: bool = False,
881
+ num_colors: int = 4,
882
+ color_1: str = "#000000",
883
+ color_2: str = "#FFFFFF",
884
+ color_3: str = "#FF0000",
885
+ color_4: str = "#00FF00",
886
+ apply_gradient_filter: bool = False,
887
+ gradient_strength: float = 0.3,
888
+ variation_steps: int = 5,
889
  progress=gr.Progress(),
890
  ):
891
  """Wrapper function for standard QR generation"""
 
909
  "enable_freeu": enable_freeu,
910
  "controlnet_strength_standard_first": controlnet_strength_standard_first,
911
  "controlnet_strength_standard_final": controlnet_strength_standard_final,
912
+ "enable_color_quantization": enable_color_quantization,
913
+ "num_colors": num_colors,
914
+ "color_1": color_1,
915
+ "color_2": color_2,
916
+ "color_3": color_3,
917
+ "color_4": color_4,
918
+ "apply_gradient_filter": apply_gradient_filter,
919
+ "gradient_strength": gradient_strength,
920
+ "variation_steps": variation_steps,
921
  }
922
  settings_json = generate_settings_json(settings_dict)
923
 
 
937
  enable_upscale=enable_upscale,
938
  controlnet_strength_standard_first=controlnet_strength_standard_first,
939
  controlnet_strength_standard_final=controlnet_strength_standard_final,
940
+ enable_color_quantization=enable_color_quantization,
941
+ num_colors=num_colors,
942
+ color_1=color_1,
943
+ color_2=color_2,
944
+ color_3=color_3,
945
+ color_4=color_4,
946
+ apply_gradient_filter=apply_gradient_filter,
947
+ gradient_strength=gradient_strength,
948
+ variation_steps=variation_steps,
949
  progress=progress,
950
  )
951
 
 
990
  sag_blur_sigma: float = 0.5,
991
  controlnet_strength_first: float = 0.45,
992
  controlnet_strength_final: float = 0.70,
993
+ enable_color_quantization: bool = False,
994
+ num_colors: int = 4,
995
+ color_1: str = "#000000",
996
+ color_2: str = "#FFFFFF",
997
+ color_3: str = "#FF0000",
998
+ color_4: str = "#00FF00",
999
+ apply_gradient_filter: bool = False,
1000
+ gradient_strength: float = 0.3,
1001
+ variation_steps: int = 5,
1002
  progress=gr.Progress(),
1003
  ):
1004
  """Wrapper function for artistic QR generation with FreeU and SAG parameters"""
 
1029
  "sag_blur_sigma": sag_blur_sigma,
1030
  "controlnet_strength_first": controlnet_strength_first,
1031
  "controlnet_strength_final": controlnet_strength_final,
1032
+ "enable_color_quantization": enable_color_quantization,
1033
+ "num_colors": num_colors,
1034
+ "color_1": color_1,
1035
+ "color_2": color_2,
1036
+ "color_3": color_3,
1037
+ "color_4": color_4,
1038
+ "apply_gradient_filter": apply_gradient_filter,
1039
+ "gradient_strength": gradient_strength,
1040
+ "variation_steps": variation_steps,
1041
  }
1042
  settings_json = generate_settings_json(settings_dict)
1043
 
 
1064
  sag_blur_sigma=sag_blur_sigma,
1065
  controlnet_strength_first=controlnet_strength_first,
1066
  controlnet_strength_final=controlnet_strength_final,
1067
+ enable_color_quantization=enable_color_quantization,
1068
+ num_colors=num_colors,
1069
+ color_1=color_1,
1070
+ color_2=color_2,
1071
+ color_3=color_3,
1072
+ color_4=color_4,
1073
+ apply_gradient_filter=apply_gradient_filter,
1074
+ gradient_strength=gradient_strength,
1075
+ variation_steps=variation_steps,
1076
  progress=progress,
1077
  )
1078
 
 
1150
  gr.update(),
1151
  gr.update(),
1152
  gr.update(),
1153
+ gr.update(),
1154
+ gr.update(),
1155
+ gr.update(),
1156
+ gr.update(),
1157
+ gr.update(),
1158
+ gr.update(),
1159
+ gr.update(),
1160
+ gr.update(),
1161
+ gr.update(),
1162
  gr.update(value=error_msg, visible=True),
1163
  )
1164
 
 
1181
  controlnet_strength_standard_final = params.get(
1182
  "controlnet_strength_standard_final", 1.0
1183
  )
1184
+ enable_color_quantization = params.get("enable_color_quantization", False)
1185
+ num_colors = params.get("num_colors", 4)
1186
+ color_1 = params.get("color_1", "#000000")
1187
+ color_2 = params.get("color_2", "#FFFFFF")
1188
+ color_3 = params.get("color_3", "#FF0000")
1189
+ color_4 = params.get("color_4", "#00FF00")
1190
+ apply_gradient_filter = params.get("apply_gradient_filter", False)
1191
+ gradient_strength = params.get("gradient_strength", 0.3)
1192
+ variation_steps = params.get("variation_steps", 5)
1193
 
1194
  success_msg = "✅ Settings loaded successfully!"
1195
  return (
 
1207
  enable_freeu,
1208
  controlnet_strength_standard_first,
1209
  controlnet_strength_standard_final,
1210
+ enable_color_quantization,
1211
+ num_colors,
1212
+ color_1,
1213
+ color_2,
1214
+ color_3,
1215
+ color_4,
1216
+ apply_gradient_filter,
1217
+ gradient_strength,
1218
+ variation_steps,
1219
  gr.update(value=success_msg, visible=True),
1220
  )
1221
 
 
1236
  gr.update(),
1237
  gr.update(),
1238
  gr.update(),
1239
+ gr.update(),
1240
+ gr.update(),
1241
+ gr.update(),
1242
+ gr.update(),
1243
+ gr.update(),
1244
+ gr.update(),
1245
+ gr.update(),
1246
+ gr.update(),
1247
+ gr.update(),
1248
  gr.update(value=error_msg, visible=True),
1249
  )
1250
  except Exception as e:
 
1264
  gr.update(),
1265
  gr.update(),
1266
  gr.update(),
1267
+ gr.update(),
1268
+ gr.update(),
1269
+ gr.update(),
1270
+ gr.update(),
1271
+ gr.update(),
1272
+ gr.update(),
1273
+ gr.update(),
1274
+ gr.update(),
1275
+ gr.update(),
1276
  gr.update(value=error_msg, visible=True),
1277
  )
1278
 
 
1311
  gr.update(),
1312
  gr.update(),
1313
  gr.update(),
1314
+ gr.update(),
1315
+ gr.update(),
1316
+ gr.update(),
1317
+ gr.update(),
1318
+ gr.update(),
1319
+ gr.update(),
1320
+ gr.update(),
1321
+ gr.update(),
1322
+ gr.update(),
1323
  gr.update(value=error_msg, visible=True),
1324
  )
1325
 
 
1345
  sag_blur_sigma = params.get("sag_blur_sigma", 0.5)
1346
  controlnet_strength_first = params.get("controlnet_strength_first", 0.45)
1347
  controlnet_strength_final = params.get("controlnet_strength_final", 0.7)
1348
+ enable_color_quantization = params.get("enable_color_quantization", False)
1349
+ num_colors = params.get("num_colors", 4)
1350
+ color_1 = params.get("color_1", "#000000")
1351
+ color_2 = params.get("color_2", "#FFFFFF")
1352
+ color_3 = params.get("color_3", "#FF0000")
1353
+ color_4 = params.get("color_4", "#00FF00")
1354
+ apply_gradient_filter = params.get("apply_gradient_filter", False)
1355
+ gradient_strength = params.get("gradient_strength", 0.3)
1356
+ variation_steps = params.get("variation_steps", 5)
1357
 
1358
  success_msg = "✅ Settings loaded successfully!"
1359
  return (
 
1378
  sag_blur_sigma,
1379
  controlnet_strength_first,
1380
  controlnet_strength_final,
1381
+ enable_color_quantization,
1382
+ num_colors,
1383
+ color_1,
1384
+ color_2,
1385
+ color_3,
1386
+ color_4,
1387
+ apply_gradient_filter,
1388
+ gradient_strength,
1389
+ variation_steps,
1390
  gr.update(value=success_msg, visible=True),
1391
  )
1392
 
 
1414
  gr.update(),
1415
  gr.update(),
1416
  gr.update(),
1417
+ gr.update(),
1418
+ gr.update(),
1419
+ gr.update(),
1420
+ gr.update(),
1421
+ gr.update(),
1422
+ gr.update(),
1423
+ gr.update(),
1424
+ gr.update(),
1425
+ gr.update(),
1426
  gr.update(value=error_msg, visible=True),
1427
  )
1428
  except Exception as e:
 
1449
  gr.update(),
1450
  gr.update(),
1451
  gr.update(),
1452
+ gr.update(),
1453
+ gr.update(),
1454
+ gr.update(),
1455
+ gr.update(),
1456
+ gr.update(),
1457
+ gr.update(),
1458
+ gr.update(),
1459
+ gr.update(),
1460
+ gr.update(),
1461
  gr.update(value=error_msg, visible=True),
1462
  )
1463
 
 
1565
  enable_upscale: bool = False,
1566
  controlnet_strength_first: float = 0.45,
1567
  controlnet_strength_final: float = 1.0,
1568
+ enable_color_quantization: bool = False,
1569
+ num_colors: int = 4,
1570
+ color_1: str = "#000000",
1571
+ color_2: str = "#FFFFFF",
1572
+ color_3: str = "#FF0000",
1573
+ color_4: str = "#00FF00",
1574
+ apply_gradient_filter: bool = False,
1575
+ gradient_strength: float = 0.3,
1576
+ variation_steps: int = 5,
1577
  gr_progress=None,
1578
  ):
1579
  emptylatentimage_5 = emptylatentimage.generate(
 
1769
  if enable_upscale:
1770
  # Show pre-upscale result
1771
  pre_upscale_tensor = get_value_at_index(vaedecode_21, 0)
1772
+ pre_upscale_np = (pre_upscale_tensor.detach().cpu().numpy() * 255).astype(
1773
+ np.uint8
1774
+ )
1775
  pre_upscale_np = pre_upscale_np[0]
1776
  pre_upscale_pil = Image.fromarray(pre_upscale_np)
1777
  msg = "Enhancement complete (step 3/4)... upscaling image"
 
1789
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
1790
  image_np = image_np[0]
1791
  pil_image = Image.fromarray(image_np)
1792
+
1793
+ # Apply color quantization if enabled
1794
+ if enable_color_quantization:
1795
+ pil_image = apply_color_quantization(
1796
+ pil_image,
1797
+ colors=[color_1, color_2, color_3, color_4],
1798
+ num_colors=num_colors,
1799
+ apply_gradients=apply_gradient_filter,
1800
+ gradient_strength=gradient_strength,
1801
+ variation_steps=variation_steps,
1802
+ )
1803
+
1804
  msg = "No errors, all good! Final QR art generated and upscaled. (step 4/4)"
1805
  log_progress(msg, gr_progress, 1.0)
1806
  yield (pil_image, msg)
 
1810
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
1811
  image_np = image_np[0]
1812
  pil_image = Image.fromarray(image_np)
1813
+
1814
+ # Apply color quantization if enabled
1815
+ if enable_color_quantization:
1816
+ pil_image = apply_color_quantization(
1817
+ pil_image,
1818
+ colors=[color_1, color_2, color_3, color_4],
1819
+ num_colors=num_colors,
1820
+ apply_gradients=apply_gradient_filter,
1821
+ gradient_strength=gradient_strength,
1822
+ variation_steps=variation_steps,
1823
+ )
1824
+
1825
  msg = "No errors, all good! Final QR art generated."
1826
  log_progress(msg, gr_progress, 1.0)
1827
  yield pil_image, msg
 
1847
  sag_blur_sigma: float = 0.5,
1848
  controlnet_strength_first: float = 0.45,
1849
  controlnet_strength_final: float = 0.7,
1850
+ enable_color_quantization: bool = False,
1851
+ num_colors: int = 4,
1852
+ color_1: str = "#000000",
1853
+ color_2: str = "#FFFFFF",
1854
+ color_3: str = "#FF0000",
1855
+ color_4: str = "#00FF00",
1856
+ apply_gradient_filter: bool = False,
1857
+ gradient_strength: float = 0.3,
1858
+ variation_steps: int = 5,
1859
  gr_progress=None,
1860
  ):
1861
  # Generate QR code
 
1914
  )
1915
 
1916
  # Show the noisy QR so you can see the border cubic pattern effect
1917
+ noisy_qr_np = (qr_with_border_noise.detach().cpu().numpy() * 255).astype(
1918
+ np.uint8
1919
+ )
1920
  noisy_qr_np = noisy_qr_np[0]
1921
  noisy_qr_pil = Image.fromarray(noisy_qr_np)
1922
  msg = f"Added QR-like cubics to border... enhancing with AI (step {current_step}/{total_steps})"
 
2112
  if enable_upscale:
2113
  # Show result before upscaling
2114
  pre_upscale_tensor = get_value_at_index(final_decoded, 0)
2115
+ pre_upscale_np = (pre_upscale_tensor.detach().cpu().numpy() * 255).astype(
2116
+ np.uint8
2117
+ )
2118
  pre_upscale_np = pre_upscale_np[0]
2119
  pre_upscale_pil = Image.fromarray(pre_upscale_np)
2120
  msg = f"Final refinement complete (step {current_step}/{total_steps})... upscaling image"
 
2134
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2135
  image_np = image_np[0]
2136
  final_image = Image.fromarray(image_np)
2137
+
2138
+ # Apply color quantization if enabled
2139
+ if enable_color_quantization:
2140
+ final_image = apply_color_quantization(
2141
+ final_image,
2142
+ colors=[color_1, color_2, color_3, color_4],
2143
+ num_colors=num_colors,
2144
+ apply_gradients=apply_gradient_filter,
2145
+ gradient_strength=gradient_strength,
2146
+ variation_steps=variation_steps,
2147
+ )
2148
+
2149
  msg = f"No errors, all good! Final artistic QR code generated and upscaled. (step {current_step}/{total_steps})"
2150
  log_progress(msg, gr_progress, 1.0)
2151
  yield (final_image, msg)
 
2155
  image_np = (image_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
2156
  image_np = image_np[0]
2157
  final_image = Image.fromarray(image_np)
2158
+
2159
+ # Apply color quantization if enabled
2160
+ if enable_color_quantization:
2161
+ final_image = apply_color_quantization(
2162
+ final_image,
2163
+ colors=[color_1, color_2, color_3, color_4],
2164
+ num_colors=num_colors,
2165
+ apply_gradients=apply_gradient_filter,
2166
+ gradient_strength=gradient_strength,
2167
+ variation_steps=variation_steps,
2168
+ )
2169
+
2170
  msg = f"No errors, all good! Final artistic QR code generated. (step {current_step}/{total_steps})"
2171
  log_progress(msg, gr_progress, 1.0)
2172
  yield (final_image, msg)
2173
 
2174
+
2175
  if __name__ == "__main__" and not os.environ.get("QR_TESTING_MODE"):
2176
  # Call AOT compilation during startup (only on CUDA, not MPS)
2177
  # Must be called after module init but before Gradio app launch
 
2385
  info="Enable upscaling with RealESRGAN for higher quality output (enabled by default for artistic pipeline)",
2386
  )
2387
 
2388
+ # Color Quantization Section
2389
+ gr.Markdown("### Color Quantization (Optional)")
2390
+ artistic_enable_color_quantization = gr.Checkbox(
2391
+ label="Enable Color Quantization",
2392
+ value=False,
2393
+ info="Apply a custom color palette to the generated image",
2394
+ )
2395
+
2396
+ artistic_num_colors = gr.Slider(
2397
+ minimum=2,
2398
+ maximum=4,
2399
+ step=1,
2400
+ value=4,
2401
+ label="Number of Colors",
2402
+ info="How many colors to use from the palette (2-4)",
2403
+ visible=False,
2404
+ )
2405
+
2406
+ # Colors 1 & 2 (QR code colors - hidden when gradient enabled)
2407
+ with gr.Row(
2408
+ visible=False
2409
+ ) as artistic_color_pickers_row_1_2:
2410
+ artistic_color_1 = gr.ColorPicker(
2411
+ label="Color 1 (QR Dark)",
2412
+ value="#000000",
2413
+ info="Preserved when using gradients",
2414
+ )
2415
+ artistic_color_2 = gr.ColorPicker(
2416
+ label="Color 2 (QR Light)",
2417
+ value="#FFFFFF",
2418
+ info="Preserved when using gradients",
2419
+ )
2420
+
2421
+ # Colors 3 & 4 (Background colors - always editable)
2422
+ with gr.Row(
2423
+ visible=False
2424
+ ) as artistic_color_pickers_row_3_4:
2425
+ artistic_color_3 = gr.ColorPicker(
2426
+ label="Color 3 (Background)", value="#FF0000"
2427
+ )
2428
+ artistic_color_4 = gr.ColorPicker(
2429
+ label="Color 4 (Background)", value="#00FF00"
2430
+ )
2431
+
2432
+ # Gradient Filter Section (nested under color quantization)
2433
+ artistic_apply_gradient_filter = gr.Checkbox(
2434
+ label="Apply Gradient Filter",
2435
+ value=False,
2436
+ visible=False,
2437
+ elem_id="artistic_gradient_checkbox",
2438
+ info="Create gradient variations around colors 3-4 while preserving colors 1-2 for QR scannability",
2439
+ )
2440
+
2441
+ artistic_gradient_strength = gr.Slider(
2442
+ minimum=0.1,
2443
+ maximum=1.0,
2444
+ step=0.1,
2445
+ value=0.3,
2446
+ label="Gradient Strength",
2447
+ info="Brightness variation (0.3 = ±30%)",
2448
+ visible=False,
2449
+ )
2450
+
2451
+ artistic_variation_steps = gr.Slider(
2452
+ minimum=1,
2453
+ maximum=10,
2454
+ step=1,
2455
+ value=5,
2456
+ label="Variation Steps",
2457
+ info="Number of gradient steps (higher = smoother)",
2458
+ visible=False,
2459
+ )
2460
+
2461
+ # Visibility toggle for gradient filter
2462
+ artistic_apply_gradient_filter.change(
2463
+ fn=lambda gradient_enabled: (
2464
+ gr.update(visible=gradient_enabled),
2465
+ gr.update(visible=gradient_enabled),
2466
+ gr.update(
2467
+ visible=not gradient_enabled
2468
+ ), # Hide colors 1&2 when gradient ON
2469
+ ),
2470
+ inputs=[artistic_apply_gradient_filter],
2471
+ outputs=[
2472
+ artistic_gradient_strength,
2473
+ artistic_variation_steps,
2474
+ artistic_color_pickers_row_1_2,
2475
+ ],
2476
+ )
2477
+
2478
+ # Visibility toggle for color quantization
2479
+ artistic_enable_color_quantization.change(
2480
+ fn=lambda enabled: (
2481
+ gr.update(visible=enabled),
2482
+ gr.update(visible=enabled),
2483
+ gr.update(visible=enabled),
2484
+ gr.update(visible=enabled),
2485
+ ),
2486
+ inputs=[artistic_enable_color_quantization],
2487
+ outputs=[
2488
+ artistic_num_colors,
2489
+ artistic_color_pickers_row_1_2,
2490
+ artistic_color_pickers_row_3_4,
2491
+ artistic_apply_gradient_filter,
2492
+ ],
2493
+ )
2494
+
2495
  # Add seed controls for artistic QR
2496
  artistic_use_custom_seed = gr.Checkbox(
2497
  label="Use Custom Seed",
 
2647
  sag_blur_sigma,
2648
  controlnet_strength_first,
2649
  controlnet_strength_final,
2650
+ artistic_enable_color_quantization,
2651
+ artistic_num_colors,
2652
+ artistic_color_1,
2653
+ artistic_color_2,
2654
+ artistic_color_3,
2655
+ artistic_color_4,
2656
+ artistic_apply_gradient_filter,
2657
+ artistic_gradient_strength,
2658
+ artistic_variation_steps,
2659
  ],
2660
  outputs=[
2661
  artistic_output_image,
 
2691
  sag_blur_sigma,
2692
  controlnet_strength_first,
2693
  controlnet_strength_final,
2694
+ artistic_enable_color_quantization,
2695
+ artistic_num_colors,
2696
+ artistic_color_1,
2697
+ artistic_color_2,
2698
+ artistic_color_3,
2699
+ artistic_color_4,
2700
+ artistic_apply_gradient_filter,
2701
+ artistic_gradient_strength,
2702
+ artistic_variation_steps,
2703
  import_status_artistic,
2704
  ],
2705
  )
 
3300
  info="Enable FreeU quality enhancement (disabled by default for standard pipeline)",
3301
  )
3302
 
3303
+ # Color Quantization Section
3304
+ gr.Markdown("### Color Quantization (Optional)")
3305
+ enable_color_quantization = gr.Checkbox(
3306
+ label="Enable Color Quantization",
3307
+ value=False,
3308
+ info="Apply a custom color palette to the generated image",
3309
+ )
3310
+
3311
+ num_colors = gr.Slider(
3312
+ minimum=2,
3313
+ maximum=4,
3314
+ step=1,
3315
+ value=4,
3316
+ label="Number of Colors",
3317
+ info="How many colors to use from the palette (2-4)",
3318
+ visible=False,
3319
+ )
3320
+
3321
+ # Colors 1 & 2 (QR code colors - hidden when gradient enabled)
3322
+ with gr.Row(visible=False) as color_pickers_row_1_2:
3323
+ color_1 = gr.ColorPicker(
3324
+ label="Color 1 (QR Dark)",
3325
+ value="#000000",
3326
+ info="Preserved when using gradients",
3327
+ )
3328
+ color_2 = gr.ColorPicker(
3329
+ label="Color 2 (QR Light)",
3330
+ value="#FFFFFF",
3331
+ info="Preserved when using gradients",
3332
+ )
3333
+
3334
+ # Colors 3 & 4 (Background colors - always editable)
3335
+ with gr.Row(visible=False) as color_pickers_row_3_4:
3336
+ color_3 = gr.ColorPicker(
3337
+ label="Color 3 (Background)", value="#FF0000"
3338
+ )
3339
+ color_4 = gr.ColorPicker(
3340
+ label="Color 4 (Background)", value="#00FF00"
3341
+ )
3342
+
3343
+ # Gradient Filter Section (nested under color quantization)
3344
+ apply_gradient_filter = gr.Checkbox(
3345
+ label="Apply Gradient Filter",
3346
+ value=False,
3347
+ visible=False,
3348
+ elem_id="gradient_checkbox",
3349
+ info="Create gradient variations around colors 3-4 while preserving colors 1-2 for QR scannability",
3350
+ )
3351
+
3352
+ gradient_strength = gr.Slider(
3353
+ minimum=0.1,
3354
+ maximum=1.0,
3355
+ step=0.1,
3356
+ value=0.3,
3357
+ label="Gradient Strength",
3358
+ info="Brightness variation (0.3 = ±30%)",
3359
+ visible=False,
3360
+ )
3361
+
3362
+ variation_steps = gr.Slider(
3363
+ minimum=1,
3364
+ maximum=10,
3365
+ step=1,
3366
+ value=5,
3367
+ label="Variation Steps",
3368
+ info="Number of gradient steps (higher = smoother)",
3369
+ visible=False,
3370
+ )
3371
+
3372
+ # Visibility toggle for gradient filter
3373
+ apply_gradient_filter.change(
3374
+ fn=lambda gradient_enabled: (
3375
+ gr.update(visible=gradient_enabled),
3376
+ gr.update(visible=gradient_enabled),
3377
+ gr.update(
3378
+ visible=not gradient_enabled
3379
+ ), # Hide colors 1&2 when gradient ON
3380
+ ),
3381
+ inputs=[apply_gradient_filter],
3382
+ outputs=[
3383
+ gradient_strength,
3384
+ variation_steps,
3385
+ color_pickers_row_1_2,
3386
+ ],
3387
+ )
3388
+
3389
+ # Visibility toggle for color quantization
3390
+ enable_color_quantization.change(
3391
+ fn=lambda enabled: (
3392
+ gr.update(visible=enabled),
3393
+ gr.update(visible=enabled),
3394
+ gr.update(visible=enabled),
3395
+ gr.update(visible=enabled),
3396
+ ),
3397
+ inputs=[enable_color_quantization],
3398
+ outputs=[
3399
+ num_colors,
3400
+ color_pickers_row_1_2,
3401
+ color_pickers_row_3_4,
3402
+ apply_gradient_filter,
3403
+ ],
3404
+ )
3405
+
3406
  # Add seed controls
3407
  use_custom_seed = gr.Checkbox(
3408
  label="Use Custom Seed",
 
3485
  enable_freeu_standard,
3486
  controlnet_strength_standard_first,
3487
  controlnet_strength_standard_final,
3488
+ enable_color_quantization,
3489
+ num_colors,
3490
+ color_1,
3491
+ color_2,
3492
+ color_3,
3493
+ color_4,
3494
+ apply_gradient_filter,
3495
+ gradient_strength,
3496
+ variation_steps,
3497
  ],
3498
  outputs=[
3499
  output_image,
 
3522
  enable_freeu_standard,
3523
  controlnet_strength_standard_first,
3524
  controlnet_strength_standard_final,
3525
+ enable_color_quantization,
3526
+ num_colors,
3527
+ color_1,
3528
+ color_2,
3529
+ color_3,
3530
+ color_4,
3531
+ apply_gradient_filter,
3532
+ gradient_strength,
3533
+ variation_steps,
3534
  import_status_standard,
3535
  ],
3536
  )