Qrverse commited on
Commit
95cf55f
Β·
verified Β·
1 Parent(s): 75920a8

Initial handler: SD 1.5 + QR Monster v2, adaptive 2/3 pass pipeline

Browse files
Files changed (2) hide show
  1. handler.py +295 -0
  2. requirements.txt +8 -0
handler.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QR-Verse AI Art Generator β€” HuggingFace Inference Endpoint Handler
3
+
4
+ Adaptive multi-pass pipeline:
5
+ Pass 1 (ART): txt2img + ControlNet at category-specific cn_weight β†’ creative art
6
+ Pass 2 (QR FORCE): img2img + ControlNet at higher scale β†’ embed QR pattern
7
+ Pass 3 (RESCUE, optional): img2img + ControlNet at max scale β†’ force scannable QR
8
+
9
+ Models:
10
+ - Checkpoint: SG161222/Realistic_Vision_V5.1_noVAE (SD 1.5)
11
+ - ControlNet: monster-labs/control_v1p_sd15_qrcode_monster (v2)
12
+
13
+ Key differentiator vs Replicate:
14
+ - control_guidance_start/end support (0.05 / 0.85)
15
+ - Category-aware cn_weight (1.38 geometric vs 1.80 texture)
16
+ - Adaptive pass count based on category difficulty
17
+ """
18
+
19
+ import base64
20
+ import io
21
+ import logging
22
+ import time
23
+ from typing import Any
24
+
25
+ import torch
26
+ from diffusers import (
27
+ ControlNetModel,
28
+ StableDiffusionControlNetPipeline,
29
+ StableDiffusionControlNetImg2ImgPipeline,
30
+ UniPCMultistepScheduler,
31
+ )
32
+ from PIL import Image
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Category parameter presets (extracted from 71K ChromaDB generation learnings)
38
+ # ---------------------------------------------------------------------------
39
+ # Two cn_weight clusters:
40
+ # 1.80 β†’ high-texture categories (food, luxury, wedding, sports)
41
+ # 1.38 β†’ geometric/structural categories (architecture, nature, tech)
42
+ # Categories with <35% accept rate get 3 passes instead of 2.
43
+
44
+ CATEGORY_PARAMS = {
45
+ # High-texture cluster (cn_weight=1.80, 2 passes)
46
+ "food": {"cn_weight": 1.80, "cfg": 7.5, "steps": 50, "passes": 2},
47
+ "luxury": {"cn_weight": 1.80, "cfg": 7.5, "steps": 50, "passes": 2},
48
+ "wedding": {"cn_weight": 1.80, "cfg": 7.5, "steps": 50, "passes": 2},
49
+ "sports": {"cn_weight": 1.80, "cfg": 7.5, "steps": 50, "passes": 2},
50
+ "restaurant": {"cn_weight": 1.80, "cfg": 7.5, "steps": 50, "passes": 2},
51
+ "retail": {"cn_weight": 1.80, "cfg": 7.5, "steps": 50, "passes": 2},
52
+ # Geometric cluster (cn_weight=1.38, 2-3 passes)
53
+ "architecture": {"cn_weight": 1.38, "cfg": 7.5, "steps": 40, "passes": 3},
54
+ "nature": {"cn_weight": 1.38, "cfg": 7.5, "steps": 40, "passes": 2},
55
+ "social": {"cn_weight": 1.38, "cfg": 7.5, "steps": 40, "passes": 3},
56
+ "seasonal": {"cn_weight": 1.59, "cfg": 7.5, "steps": 40, "passes": 3},
57
+ "tech": {"cn_weight": 1.38, "cfg": 7.5, "steps": 40, "passes": 2},
58
+ "world_wonders": {"cn_weight": 1.38, "cfg": 7.5, "steps": 40, "passes": 2},
59
+ "medieval": {"cn_weight": 1.38, "cfg": 7.5, "steps": 40, "passes": 2},
60
+ "professional": {"cn_weight": 1.80, "cfg": 7.5, "steps": 50, "passes": 2},
61
+ "real_estate": {"cn_weight": 1.80, "cfg": 7.5, "steps": 50, "passes": 2},
62
+ # Default fallback
63
+ "default": {"cn_weight": 1.50, "cfg": 7.5, "steps": 40, "passes": 2},
64
+ }
65
+
66
+
67
+ class EndpointHandler:
68
+ """Custom handler for HuggingFace Inference Endpoints."""
69
+
70
+ def __init__(self, path: str = ""):
71
+ """Load models on endpoint startup."""
72
+ logger.info("Loading QR Art Generator pipeline...")
73
+ start = time.time()
74
+
75
+ device = "cuda" if torch.cuda.is_available() else "cpu"
76
+ dtype = torch.float16 if device == "cuda" else torch.float32
77
+
78
+ # Load QR Monster ControlNet v2
79
+ self.controlnet = ControlNetModel.from_pretrained(
80
+ "monster-labs/control_v1p_sd15_qrcode_monster",
81
+ subfolder="v2",
82
+ torch_dtype=dtype,
83
+ )
84
+
85
+ # Load SD 1.5 txt2img + ControlNet pipeline (Pass 1)
86
+ self.pipe_txt2img = StableDiffusionControlNetPipeline.from_pretrained(
87
+ "SG161222/Realistic_Vision_V5.1_noVAE",
88
+ controlnet=self.controlnet,
89
+ torch_dtype=dtype,
90
+ safety_checker=None,
91
+ requires_safety_checker=False,
92
+ )
93
+ self.pipe_txt2img.scheduler = UniPCMultistepScheduler.from_config(
94
+ self.pipe_txt2img.scheduler.config
95
+ )
96
+
97
+ # Load img2img + ControlNet pipeline (Pass 2/3)
98
+ self.pipe_img2img = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
99
+ "SG161222/Realistic_Vision_V5.1_noVAE",
100
+ controlnet=self.controlnet,
101
+ torch_dtype=dtype,
102
+ safety_checker=None,
103
+ requires_safety_checker=False,
104
+ )
105
+ self.pipe_img2img.scheduler = UniPCMultistepScheduler.from_config(
106
+ self.pipe_img2img.scheduler.config
107
+ )
108
+
109
+ # Move to device + optimize
110
+ self.pipe_txt2img.to(device)
111
+ self.pipe_img2img.to(device)
112
+
113
+ if device == "cuda":
114
+ try:
115
+ self.pipe_txt2img.enable_xformers_memory_efficient_attention()
116
+ self.pipe_img2img.enable_xformers_memory_efficient_attention()
117
+ logger.info("xformers memory-efficient attention enabled")
118
+ except Exception:
119
+ logger.warning("xformers not available, using default attention")
120
+
121
+ self.device = device
122
+ self.dtype = dtype
123
+ elapsed = time.time() - start
124
+ logger.info(f"Pipeline loaded in {elapsed:.1f}s on {device}")
125
+
126
+ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
127
+ """
128
+ Generate QR art from input parameters.
129
+
130
+ Input JSON:
131
+ {
132
+ "inputs": {
133
+ "prompt": str, # Required
134
+ "negative_prompt": str, # Optional
135
+ "qr_code_image": str, # Required β€” base64 PNG of QR code
136
+ "category": str, # Optional β€” maps to CATEGORY_PARAMS
137
+ "seed": int, # Optional β€” -1 for random
138
+ "width": int, # Optional β€” default 768
139
+ "height": int, # Optional β€” default 768
140
+ "num_passes": int, # Optional β€” override auto pass count
141
+ "controlnet_scale": float, # Optional β€” override category cn_weight
142
+ "guidance_scale": float, # Optional β€” override category cfg
143
+ "num_inference_steps": int, # Optional β€” override category steps
144
+ "control_guidance_start": float, # Optional β€” default 0.05
145
+ "control_guidance_end": float, # Optional β€” default 0.85
146
+ }
147
+ }
148
+
149
+ Output JSON:
150
+ {
151
+ "image": str, # base64 PNG
152
+ "passes_run": int,
153
+ "parameters": dict, # actual parameters used
154
+ "time_seconds": float,
155
+ }
156
+ """
157
+ start = time.time()
158
+
159
+ inputs = data.get("inputs", data)
160
+ prompt = inputs.get("prompt", "")
161
+ negative_prompt = inputs.get(
162
+ "negative_prompt",
163
+ "ugly, disfigured, low quality, blurry, nsfw, text, watermark",
164
+ )
165
+ qr_b64 = inputs.get("qr_code_image", "")
166
+
167
+ if not prompt:
168
+ return {"error": "prompt is required"}
169
+ if not qr_b64:
170
+ return {"error": "qr_code_image (base64 PNG) is required"}
171
+
172
+ # Decode QR code image
173
+ try:
174
+ qr_image = Image.open(io.BytesIO(base64.b64decode(qr_b64))).convert("RGB")
175
+ except Exception as e:
176
+ return {"error": f"Failed to decode qr_code_image: {e}"}
177
+
178
+ # Resolve parameters
179
+ category = inputs.get("category", "default")
180
+ params = CATEGORY_PARAMS.get(category, CATEGORY_PARAMS["default"])
181
+
182
+ cn_weight = inputs.get("controlnet_scale", params["cn_weight"])
183
+ cfg = inputs.get("guidance_scale", params["cfg"])
184
+ steps = inputs.get("num_inference_steps", params["steps"])
185
+ num_passes = inputs.get("num_passes", params["passes"])
186
+ width = inputs.get("width", 768)
187
+ height = inputs.get("height", 768)
188
+ control_start = inputs.get("control_guidance_start", 0.05)
189
+ control_end = inputs.get("control_guidance_end", 0.85)
190
+
191
+ # Seed
192
+ seed = inputs.get("seed", -1)
193
+ if seed == -1:
194
+ generator = torch.Generator(device=self.device)
195
+ seed = generator.seed()
196
+ else:
197
+ generator = torch.Generator(device=self.device).manual_seed(seed)
198
+
199
+ # Resize QR code to target dimensions
200
+ qr_image = qr_image.resize((width, height), Image.LANCZOS)
201
+
202
+ # ---- Pass 1: txt2img + ControlNet (ART pass) ----
203
+ logger.info(
204
+ f"Pass 1/{ num_passes}: txt2img cn={cn_weight} cfg={cfg} steps={steps}"
205
+ )
206
+ result = self.pipe_txt2img(
207
+ prompt=prompt,
208
+ negative_prompt=negative_prompt,
209
+ image=qr_image,
210
+ width=width,
211
+ height=height,
212
+ guidance_scale=cfg,
213
+ controlnet_conditioning_scale=cn_weight,
214
+ control_guidance_start=control_start,
215
+ control_guidance_end=control_end,
216
+ num_inference_steps=steps,
217
+ generator=generator,
218
+ )
219
+ art_image = result.images[0]
220
+
221
+ # ---- Pass 2: img2img + ControlNet (QR FORCE pass) ----
222
+ if num_passes >= 2:
223
+ p2_cn = cn_weight + 0.4
224
+ p2_cfg = 10.0
225
+ p2_strength = 0.35
226
+ p2_steps = 30
227
+
228
+ logger.info(
229
+ f"Pass 2/{num_passes}: img2img cn={p2_cn} cfg={p2_cfg} "
230
+ f"strength={p2_strength} steps={p2_steps}"
231
+ )
232
+ result = self.pipe_img2img(
233
+ prompt=prompt,
234
+ negative_prompt=negative_prompt,
235
+ image=art_image,
236
+ control_image=qr_image,
237
+ strength=p2_strength,
238
+ guidance_scale=p2_cfg,
239
+ controlnet_conditioning_scale=p2_cn,
240
+ control_guidance_start=control_start,
241
+ control_guidance_end=control_end,
242
+ num_inference_steps=p2_steps,
243
+ generator=generator,
244
+ )
245
+ art_image = result.images[0]
246
+
247
+ # ---- Pass 3: img2img + ControlNet (RESCUE pass) ----
248
+ if num_passes >= 3:
249
+ p3_cn = cn_weight + 0.8
250
+ p3_cfg = 13.0
251
+ p3_strength = 0.45
252
+ p3_steps = 25
253
+
254
+ logger.info(
255
+ f"Pass 3/{num_passes}: img2img cn={p3_cn} cfg={p3_cfg} "
256
+ f"strength={p3_strength} steps={p3_steps}"
257
+ )
258
+ result = self.pipe_img2img(
259
+ prompt=prompt,
260
+ negative_prompt=negative_prompt,
261
+ image=art_image,
262
+ control_image=qr_image,
263
+ strength=p3_strength,
264
+ guidance_scale=p3_cfg,
265
+ controlnet_conditioning_scale=p3_cn,
266
+ control_guidance_start=control_start,
267
+ control_guidance_end=control_end,
268
+ num_inference_steps=p3_steps,
269
+ generator=generator,
270
+ )
271
+ art_image = result.images[0]
272
+
273
+ # Encode result to base64 PNG
274
+ buf = io.BytesIO()
275
+ art_image.save(buf, format="PNG")
276
+ result_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
277
+
278
+ elapsed = time.time() - start
279
+
280
+ return {
281
+ "image": result_b64,
282
+ "passes_run": num_passes,
283
+ "seed": seed,
284
+ "parameters": {
285
+ "category": category,
286
+ "controlnet_scale_p1": cn_weight,
287
+ "guidance_scale_p1": cfg,
288
+ "steps_p1": steps,
289
+ "control_guidance_start": control_start,
290
+ "control_guidance_end": control_end,
291
+ "width": width,
292
+ "height": height,
293
+ },
294
+ "time_seconds": round(elapsed, 2),
295
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ diffusers>=0.27.0
2
+ transformers>=4.38.0
3
+ accelerate>=0.27.0
4
+ torch>=2.1.0
5
+ xformers>=0.0.23
6
+ safetensors
7
+ Pillow
8
+ controlnet-aux