sayshara commited on
Commit
70be616
·
1 Parent(s): e65a359

added diffqrcoder_wrapper

Browse files
app.py CHANGED
@@ -1,7 +1,120 @@
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
+ # app.py
2
  import gradio as gr
3
+ import spaces
4
+ from diffqrcoder_wrapper import generate_qr_art
5
 
6
+ DEFAULT_PROMPT = (
7
+ "whimsical biomimetic blueprint, iridescent inks swirling through "
8
+ "mechanical petals, soft gears woven with luminescent filigree"
9
+ )
10
+
11
+ DEFAULT_NEG = "easynegative"
12
+
13
+
14
+ @spaces.GPU # ZeroGPU: attach GPU only for this call
15
+ def infer(
16
+ url_or_text: str,
17
+ prompt: str,
18
+ num_inference_steps: int,
19
+ controlnet_scale: float,
20
+ scanning_robust_guidance_scale: float,
21
+ perceptual_guidance_scale: float,
22
+ srmpgd_iters: int,
23
+ ):
24
+ # Map 0 → None for SR-MPGD iteration (as in original script)
25
+ srmpgd_num_iteration = None if srmpgd_iters == 0 else srmpgd_iters
26
+
27
+ img = generate_qr_art(
28
+ url_or_text=url_or_text,
29
+ prompt=prompt,
30
+ neg_prompt=DEFAULT_NEG,
31
+ num_inference_steps=num_inference_steps,
32
+ qrcode_module_size=20,
33
+ qrcode_padding=78,
34
+ controlnet_conditioning_scale=controlnet_scale,
35
+ scanning_robust_guidance_scale=scanning_robust_guidance_scale,
36
+ perceptual_guidance_scale=perceptual_guidance_scale,
37
+ srmpgd_num_iteration=srmpgd_num_iteration,
38
+ srmpgd_lr=0.1,
39
+ seed=1,
40
+ )
41
+ return img
42
+
43
+
44
+ with gr.Blocks() as demo:
45
+ gr.Markdown(
46
+ r"""
47
+ # DiffQRCoder – ZeroGPU demo
48
+
49
+ Generate aesthetic, scanning-robust QR codes using the **DiffQRCoder** pipeline
50
+ ([WACV 2025](https://openaccess.thecvf.com/content/WACV2025/html/Liao_DiffQRCoder_Diffusion-Based_Aesthetic_QR_Code_Generation_with_Scanning_Robustness_Guided_WACV_2025_paper.html)) 🚀
51
+ """
52
+ )
53
+
54
+ with gr.Row():
55
+ url = gr.Textbox(
56
+ label="QR contents (URL or text)",
57
+ value="https://example.com",
58
+ )
59
+
60
+ prompt = gr.Textbox(
61
+ label="Style prompt",
62
+ value=DEFAULT_PROMPT,
63
+ lines=3,
64
+ )
65
+
66
+ with gr.Accordion("Advanced parameters", open=False):
67
+ steps = gr.Slider(
68
+ minimum=10,
69
+ maximum=60,
70
+ value=40,
71
+ step=1,
72
+ label="Diffusion steps (num_inference_steps)",
73
+ )
74
+ control_scale = gr.Slider(
75
+ minimum=0.5,
76
+ maximum=2.0,
77
+ value=1.35,
78
+ step=0.05,
79
+ label="ControlNet conditioning scale",
80
+ )
81
+ srg_scale = gr.Slider(
82
+ minimum=0,
83
+ maximum=800,
84
+ value=500,
85
+ step=10,
86
+ label="Scanning-robust guidance scale (srg)",
87
+ )
88
+ pg_scale = gr.Slider(
89
+ minimum=0,
90
+ maximum=10,
91
+ value=2,
92
+ step=0.5,
93
+ label="Perceptual guidance scale (pg)",
94
+ )
95
+ srmpgd_iters = gr.Slider(
96
+ minimum=0,
97
+ maximum=64,
98
+ value=0,
99
+ step=1,
100
+ label="SR-MPGD iterations (0 = disabled)",
101
+ )
102
+
103
+ btn = gr.Button("Generate QR Art ✨", variant="primary")
104
+ out = gr.Image(label="Output QR art", type="pil")
105
+
106
+ btn.click(
107
+ fn=infer,
108
+ inputs=[
109
+ url,
110
+ prompt,
111
+ steps,
112
+ control_scale,
113
+ srg_scale,
114
+ pg_scale,
115
+ srmpgd_iters,
116
+ ],
117
+ outputs=[out],
118
+ )
119
 
 
120
  demo.launch()
diffqrcoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline_diffqrcoder import DiffQRCoderPipeline
diffqrcoder/image_processor.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+
6
+ IMAGE_MAX_VAL = 255
7
+
8
+
9
+ def min_max_normalize(x: torch.Tensor) -> torch.Tensor:
10
+ return (x - x.min()) / (x.max() - x.min())
11
+
12
+
13
+ def convert_to_gray(
14
+ images: torch.Tensor,
15
+ cr: float = 0.2999,
16
+ cg: float = 0.587,
17
+ cb: float = 0.1114,
18
+ ) -> torch.Tensor:
19
+
20
+ assert images.shape[1] == 3, \
21
+ f"The channel of color images must be 3 but get {images.shape[1]}. They are not color images."
22
+
23
+ gray_image = cr * images[:, 0] + cg * images[:, 1] + cb * images[:, 2]
24
+ return gray_image.unsqueeze(1)
25
+
26
+
27
+ def image_binarize(
28
+ image: torch.Tensor,
29
+ binary_threshold: Optional[float] = None,
30
+ ) -> torch.Tensor:
31
+
32
+ if image.shape[1] == 3:
33
+ image = convert_to_gray(image)
34
+
35
+ if binary_threshold is None:
36
+ if image.max() <= 1:
37
+ binary_threshold = 0.5
38
+ else:
39
+ binary_threshold = 0.5 * IMAGE_MAX_VAL
40
+ return (image > binary_threshold).to(image.dtype)
41
+
42
+
43
+ def crop_padding(x: torch.Tensor, padding: int):
44
+ return x[:, :, padding:-padding, padding:-padding]
diffqrcoder/losses/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .perceptual_loss import PerceptualLoss
2
+ from .scanning_robust_loss import ScanningRobustLoss
3
+ from .personalized_code_loss import PersonalizedCodeLoss
diffqrcoder/losses/perceptual_loss.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from torchvision.models import vgg16
8
+ from torchvision.transforms import Normalize
9
+
10
+
11
+ class VGGFeatureExtractor(nn.Module):
12
+ def __init__(
13
+ self,
14
+ requires_grad: bool = False,
15
+ pretrained_weights: str = "DEFAULT",
16
+ ) -> None:
17
+
18
+ super().__init__()
19
+ self.norm = Normalize(
20
+ mean=[0.485, 0.456, 0.406],
21
+ std=[0.229, 0.224, 0.225],
22
+ )
23
+ self.slice_indices = [(0, 4), (4, 9), (9, 16), (16, 23)]
24
+ self.slices = nn.ModuleList([nn.Sequential() for _ in range(len(self.slice_indices))])
25
+ self._initialize_slices(pretrained_weights)
26
+ self.features = namedtuple("Outputs", [f"layer{i}" for i in range(len(self.slice_indices))])
27
+
28
+ if not requires_grad:
29
+ for param in self.parameters():
30
+ param.requires_grad = False
31
+
32
+ def lp_norm(self, x: torch.Tensor) -> torch.Tensor:
33
+ return torch.nn.functional.normalize(x, p=2.0, dim=1)
34
+
35
+ def _initialize_slices(self, pretrained_weights: str = "DEFAULT") -> None:
36
+ features = vgg16(weights=pretrained_weights).features
37
+ for slice_idx, (start, end) in enumerate(self.slice_indices):
38
+ for i in range(start, end):
39
+ self.slices[slice_idx].add_module(str(i), features[i])
40
+
41
+ def forward(self, x: torch.Tensor) -> namedtuple:
42
+ outputs = []
43
+ x = self.norm(x)
44
+ for slice_model in self.slices:
45
+ x = self.lp_norm(slice_model(x))
46
+ outputs.append(x)
47
+ return self.features(*outputs)
48
+
49
+
50
+ class PerceptualLoss(nn.Module):
51
+ def __init__(
52
+ self,
53
+ requires_grad: bool = False,
54
+ pretrained_weights: str = "DEFAULT",
55
+ ):
56
+ super(PerceptualLoss, self).__init__()
57
+ self.extractor = VGGFeatureExtractor(
58
+ pretrained_weights=pretrained_weights,
59
+ requires_grad=requires_grad,
60
+ )
61
+
62
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
63
+ return torch.mean(
64
+ torch.tensor(
65
+ [
66
+ torch.nn.functional.mse_loss(fx, fy)
67
+ for fx, fy in zip(self.extractor(x), self.extractor(y))
68
+ ]
69
+ )
70
+ )
diffqrcoder/losses/personalized_code_loss.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .scanning_robust_loss import ScanningRobustLoss
5
+
6
+
7
+ class PersonalizedCodeLoss(nn.Module):
8
+ def __init__(
9
+ self,
10
+ qrcode_image: torch.Tensor,
11
+ content_image: torch.Tensor,
12
+ module_size: int = 16,
13
+ b_thres: float = 50,
14
+ w_thres: float = 200,
15
+ b_soft_value: float = 40 / 255,
16
+ w_soft_value: float = 220 / 255,
17
+ code_weight: float = 1e12,
18
+ content_weight: float = 1e8,
19
+ device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
20
+ ):
21
+ super(PersonalizedCodeLoss, self).__init__()
22
+ self.code_loss = ScanningRobustLoss(
23
+ module_size=module_size,
24
+ ).to(device)
25
+
26
+ self.content_image = content_image
27
+ self.code_weight = code_weight
28
+ self.content_weight = content_weight
29
+ self.qrcode_image = qrcode_image
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ code_loss = self.code_loss(x, self.qrcode_image)
33
+ perceptual_loss = nn.MSELoss()(x, self.content_image)
34
+ total_loss = (
35
+ self.code_weight * code_loss + \
36
+ self.content_weight * perceptual_loss
37
+ )
38
+ return {
39
+ "code": code_loss,
40
+ "perceptual": perceptual_loss,
41
+ "total": total_loss
42
+ }
diffqrcoder/losses/scanning_robust_loss.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+
8
+ from ..image_processor import convert_to_gray, image_binarize, min_max_normalize
9
+
10
+
11
+ class GaussianFilter(nn.Module):
12
+ def __init__(self, module_size: int, filter_thres: float = 0.1) -> None:
13
+ super().__init__()
14
+ self.module_size = module_size
15
+ self.filter_thres = filter_thres
16
+ self.conv = nn.Conv2d(
17
+ in_channels=1,
18
+ out_channels=1,
19
+ kernel_size=module_size,
20
+ stride=module_size,
21
+ padding=0,
22
+ bias=False,
23
+ groups=1,
24
+ )
25
+ self._setup_filter_weights()
26
+
27
+ def _setup_filter_weights(self) -> None:
28
+ filter_1d = cv2.getGaussianKernel(
29
+ ksize=self.module_size,
30
+ sigma=1.5,
31
+ ktype=cv2.CV_32F
32
+ )
33
+ filter_2d = filter_1d * filter_1d.T
34
+ filter_2d = min_max_normalize(filter_2d)
35
+ filter_2d[filter_2d < self.filter_thres] = .0
36
+ gaussian_filter_init = torch.tensor(filter_2d, dtype=torch.float32)
37
+ self.conv.weight = nn.Parameter(
38
+ gaussian_filter_init.reshape(1, 1, *gaussian_filter_init.shape),
39
+ requires_grad=False,
40
+ )
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ return self.conv(x)
44
+
45
+
46
+ class RegionMeanFilter(nn.Module):
47
+ def __init__(self, module_size: int) -> None:
48
+ super().__init__()
49
+ self.module_size = module_size
50
+ self.conv = nn.Conv2d(
51
+ in_channels=1,
52
+ out_channels=1,
53
+ kernel_size=module_size,
54
+ stride=module_size,
55
+ padding=0,
56
+ bias=None,
57
+ groups=1,
58
+ )
59
+ self._setup_kernel_weights()
60
+
61
+ def _setup_kernel_weights(self) -> None:
62
+ module_center = int(self.module_size / 2)
63
+ radius = math.ceil(self.module_size / 6)
64
+ center_filter = torch.zeros((1, 1, self.module_size, self.module_size))
65
+ center_filter[
66
+ :, :,
67
+ module_center-radius : module_center+radius,
68
+ module_center-radius : module_center+radius,
69
+ ] = 1.0
70
+
71
+ self.conv.weight = nn.Parameter(
72
+ center_filter / center_filter.sum(),
73
+ requires_grad=False,
74
+ )
75
+
76
+ @torch.no_grad()
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ return self.conv(x)
79
+
80
+
81
+ class CenterPixelExtractor(nn.Module):
82
+ def __init__(self, module_size: int) -> None:
83
+ super().__init__()
84
+ self.module_size = module_size
85
+ self.conv = nn.Conv2d(
86
+ in_channels=1,
87
+ out_channels=1,
88
+ kernel_size=module_size,
89
+ stride=module_size,
90
+ padding=0,
91
+ bias=None,
92
+ groups=1,
93
+ )
94
+ self._setup_kernel_weights()
95
+
96
+ def _setup_kernel_weights(self) -> None:
97
+ module_center = int(self.module_size / 2) + 1
98
+ center_filter = torch.zeros((1, 1, self.module_size, self.module_size))
99
+ center_filter[:, :, module_center, module_center] = 1.0
100
+ self.conv.weight = nn.Parameter(center_filter, requires_grad=False)
101
+
102
+ @torch.no_grad()
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ return self.conv(x)
105
+
106
+
107
+ class QRCodeErrorExtractor(nn.Module):
108
+ def __init__(self, module_size: int) -> None:
109
+ super().__init__()
110
+ self.module_size = module_size
111
+ self.region_mean_filter = RegionMeanFilter(module_size)
112
+ self.center_pixel_extractor = CenterPixelExtractor(module_size=module_size)
113
+
114
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
115
+ x_center_mean = self.region_mean_filter(x)
116
+ y_center_pixel = self.center_pixel_extractor(y)
117
+ error_mask = (y_center_pixel == 0) & (x_center_mean > 0.45) | \
118
+ (y_center_pixel == 1) & (x_center_mean < 0.65)
119
+ return error_mask.float()
120
+
121
+
122
+ class ScanningRobustLoss(nn.Module):
123
+ def __init__(self, module_size: int) -> None:
124
+ super().__init__()
125
+ self.gaussian_filter = GaussianFilter(module_size=module_size)
126
+ self.center_filter = RegionMeanFilter(module_size=module_size)
127
+ self.module_error_extractor = QRCodeErrorExtractor(module_size=module_size)
128
+
129
+ def _compute_error(self, image: torch.Tensor, qrcode: torch.Tensor) -> torch.Tensor:
130
+ gray_image = convert_to_gray(image)
131
+ error0 = 2 * torch.relu(gray_image - 0.45) * (1 - qrcode)
132
+ error1 = 2 * torch.relu(0.65 - gray_image) * qrcode
133
+ return error0 + error1
134
+
135
+ def _compute_ealy_stopping_mask(self, image: torch.Tensor, qrcode: torch.Tensor) -> torch.Tensor:
136
+ return self.module_error_extractor(
137
+ convert_to_gray(image.clone().detach()),
138
+ image_binarize(qrcode),
139
+ )
140
+
141
+ def forward(self, image: torch.Tensor, qrcode: torch.Tensor) -> torch.Tensor:
142
+ error = self._compute_error(image, qrcode)
143
+ sample_error = self.gaussian_filter(error)
144
+ mask = self._compute_ealy_stopping_mask(image, qrcode)
145
+ return torch.mean(sample_error * mask)
diffqrcoder/pipeline_diffqrcoder.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionControlNetPipeline
5
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
6
+ from diffusers.image_processor import PipelineImageInput
7
+ from diffusers.models import ControlNetModel
8
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
9
+ from diffusers.pipelines.controlnet import MultiControlNetModel
10
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
11
+ from diffusers.utils import is_torch_xla_available, deprecate
12
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
13
+ from tqdm import trange
14
+
15
+ from .image_processor import image_binarize, crop_padding
16
+ from .srpg import ScanningRobustPerceptualGuidance
17
+
18
+
19
+ if is_torch_xla_available():
20
+ import torch_xla.core.xla_model as xm
21
+
22
+ XLA_AVAILABLE = True
23
+ else:
24
+ XLA_AVAILABLE = False
25
+
26
+
27
+ class DiffQRCoderPipeline(StableDiffusionControlNetPipeline):
28
+ def _run_stage1(
29
+ self,
30
+ prompt: Union[str, List[str]] = None,
31
+ qrcode: PipelineImageInput = None,
32
+ height: Optional[int] = None,
33
+ width: Optional[int] = None,
34
+ num_inference_steps: int = 50,
35
+ timesteps: List[int] = None,
36
+ sigmas: List[float] = None,
37
+ guidance_scale: float = 7.5,
38
+ negative_prompt: Optional[Union[str, List[str]]] = None,
39
+ num_images_per_prompt: Optional[int] = 1,
40
+ eta: float = 0.0,
41
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
42
+ latents: Optional[torch.Tensor] = None,
43
+ prompt_embeds: Optional[torch.Tensor] = None,
44
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
45
+ ip_adapter_image: Optional[PipelineImageInput] = None,
46
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
47
+ output_type: Optional[str] = "pil",
48
+ return_dict: bool = True,
49
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
50
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
51
+ guess_mode: bool = False,
52
+ control_guidance_start: Union[float, List[float]] = 0.0,
53
+ control_guidance_end: Union[float, List[float]] = 1.0,
54
+ clip_skip: Optional[int] = None,
55
+ callback_on_step_end: Optional[
56
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
57
+ ] = None,
58
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
59
+ **kwargs,
60
+ ):
61
+ return super().__call__(
62
+ prompt=prompt,
63
+ image=qrcode,
64
+ height=height,
65
+ width=width,
66
+ num_inference_steps=num_inference_steps,
67
+ timesteps=timesteps,
68
+ sigmas=sigmas,
69
+ guidance_scale=guidance_scale,
70
+ negative_prompt=negative_prompt,
71
+ num_images_per_prompt=num_images_per_prompt,
72
+ eta=eta,
73
+ generator=generator,
74
+ latents=latents,
75
+ prompt_embeds=prompt_embeds,
76
+ negative_prompt_embeds=negative_prompt_embeds,
77
+ ip_adapter_image=ip_adapter_image,
78
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
79
+ output_type=output_type,
80
+ return_dict=True,
81
+ cross_attention_kwargs=cross_attention_kwargs,
82
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
83
+ guess_mode=guess_mode,
84
+ control_guidance_start=control_guidance_start,
85
+ control_guidance_end=control_guidance_end,
86
+ clip_skip=clip_skip,
87
+ callback_on_step_end=callback_on_step_end,
88
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
89
+ **kwargs,
90
+ )
91
+
92
+ def _run_stage2(
93
+ self,
94
+ prompt: Union[str, List[str]] = None,
95
+ qrcode: PipelineImageInput = None,
96
+ qrcode_module_size: int = 20,
97
+ qrcode_padding: int = 78,
98
+ ref_image: PipelineImageInput = None,
99
+ height: Optional[int] = None,
100
+ width: Optional[int] = None,
101
+ num_inference_steps: int = 50,
102
+ timesteps: List[int] = None,
103
+ sigmas: List[float] = None,
104
+ guidance_scale: float = 7.5,
105
+ negative_prompt: Optional[Union[str, List[str]]] = None,
106
+ num_images_per_prompt: Optional[int] = 1,
107
+ eta: float = 0.0,
108
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
109
+ latents: Optional[torch.Tensor] = None,
110
+ prompt_embeds: Optional[torch.Tensor] = None,
111
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
112
+ ip_adapter_image: Optional[PipelineImageInput] = None,
113
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
114
+ output_type: Optional[str] = "pil",
115
+ return_dict: bool = True,
116
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
117
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
118
+ guess_mode: bool = False,
119
+ control_guidance_start: Union[float, List[float]] = 0.0,
120
+ control_guidance_end: Union[float, List[float]] = 1.0,
121
+ scanning_robust_guidance_scale: int = 500,
122
+ perceptual_guidance_scale: int = 10,
123
+ srmpgd_num_iteration: Optional[int] = None,
124
+ srmpgd_lr: Optional[float] = 0.1,
125
+ clip_skip: Optional[int] = None,
126
+ callback_on_step_end: Optional[
127
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
128
+ ] = None,
129
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
130
+ **kwargs,
131
+ ):
132
+ self.srpg = ScanningRobustPerceptualGuidance(
133
+ module_size=qrcode_module_size,
134
+ scanning_robust_guidance_scale=scanning_robust_guidance_scale,
135
+ perceptual_guidance_scale=perceptual_guidance_scale,
136
+ ).to(self.device).to(self.dtype)
137
+
138
+ callback = kwargs.pop("callback", None)
139
+ callback_steps = kwargs.pop("callback_steps", None)
140
+
141
+ if callback is not None:
142
+ deprecate(
143
+ "callback",
144
+ "1.0.0",
145
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
146
+ )
147
+ if callback_steps is not None:
148
+ deprecate(
149
+ "callback_steps",
150
+ "1.0.0",
151
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
152
+ )
153
+
154
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
155
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
156
+
157
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
158
+
159
+ # align format for control guidance
160
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
161
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
162
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
163
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
164
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
165
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
166
+ control_guidance_start, control_guidance_end = (
167
+ mult * [control_guidance_start],
168
+ mult * [control_guidance_end],
169
+ )
170
+
171
+ # 1. Check inputs. Raise error if not correct
172
+ self.check_inputs(
173
+ prompt,
174
+ qrcode,
175
+ callback_steps,
176
+ negative_prompt,
177
+ prompt_embeds,
178
+ negative_prompt_embeds,
179
+ ip_adapter_image,
180
+ ip_adapter_image_embeds,
181
+ controlnet_conditioning_scale,
182
+ control_guidance_start,
183
+ control_guidance_end,
184
+ callback_on_step_end_tensor_inputs,
185
+ )
186
+
187
+ self._guidance_scale = guidance_scale
188
+ self._clip_skip = clip_skip
189
+ self._cross_attention_kwargs = cross_attention_kwargs
190
+ self._interrupt = False
191
+
192
+ # 2. Define call parameters
193
+ if prompt is not None and isinstance(prompt, str):
194
+ batch_size = 1
195
+ elif prompt is not None and isinstance(prompt, list):
196
+ batch_size = len(prompt)
197
+ else:
198
+ batch_size = prompt_embeds.shape[0]
199
+
200
+ device = self._execution_device
201
+
202
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
203
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
204
+
205
+ global_pool_conditions = (
206
+ controlnet.config.global_pool_conditions
207
+ if isinstance(controlnet, ControlNetModel)
208
+ else controlnet.nets[0].config.global_pool_conditions
209
+ )
210
+ guess_mode = guess_mode or global_pool_conditions
211
+
212
+ # 3. Encode input prompt
213
+ text_encoder_lora_scale = (
214
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
215
+ )
216
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
217
+ prompt,
218
+ device,
219
+ num_images_per_prompt,
220
+ self.do_classifier_free_guidance,
221
+ negative_prompt,
222
+ prompt_embeds=prompt_embeds,
223
+ negative_prompt_embeds=negative_prompt_embeds,
224
+ lora_scale=text_encoder_lora_scale,
225
+ clip_skip=self.clip_skip,
226
+ )
227
+ # For classifier free guidance, we need to do two forward passes.
228
+ # Here we concatenate the unconditional and text embeddings into a single batch
229
+ # to avoid doing two forward passes
230
+ if self.do_classifier_free_guidance:
231
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
232
+
233
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
234
+ image_embeds = self.prepare_ip_adapter_image_embeds(
235
+ ip_adapter_image,
236
+ ip_adapter_image_embeds,
237
+ device,
238
+ batch_size * num_images_per_prompt,
239
+ self.do_classifier_free_guidance,
240
+ )
241
+
242
+ # 4. Prepare image
243
+ if isinstance(controlnet, ControlNetModel):
244
+ qrcode = self.prepare_image(
245
+ image=qrcode,
246
+ width=width,
247
+ height=height,
248
+ batch_size=batch_size * num_images_per_prompt,
249
+ num_images_per_prompt=num_images_per_prompt,
250
+ device=device,
251
+ dtype=controlnet.dtype,
252
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
253
+ guess_mode=guess_mode,
254
+ )
255
+ height, width = qrcode.shape[-2:]
256
+ elif isinstance(controlnet, MultiControlNetModel):
257
+ qrcodes = []
258
+
259
+ # Nested lists as ControlNet condition
260
+ if isinstance(qrcode[0], list):
261
+ # Transpose the nested image list
262
+ qrcode = [list(t) for t in zip(*qrcode)]
263
+
264
+ for qrcode_ in qrcode:
265
+ qrcode_ = self.prepare_image(
266
+ image=qrcode_,
267
+ width=width,
268
+ height=height,
269
+ batch_size=batch_size * num_images_per_prompt,
270
+ num_images_per_prompt=num_images_per_prompt,
271
+ device=device,
272
+ dtype=controlnet.dtype,
273
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
274
+ guess_mode=guess_mode,
275
+ )
276
+
277
+ qrcodes.append(qrcode_)
278
+
279
+ qrcode = qrcodes
280
+ height, width = qrcode[0].shape[-2:]
281
+ else:
282
+ assert False
283
+
284
+ # 5. Prepare timesteps
285
+ timesteps, num_inference_steps = retrieve_timesteps(
286
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
287
+ )
288
+ self._num_timesteps = len(timesteps)
289
+
290
+ # 6. Prepare latent variables
291
+ num_channels_latents = self.unet.config.in_channels
292
+ latents = self.prepare_latents(
293
+ batch_size * num_images_per_prompt,
294
+ num_channels_latents,
295
+ height,
296
+ width,
297
+ prompt_embeds.dtype,
298
+ device,
299
+ generator,
300
+ latents,
301
+ )
302
+
303
+ # 6.5 Optionally get Guidance Scale Embedding
304
+ timestep_cond = None
305
+ if self.unet.config.time_cond_proj_dim is not None:
306
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
307
+ timestep_cond = self.get_guidance_scale_embedding(
308
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
309
+ ).to(device=device, dtype=latents.dtype)
310
+
311
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
312
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
313
+
314
+ # 7.1 Add image embeds for IP-Adapter
315
+ added_cond_kwargs = (
316
+ {"image_embeds": image_embeds}
317
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
318
+ else None
319
+ )
320
+
321
+ # 7.2 Create tensor stating which controlnets to keep
322
+ controlnet_keep = []
323
+ for i in range(len(timesteps)):
324
+ keeps = [
325
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
326
+ for s, e in zip(control_guidance_start, control_guidance_end)
327
+ ]
328
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
329
+
330
+ # 8. Denoising loop
331
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
332
+ is_unet_compiled = is_compiled_module(self.unet)
333
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
334
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
335
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
336
+ with torch.enable_grad():
337
+ for i, t in enumerate(timesteps):
338
+ if self.interrupt:
339
+ continue
340
+
341
+ # Relevant thread:
342
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
343
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
344
+ torch._inductor.cudagraph_mark_step_begin()
345
+ # expand the latents if we are doing classifier free guidance
346
+ latents = latents.clone().detach().requires_grad_(True)
347
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
348
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
349
+
350
+ # controlnet(s) inference
351
+ if guess_mode and self.do_classifier_free_guidance:
352
+ # Infer ControlNet only for the conditional batch.
353
+ control_model_input = latents
354
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
355
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
356
+ else:
357
+ control_model_input = latent_model_input
358
+ controlnet_prompt_embeds = prompt_embeds
359
+
360
+ if isinstance(controlnet_keep[i], list):
361
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
362
+ else:
363
+ controlnet_cond_scale = controlnet_conditioning_scale
364
+ if isinstance(controlnet_cond_scale, list):
365
+ controlnet_cond_scale = controlnet_cond_scale[0]
366
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
367
+
368
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
369
+ control_model_input,
370
+ t,
371
+ encoder_hidden_states=controlnet_prompt_embeds,
372
+ controlnet_cond=qrcode,
373
+ conditioning_scale=cond_scale,
374
+ guess_mode=guess_mode,
375
+ return_dict=False,
376
+ )
377
+
378
+ if guess_mode and self.do_classifier_free_guidance:
379
+ # Inferred ControlNet only for the conditional batch.
380
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
381
+ # add 0 to the unconditional batch to keep it unchanged.
382
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
383
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
384
+
385
+ # predict the noise residual
386
+ noise_pred = self.unet(
387
+ latent_model_input,
388
+ t,
389
+ encoder_hidden_states=prompt_embeds,
390
+ timestep_cond=timestep_cond,
391
+ cross_attention_kwargs=self.cross_attention_kwargs,
392
+ down_block_additional_residuals=down_block_res_samples,
393
+ mid_block_additional_residual=mid_block_res_sample,
394
+ added_cond_kwargs=added_cond_kwargs,
395
+ return_dict=False,
396
+ )[0]
397
+
398
+ # perform guidance
399
+ if self.do_classifier_free_guidance:
400
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
401
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
402
+
403
+ # compute the original latents x_t -> x_0
404
+ original_latents = self.scheduler.step(
405
+ noise_pred,
406
+ t,
407
+ latents,
408
+ **extra_step_kwargs,
409
+ return_dict=True,
410
+ ).pred_original_sample
411
+
412
+ original_image = self.vae.decode(
413
+ original_latents / self.vae.config.scaling_factor,
414
+ return_dict=True,
415
+ ).sample
416
+
417
+ # compute the score of Scanninig Robust Perceptual Guidance (SRPG)
418
+ score = self.srpg.compute_score(
419
+ latents=latents,
420
+ image=crop_padding(self.image_processor.denormalize(original_image), qrcode_padding),
421
+ qrcode=crop_padding(image_binarize(qrcode[qrcode.size(0) // 2, None]), qrcode_padding),
422
+ ref_image=crop_padding(ref_image, qrcode_padding),
423
+ )
424
+
425
+ timesteps_prev = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
426
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
427
+ beta_prod_t = 1 - alpha_prod_t
428
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[timesteps_prev] if timesteps_prev >= 0 else self.scheduler.final_alpha_cumprod
429
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
430
+
431
+ noise_pred = noise_pred + (beta_prod_t ** 0.5) * score
432
+ original_latents = (latents - (beta_prod_t ** 0.5) * noise_pred) / alpha_prod_t ** 0.5
433
+ latents = (alpha_prod_t_prev ** 0.5) * original_latents + (beta_prod_t_prev ** 0.5) * noise_pred
434
+
435
+ if callback_on_step_end is not None:
436
+ callback_kwargs = {}
437
+ for k in callback_on_step_end_tensor_inputs:
438
+ callback_kwargs[k] = locals()[k]
439
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
440
+
441
+ latents = callback_outputs.pop("latents", latents)
442
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
443
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
444
+
445
+ # call the callback, if provided
446
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
447
+ progress_bar.update()
448
+ if callback is not None and i % callback_steps == 0:
449
+ step_idx = i // getattr(self.scheduler, "order", 1)
450
+ callback(step_idx, t, latents)
451
+
452
+ if XLA_AVAILABLE:
453
+ xm.mark_step()
454
+
455
+ # perform Scanning Robust Manifold Projected Gradient Descent (SR-MPGD)
456
+ if srmpgd_num_iteration is not None:
457
+ with torch.enable_grad():
458
+ latents = latents.clone().detach().requires_grad_(True)
459
+ optimizer = torch.optim.SGD([latents], lr=srmpgd_lr)
460
+
461
+ for i in trange(srmpgd_num_iteration):
462
+ optimizer.zero_grad()
463
+ original_image = self.vae.decode(latents / self.vae.config.scaling_factor,return_dict=False)[0]
464
+ loss = self.srpg.compute_loss(
465
+ image=crop_padding(self.image_processor.denormalize(original_image), qrcode_padding),
466
+ qrcode=crop_padding(image_binarize(qrcode[qrcode.size(0) // 2, None]), qrcode_padding),
467
+ ref_image=crop_padding(ref_image, qrcode_padding),
468
+ )
469
+ loss.backward()
470
+ optimizer.step()
471
+
472
+ # If we do sequential model offloading, let's offload unet and controlnet
473
+ # manually for max memory savings
474
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
475
+ self.unet.to("cpu")
476
+ self.controlnet.to("cpu")
477
+ torch.cuda.empty_cache()
478
+
479
+ if not output_type == "latent":
480
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
481
+ 0
482
+ ]
483
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
484
+ else:
485
+ image = latents
486
+ has_nsfw_concept = None
487
+
488
+ if has_nsfw_concept is None:
489
+ do_denormalize = [True] * image.shape[0]
490
+ else:
491
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
492
+
493
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
494
+
495
+ # Offload all models
496
+ self.maybe_free_model_hooks()
497
+
498
+ if not return_dict:
499
+ return (image, has_nsfw_concept)
500
+
501
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
502
+
503
+ @torch.no_grad()
504
+ def __call__(
505
+ self,
506
+ prompt: Union[str, List[str]] = None,
507
+ qrcode: PipelineImageInput = None,
508
+ qrcode_module_size: int = 20,
509
+ qrcode_padding: int = 78,
510
+ height: Optional[int] = None,
511
+ width: Optional[int] = None,
512
+ num_inference_steps: int = 50,
513
+ timesteps: List[int] = None,
514
+ sigmas: List[float] = None,
515
+ guidance_scale: float = 7.5,
516
+ negative_prompt: Optional[Union[str, List[str]]] = None,
517
+ num_images_per_prompt: Optional[int] = 1,
518
+ eta: float = 0.0,
519
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
520
+ latents: Optional[torch.Tensor] = None,
521
+ prompt_embeds: Optional[torch.Tensor] = None,
522
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
523
+ ip_adapter_image: Optional[PipelineImageInput] = None,
524
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
525
+ output_type: Optional[str] = "pil",
526
+ return_dict: bool = True,
527
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
528
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
529
+ guess_mode: bool = False,
530
+ control_guidance_start: Union[float, List[float]] = 0.0,
531
+ control_guidance_end: Union[float, List[float]] = 1.0,
532
+ scanning_robust_guidance_scale: int = 500,
533
+ perceptual_guidance_scale: int = 10,
534
+ clip_skip: Optional[int] = None,
535
+ srmpgd_num_iteration: Optional[int] = None,
536
+ srmpgd_lr: Optional[float] = 0.1,
537
+ callback_on_step_end: Optional[
538
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
539
+ ] = None,
540
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
541
+ **kwargs,
542
+ ):
543
+ stage1_output = self._run_stage1(
544
+ prompt=prompt,
545
+ qrcode=qrcode,
546
+ height=height,
547
+ width=width,
548
+ num_inference_steps=num_inference_steps,
549
+ timesteps=timesteps,
550
+ sigmas=sigmas,
551
+ guidance_scale=guidance_scale,
552
+ negative_prompt=negative_prompt,
553
+ num_images_per_prompt=num_images_per_prompt,
554
+ eta=eta,
555
+ generator=generator,
556
+ latents=latents,
557
+ prompt_embeds=prompt_embeds,
558
+ negative_prompt_embeds=negative_prompt_embeds,
559
+ ip_adapter_image=ip_adapter_image,
560
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
561
+ output_type="pt",
562
+ return_dict=False,
563
+ cross_attention_kwargs=cross_attention_kwargs,
564
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
565
+ guess_mode=guess_mode,
566
+ control_guidance_start=control_guidance_start,
567
+ control_guidance_end=control_guidance_end,
568
+ clip_skip=clip_skip,
569
+ callback_on_step_end=callback_on_step_end,
570
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
571
+ )
572
+ stage2_output = self._run_stage2(
573
+ prompt=prompt,
574
+ qrcode=qrcode,
575
+ qrcode_module_size=qrcode_module_size,
576
+ qrcode_padding=qrcode_padding,
577
+ ref_image=stage1_output.images,
578
+ height=height,
579
+ width=width,
580
+ num_inference_steps=num_inference_steps,
581
+ timesteps=timesteps,
582
+ sigmas=sigmas,
583
+ guidance_scale=guidance_scale,
584
+ negative_prompt=negative_prompt,
585
+ num_images_per_prompt=num_images_per_prompt,
586
+ eta=eta,
587
+ generator=generator,
588
+ latents=latents,
589
+ prompt_embeds=prompt_embeds,
590
+ negative_prompt_embeds=negative_prompt_embeds,
591
+ ip_adapter_image=ip_adapter_image,
592
+ ip_adapter_image_embeds=ip_adapter_image_embeds,
593
+ output_type=output_type,
594
+ return_dict=return_dict,
595
+ cross_attention_kwargs=cross_attention_kwargs,
596
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
597
+ guess_mode=guess_mode,
598
+ control_guidance_start=control_guidance_start,
599
+ control_guidance_end=control_guidance_end,
600
+ scanning_robust_guidance_scale=scanning_robust_guidance_scale,
601
+ perceptual_guidance_scale=perceptual_guidance_scale,
602
+ clip_skip=clip_skip,
603
+ srmpgd_num_iteration=srmpgd_num_iteration,
604
+ callback_on_step_end=callback_on_step_end,
605
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
606
+ )
607
+ return stage2_output
diffqrcoder/srpg.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from diffqrcoder.losses import PerceptualLoss, ScanningRobustLoss
5
+
6
+
7
+
8
+ GRADIENT_SCALE = 100
9
+
10
+
11
+ class ScanningRobustPerceptualGuidance(nn.Module):
12
+ def __init__(
13
+ self,
14
+ module_size: int = 20,
15
+ scanning_robust_guidance_scale: int = 500,
16
+ perceptual_guidance_scale: int = 2,
17
+ ):
18
+ super().__init__()
19
+ self.module_size = module_size
20
+ self.scanning_robust_guidance_scale = scanning_robust_guidance_scale
21
+ self.perceptual_guidance_scale = perceptual_guidance_scale
22
+ self.scanning_robust_loss_fn = ScanningRobustLoss(module_size=module_size)
23
+ self.perceptual_loss_fn = PerceptualLoss()
24
+
25
+ def compute_loss(self, image: torch.Tensor, qrcode: torch.Tensor, ref_image: torch.Tensor) -> torch.Tensor:
26
+ return (
27
+ self.scanning_robust_guidance_scale * self.scanning_robust_loss_fn(image, qrcode) + \
28
+ self.perceptual_guidance_scale * self.perceptual_loss_fn(image, ref_image)
29
+ ) * GRADIENT_SCALE
30
+
31
+ def compute_score(self, latents: torch.Tensor, image: torch.Tensor, qrcode: torch.Tensor, ref_image: torch.Tensor) -> torch.Tensor:
32
+ loss = self.compute_loss(image, qrcode, ref_image)
33
+ return torch.autograd.grad(loss, latents)[0] / GRADIENT_SCALE
diffqrcoder/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision.transforms.functional import pil_to_tensor
5
+
6
+
7
+ def convert_pil_to_normalized_tensor(image: Image) -> torch.Tensor:
8
+ return pil_to_tensor(image).unsqueeze(0).float()
9
+
10
+
11
+ def convert_normalized_tensor_to_np_image(image: torch.Tensor) -> np.array:
12
+ return image.clip(0, 1).squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
13
+
14
+
15
+ def add_position_pattern(
16
+ x: torch.Tensor,
17
+ y: torch.Tensor,
18
+ module_num: int,
19
+ module_size: int
20
+ ) -> torch.Tensor:
21
+
22
+ x[: 8 * module_size - 1, : 8 * module_size - 1, :] = \
23
+ y[: 8 * module_size - 1, : 8 * module_size - 1, :]
24
+
25
+ x[
26
+ (module_num - 8) * module_size + 1 : module_num * module_size,
27
+ : 8 * module_size - 1,
28
+ :
29
+ ] = y[
30
+ (module_num - 8) * module_size + 1 : module_num * module_size,
31
+ : 8 * module_size - 1,
32
+ :
33
+ ]
34
+
35
+ x[
36
+ : 8 * module_size - 1,
37
+ (module_num - 8) * module_size + 1 : module_num * module_size,
38
+ :
39
+ ] = y[
40
+ : 8 * module_size - 1,
41
+ (module_num - 8) * module_size + 1 : module_num * module_size,
42
+ :
43
+ ]
44
+
45
+ x[
46
+ (module_num - 9) * module_size : (module_num - 4) * module_size - 1,
47
+ (module_num - 9) * module_size : (module_num - 4) * module_size - 1,
48
+ :
49
+ ] = y[
50
+ (module_num - 9) * module_size : (module_num - 4) * module_size - 1,
51
+ (module_num - 9) * module_size : (module_num - 4) * module_size - 1,
52
+ :
53
+ ]
54
+ return x
diffqrcoder_wrapper.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # diffqrcoder_wrapper.py
2
+ import torch
3
+ from diffusers import ControlNetModel, DDIMScheduler
4
+ from PIL import Image
5
+ import qrcode
6
+
7
+ from diffqrcoder import DiffQRCoderPipeline
8
+
9
+ # ---- Defaults taken from run_diffqrcoder.py ---- #
10
+ CONTROLNET_CKPT = "monster-labs/control_v1p_sd15_qrcode_monster"
11
+ # Original used a direct file URL; we can keep that:
12
+ PIPE_CKPT = (
13
+ "https://huggingface.co/fp16-guy/Cetus-Mix_Whalefall_fp16_cleaned/"
14
+ "resolve/main/cetusMix_Whalefall2_fp16.safetensors"
15
+ )
16
+ # You can also upload that file to the Space and use a local path.
17
+
18
+ DEVICE = "cuda" # ZeroGPU will give us a CUDA device during @spaces.GPU calls
19
+
20
+ # Cache
21
+ _controlnet = None
22
+ _pipe = None
23
+
24
+
25
+ def _make_qr_image(
26
+ data: str,
27
+ box_size: int = 20, # aligns with qrcode_module_size default
28
+ border: int = 4, # typical QR quiet zone in modules
29
+ ) -> Image.Image:
30
+ qr = qrcode.QRCode(
31
+ version=None,
32
+ error_correction=qrcode.constants.ERROR_CORRECT_H,
33
+ box_size=box_size,
34
+ border=border,
35
+ )
36
+ qr.add_data(data)
37
+ qr.make(fit=True)
38
+ img = qr.make_image(fill_color="black", back_color="white").convert("RGB")
39
+ return img
40
+
41
+
42
+ def load_pipeline():
43
+ """
44
+ Lazily load ControlNet + DiffQRCoderPipeline.
45
+ Mirrors run_diffqrcoder.py, but only once.
46
+ """
47
+ global _controlnet, _pipe
48
+
49
+ if _pipe is not None:
50
+ return _pipe
51
+
52
+ # 1. ControlNet
53
+ if _controlnet is None:
54
+ _controlnet = ControlNetModel.from_pretrained(
55
+ CONTROLNET_CKPT,
56
+ torch_dtype=torch.float16,
57
+ )
58
+
59
+ # 2. DiffQRCoderPipeline (from single safetensors file)
60
+ pipe = DiffQRCoderPipeline.from_single_file(
61
+ PIPE_CKPT,
62
+ controlnet=_controlnet,
63
+ torch_dtype=torch.float16,
64
+ use_auth_token=True, # uses the Space's HF token
65
+ )
66
+
67
+ # 3. Use DDIM scheduler as in original script
68
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
69
+
70
+ # Don't call .to("cuda") yet; do it inside the @spaces.GPU function
71
+ _pipe = pipe
72
+ return _pipe
73
+
74
+
75
+ def generate_qr_art(
76
+ url_or_text: str,
77
+ prompt: str,
78
+ neg_prompt: str = "easynegative",
79
+ num_inference_steps: int = 40,
80
+ qrcode_module_size: int = 20,
81
+ qrcode_padding: int = 78,
82
+ controlnet_conditioning_scale: float = 1.35,
83
+ scanning_robust_guidance_scale: float = 500.0,
84
+ perceptual_guidance_scale: float = 2.0,
85
+ srmpgd_num_iteration: int | None = None,
86
+ srmpgd_lr: float = 0.1,
87
+ seed: int = 1,
88
+ ) -> Image.Image:
89
+ """
90
+ Directly mirrors the call at the bottom of run_diffqrcoder.py,
91
+ but takes the QR content + prompt as arguments and returns a PIL image.
92
+ """
93
+ pipe = load_pipeline()
94
+
95
+ # ZeroGPU will ensure DEVICE exists as "cuda" when we call this
96
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
97
+
98
+ # Create QR image in-memory instead of loading from disk
99
+ qrcode_img = _make_qr_image(
100
+ data=url_or_text,
101
+ box_size=qrcode_module_size, # roughly aligned
102
+ border=4, # module-based border; padding param handles extra
103
+ )
104
+
105
+ pipe = pipe.to(DEVICE)
106
+
107
+ result = pipe(
108
+ prompt=prompt,
109
+ qrcode=qrcode_img,
110
+ qrcode_module_size=qrcode_module_size,
111
+ qrcode_padding=qrcode_padding,
112
+ negative_prompt=neg_prompt,
113
+ num_inference_steps=num_inference_steps,
114
+ generator=generator,
115
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
116
+ scanning_robust_guidance_scale=scanning_robust_guidance_scale,
117
+ perceptual_guidance_scale=perceptual_guidance_scale,
118
+ srmpgd_num_iteration=srmpgd_num_iteration,
119
+ srmpgd_lr=srmpgd_lr,
120
+ )
121
+ return result.images[0]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ spaces
3
+
4
+ torch==2.1.2
5
+ diffusers>=0.26.0
6
+ transformers
7
+ accelerate
8
+ safetensors
9
+ opencv-python
10
+ numpy
11
+ Pillow
12
+ qrcode