DiffQRCode / diffqrcoder_wrapper.py
sayshara's picture
Changed starting params.
1548bbe
# diffqrcoder_wrapper.py
import torch
from diffusers import ControlNetModel, DDIMScheduler
from PIL import Image
import qrcode
from huggingface_hub import hf_hub_download
from diffqrcoder import DiffQRCoderPipeline
# ---- Defaults taken from run_diffqrcoder.py ---- #
CONTROLNET_CKPT = "monster-labs/control_v1p_sd15_qrcode_monster"
PIPE_REPO_ID = "fp16-guy/Cetus-Mix_Whalefall_fp16_cleaned"
PIPE_FILENAME = "cetusMix_Whalefall2_fp16.safetensors"
DEVICE = "cuda"
_controlnet = None
_pipe = None
def _make_qr_image(
data: str,
box_size: int = 20,
border: int = 4,
) -> Image.Image:
qr = qrcode.QRCode(
version=None,
error_correction=qrcode.constants.ERROR_CORRECT_H,
box_size=box_size,
border=border,
)
qr.add_data(data)
qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white").convert("RGB")
return img
def load_pipeline():
"""
Lazily load ControlNet + DiffQRCoderPipeline.
"""
global _controlnet, _pipe
if _pipe is not None:
return _pipe
print("πŸ”§ Loading ControlNet...")
if _controlnet is None:
_controlnet = ControlNetModel.from_pretrained(
CONTROLNET_CKPT,
torch_dtype=torch.float16,
)
print("βœ… ControlNet loaded.")
print("πŸ”§ Downloading base model checkpoint from Hub...")
ckpt_path = hf_hub_download(
repo_id=PIPE_REPO_ID,
filename=PIPE_FILENAME,
local_dir="models",
local_dir_use_symlinks=False,
)
print("βœ… Base model checkpoint at:", ckpt_path)
print("πŸ”§ Building DiffQRCoderPipeline from checkpoint...")
pipe = DiffQRCoderPipeline.from_single_file(
ckpt_path,
controlnet=_controlnet,
torch_dtype=torch.float16,
use_auth_token=True, # uses the Space's HF token
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# Memory helpers – cheaper attention
try:
# pipe.enable_attention_slicing()
try:
pipe.enable_xformers_memory_efficient_attention()
print("βœ… xFormers attention enabled.")
except Exception as e:
print("⚠️ xFormers not available:", repr(e))
except Exception as e:
print("⚠️ Could not enable attention optimizations:", repr(e))
print("βœ… Pipeline constructed on CPU.")
_pipe = pipe
return _pipe
def generate_qr_art(
pipe: DiffQRCoderPipeline,
url_or_text: str,
prompt: str,
neg_prompt: str = "harsh edges, high contrast QR blockiness, noise, muddy colors, ugly, disfigured, low quality, blurry, nsfw",
num_inference_steps: int = 14, # gentler default
qrcode_module_size: int = 40,
qrcode_padding: int = 78,
controlnet_conditioning_scale: float = 1.35,
scanning_robust_guidance_scale: float = 300.0, # softer default
perceptual_guidance_scale: float = 2.0,
srmpgd_num_iteration: int | None = 0, # 0 = disable SR-MPGD by default
srmpgd_lr: float = 0.1,
seed: int = 42,
) -> Image.Image:
assert pipe is not None, "Pipeline must be loaded before calling generate_qr_art"
print("✨ generate_qr_art() starting...")
generator = torch.Generator(device=DEVICE).manual_seed(seed)
qrcode_img = _make_qr_image(
data=url_or_text,
box_size=qrcode_module_size,
border=4,
)
print("✨ Starting DiffQRCoder forward pass...")
result = pipe(
prompt=prompt,
qrcode=qrcode_img,
qrcode_module_size=qrcode_module_size,
qrcode_padding=qrcode_padding,
negative_prompt=neg_prompt,
num_inference_steps=num_inference_steps,
generator=generator,
controlnet_conditioning_scale=controlnet_conditioning_scale,
scanning_robust_guidance_scale=scanning_robust_guidance_scale,
perceptual_guidance_scale=perceptual_guidance_scale,
srmpgd_num_iteration=srmpgd_num_iteration,
srmpgd_lr=srmpgd_lr,
)
print("βœ… DiffQRCoder forward pass finished.")
return result.images[0]