fix: gray QR preprocessing + boost cn_weights for scannability
Browse files- handler.py +37 -24
handler.py
CHANGED
|
@@ -22,6 +22,7 @@ import logging
|
|
| 22 |
import time
|
| 23 |
from typing import Any
|
| 24 |
|
|
|
|
| 25 |
import torch
|
| 26 |
from diffusers import (
|
| 27 |
ControlNetModel,
|
|
@@ -42,25 +43,27 @@ logger = logging.getLogger(__name__)
|
|
| 42 |
# Categories with <35% accept rate get 3 passes instead of 2.
|
| 43 |
|
| 44 |
CATEGORY_PARAMS = {
|
| 45 |
-
# High-texture cluster (
|
| 46 |
-
|
| 47 |
-
"
|
| 48 |
-
"
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"
|
|
|
|
|
|
|
| 62 |
# Default fallback
|
| 63 |
-
"default": {"cn_weight":
|
| 64 |
}
|
| 65 |
|
| 66 |
|
|
@@ -175,6 +178,16 @@ class EndpointHandler:
|
|
| 175 |
except Exception as e:
|
| 176 |
return {"error": f"Failed to decode qr_code_image: {e}"}
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
# Resolve parameters
|
| 179 |
category = inputs.get("category", "default")
|
| 180 |
params = CATEGORY_PARAMS.get(category, CATEGORY_PARAMS["default"])
|
|
@@ -220,9 +233,9 @@ class EndpointHandler:
|
|
| 220 |
|
| 221 |
# ---- Pass 2: img2img + ControlNet (QR FORCE pass) ----
|
| 222 |
if num_passes >= 2:
|
| 223 |
-
p2_cn = cn_weight + 0.
|
| 224 |
-
p2_cfg =
|
| 225 |
-
p2_strength = 0.
|
| 226 |
p2_steps = 30
|
| 227 |
|
| 228 |
logger.info(
|
|
@@ -246,9 +259,9 @@ class EndpointHandler:
|
|
| 246 |
|
| 247 |
# ---- Pass 3: img2img + ControlNet (RESCUE pass) ----
|
| 248 |
if num_passes >= 3:
|
| 249 |
-
p3_cn = cn_weight +
|
| 250 |
-
p3_cfg =
|
| 251 |
-
p3_strength = 0.
|
| 252 |
p3_steps = 25
|
| 253 |
|
| 254 |
logger.info(
|
|
|
|
| 22 |
import time
|
| 23 |
from typing import Any
|
| 24 |
|
| 25 |
+
import numpy as np
|
| 26 |
import torch
|
| 27 |
from diffusers import (
|
| 28 |
ControlNetModel,
|
|
|
|
| 43 |
# Categories with <35% accept rate get 3 passes instead of 2.
|
| 44 |
|
| 45 |
CATEGORY_PARAMS = {
|
| 46 |
+
# High-texture cluster (2 passes)
|
| 47 |
+
# Boosted cn_weight from 1.80 → 2.15 to compensate for 0.05→0.85 guidance window
|
| 48 |
+
"food": {"cn_weight": 2.15, "cfg": 7.5, "steps": 50, "passes": 2},
|
| 49 |
+
"luxury": {"cn_weight": 2.15, "cfg": 7.5, "steps": 50, "passes": 2},
|
| 50 |
+
"wedding": {"cn_weight": 2.15, "cfg": 7.5, "steps": 50, "passes": 2},
|
| 51 |
+
"sports": {"cn_weight": 2.15, "cfg": 7.5, "steps": 50, "passes": 2},
|
| 52 |
+
"restaurant": {"cn_weight": 2.15, "cfg": 7.5, "steps": 50, "passes": 2},
|
| 53 |
+
"retail": {"cn_weight": 2.15, "cfg": 7.5, "steps": 50, "passes": 2},
|
| 54 |
+
# Geometric cluster (2-3 passes)
|
| 55 |
+
# Boosted cn_weight from 1.38 → 1.85
|
| 56 |
+
"architecture": {"cn_weight": 1.85, "cfg": 7.5, "steps": 40, "passes": 3},
|
| 57 |
+
"nature": {"cn_weight": 1.85, "cfg": 7.5, "steps": 40, "passes": 2},
|
| 58 |
+
"social": {"cn_weight": 1.85, "cfg": 7.5, "steps": 40, "passes": 3},
|
| 59 |
+
"seasonal": {"cn_weight": 2.00, "cfg": 7.5, "steps": 40, "passes": 3},
|
| 60 |
+
"tech": {"cn_weight": 1.85, "cfg": 7.5, "steps": 40, "passes": 2},
|
| 61 |
+
"world_wonders": {"cn_weight": 1.85, "cfg": 7.5, "steps": 40, "passes": 2},
|
| 62 |
+
"medieval": {"cn_weight": 1.85, "cfg": 7.5, "steps": 40, "passes": 2},
|
| 63 |
+
"professional": {"cn_weight": 2.15, "cfg": 7.5, "steps": 50, "passes": 2},
|
| 64 |
+
"real_estate": {"cn_weight": 2.15, "cfg": 7.5, "steps": 50, "passes": 2},
|
| 65 |
# Default fallback
|
| 66 |
+
"default": {"cn_weight": 2.00, "cfg": 7.5, "steps": 40, "passes": 2},
|
| 67 |
}
|
| 68 |
|
| 69 |
|
|
|
|
| 178 |
except Exception as e:
|
| 179 |
return {"error": f"Failed to decode qr_code_image: {e}"}
|
| 180 |
|
| 181 |
+
# CRITICAL: Preprocess QR — ensure gray background (#808080)
|
| 182 |
+
# Monster v2 ControlNet was trained on gray-background QR codes.
|
| 183 |
+
# White background gives wrong contrast signals and breaks scannability.
|
| 184 |
+
qr_array = np.array(qr_image)
|
| 185 |
+
white_mask = np.all(qr_array > 200, axis=2)
|
| 186 |
+
if np.sum(white_mask) > 0:
|
| 187 |
+
logger.info("Converting white QR background to gray (#808080)")
|
| 188 |
+
qr_array[white_mask] = [128, 128, 128]
|
| 189 |
+
qr_image = Image.fromarray(qr_array)
|
| 190 |
+
|
| 191 |
# Resolve parameters
|
| 192 |
category = inputs.get("category", "default")
|
| 193 |
params = CATEGORY_PARAMS.get(category, CATEGORY_PARAMS["default"])
|
|
|
|
| 233 |
|
| 234 |
# ---- Pass 2: img2img + ControlNet (QR FORCE pass) ----
|
| 235 |
if num_passes >= 2:
|
| 236 |
+
p2_cn = cn_weight + 0.6
|
| 237 |
+
p2_cfg = 12.0
|
| 238 |
+
p2_strength = 0.45
|
| 239 |
p2_steps = 30
|
| 240 |
|
| 241 |
logger.info(
|
|
|
|
| 259 |
|
| 260 |
# ---- Pass 3: img2img + ControlNet (RESCUE pass) ----
|
| 261 |
if num_passes >= 3:
|
| 262 |
+
p3_cn = cn_weight + 1.0
|
| 263 |
+
p3_cfg = 14.0
|
| 264 |
+
p3_strength = 0.55
|
| 265 |
p3_steps = 25
|
| 266 |
|
| 267 |
logger.info(
|