File size: 4,027 Bytes
70be616
 
f1659ba
70be616
 
769c814
70be616
 
 
 
769c814
70be616
 
769c814
 
 
 
70be616
 
 
 
 
 
 
769c814
 
70be616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bd4fea
70be616
 
 
 
 
6bd4fea
70be616
6bd4fea
769c814
 
 
6bd4fea
 
769c814
6bd4fea
769c814
6bd4fea
70be616
769c814
70be616
 
769c814
70be616
 
f1659ba
70be616
6bd4fea
 
f1659ba
 
 
 
 
 
 
6bd4fea
 
 
 
70be616
 
 
 
 
6bd4fea
70be616
 
 
6bd4fea
70be616
 
 
6bd4fea
70be616
6bd4fea
70be616
 
 
6bd4fea
 
 
 
70be616
 
 
769c814
 
70be616
 
6bd4fea
70be616
 
 
 
 
 
 
 
 
 
 
 
 
 
6bd4fea
70be616
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# diffqrcoder_wrapper.py
import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from PIL import Image
import qrcode
from huggingface_hub import hf_hub_download

from diffqrcoder import DiffQRCoderPipeline

# ---- Defaults taken from run_diffqrcoder.py ---- #

CONTROLNET_CKPT = "monster-labs/control_v1p_sd15_qrcode_monster"

PIPE_REPO_ID = "fp16-guy/Cetus-Mix_Whalefall_fp16_cleaned"
PIPE_FILENAME = "cetusMix_Whalefall2_fp16.safetensors"

DEVICE = "cuda"

_controlnet = None
_pipe = None


def _make_qr_image(
    data: str,
    box_size: int = 20,
    border: int = 4,
) -> Image.Image:
    qr = qrcode.QRCode(
        version=None,
        error_correction=qrcode.constants.ERROR_CORRECT_H,
        box_size=box_size,
        border=border,
    )
    qr.add_data(data)
    qr.make(fit=True)
    img = qr.make_image(fill_color="black", back_color="white").convert("RGB")
    return img


def load_pipeline():
    """
    Lazily load ControlNet + DiffQRCoderPipeline.
    """
    global _controlnet, _pipe

    if _pipe is not None:
        return _pipe

    print("πŸ”§ Loading ControlNet...")
    if _controlnet is None:
        _controlnet = ControlNetModel.from_pretrained(
            CONTROLNET_CKPT,
            torch_dtype=torch.float16,
        )
    print("βœ… ControlNet loaded.")

    print("πŸ”§ Downloading base model checkpoint from Hub...")
    ckpt_path = hf_hub_download(
        repo_id=PIPE_REPO_ID,
        filename=PIPE_FILENAME,
        local_dir="models",
        local_dir_use_symlinks=False,
    )
    print("βœ… Base model checkpoint at:", ckpt_path)

    print("πŸ”§ Building DiffQRCoderPipeline from checkpoint...")
    pipe = DiffQRCoderPipeline.from_single_file(
        ckpt_path,
        controlnet=_controlnet,
        torch_dtype=torch.float16,
        use_auth_token=True,  # uses the Space's HF token
    )

    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

    # Memory helpers – cheaper attention
    try:
        # pipe.enable_attention_slicing()
        try:
            pipe.enable_xformers_memory_efficient_attention()
            print("βœ… xFormers attention enabled.")
        except Exception as e:
            print("⚠️ xFormers not available:", repr(e))

    except Exception as e:
        print("⚠️ Could not enable attention optimizations:", repr(e))

    print("βœ… Pipeline constructed on CPU.")
    _pipe = pipe
    return _pipe


def generate_qr_art(
    pipe: DiffQRCoderPipeline,
    url_or_text: str,
    prompt: str,
    neg_prompt: str = "easynegative",
    num_inference_steps: int = 20,          # gentler default
    qrcode_module_size: int = 20,
    qrcode_padding: int = 78,
    controlnet_conditioning_scale: float = 1.35,
    scanning_robust_guidance_scale: float = 300.0,  # softer default
    perceptual_guidance_scale: float = 2.0,
    srmpgd_num_iteration: int | None = 0,   # 0 = disable SR-MPGD by default
    srmpgd_lr: float = 0.1,
    seed: int = 1,
) -> Image.Image:
    assert pipe is not None, "Pipeline must be loaded before calling generate_qr_art"

    print("✨ generate_qr_art() starting...")
    generator = torch.Generator(device=DEVICE).manual_seed(seed)

    qrcode_img = _make_qr_image(
        data=url_or_text,
        box_size=qrcode_module_size,
        border=4,
    )

    print("✨ Starting DiffQRCoder forward pass...")
    result = pipe(
        prompt=prompt,
        qrcode=qrcode_img,
        qrcode_module_size=qrcode_module_size,
        qrcode_padding=qrcode_padding,
        negative_prompt=neg_prompt,
        num_inference_steps=num_inference_steps,
        generator=generator,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        scanning_robust_guidance_scale=scanning_robust_guidance_scale,
        perceptual_guidance_scale=perceptual_guidance_scale,
        srmpgd_num_iteration=srmpgd_num_iteration,
        srmpgd_lr=srmpgd_lr,
    )
    print("βœ… DiffQRCoder forward pass finished.")
    return result.images[0]