dikdimon commited on
Commit
f8680ab
·
verified ·
1 Parent(s): 6a81e1f

Upload hm using SD-Hub extension

Browse files
hm/.ipynb_checkpoints/processing-checkpoint.py ADDED
@@ -0,0 +1,1838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ import sys
7
+ import hashlib
8
+ from dataclasses import dataclass, field
9
+
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image, ImageOps
13
+ import random
14
+ import cv2
15
+ from skimage import exposure
16
+ from typing import Any
17
+
18
+ import modules.sd_hijack
19
+ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
20
+ from modules.rng import slerp # noqa: F401
21
+ from modules.sd_hijack import model_hijack
22
+ from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
23
+ from modules.shared import opts, cmd_opts, state
24
+ import modules.shared as shared
25
+ import modules.paths as paths
26
+ import modules.face_restoration
27
+ import modules.images as images
28
+ import modules.styles
29
+ import modules.sd_models as sd_models
30
+ import modules.sd_vae as sd_vae
31
+ from ldm.data.util import AddMiDaS
32
+ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
33
+
34
+ from einops import repeat, rearrange
35
+ from blendmodes.blend import blendLayers, BlendType
36
+
37
+
38
+ # some of those options should not be changed at all because they would break the model, so I removed them from options.
39
+ opt_C = 4
40
+ opt_f = 8
41
+
42
+
43
+ def setup_color_correction(image):
44
+ logging.info("Calibrating color correction.")
45
+ correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
46
+ return correction_target
47
+
48
+
49
+ def apply_color_correction(correction, original_image):
50
+ logging.info("Applying color correction.")
51
+ image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
52
+ cv2.cvtColor(
53
+ np.asarray(original_image),
54
+ cv2.COLOR_RGB2LAB
55
+ ),
56
+ correction,
57
+ channel_axis=2
58
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
59
+
60
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
61
+
62
+ return image.convert('RGB')
63
+
64
+
65
+ def uncrop(image, dest_size, paste_loc):
66
+ x, y, w, h = paste_loc
67
+ base_image = Image.new('RGBA', dest_size)
68
+ image = images.resize_image(1, image, w, h)
69
+ base_image.paste(image, (x, y))
70
+ image = base_image
71
+
72
+ return image
73
+
74
+
75
+ def apply_overlay(image, paste_loc, overlay):
76
+ if overlay is None:
77
+ return image, image.copy()
78
+
79
+ if paste_loc is not None:
80
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
81
+
82
+ original_denoised_image = image.copy()
83
+
84
+ image = image.convert('RGBA')
85
+ image.alpha_composite(overlay)
86
+ image = image.convert('RGB')
87
+
88
+ return image, original_denoised_image
89
+
90
+ def create_binary_mask(image, round=True):
91
+ if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
92
+ if round:
93
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
94
+ else:
95
+ image = image.split()[-1].convert("L")
96
+ else:
97
+ image = image.convert('L')
98
+ return image
99
+
100
+ def txt2img_image_conditioning(sd_model, x, width, height):
101
+ if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
102
+
103
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
104
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
105
+ image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
106
+
107
+ # Add the fake full 1s mask to the first dimension.
108
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
109
+ image_conditioning = image_conditioning.to(x.dtype)
110
+
111
+ return image_conditioning
112
+
113
+ elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
114
+
115
+ return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
116
+
117
+ else:
118
+ if sd_model.is_sdxl_inpaint:
119
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
120
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
121
+ image_conditioning = images_tensor_to_samples(image_conditioning,
122
+ approximation_indexes.get(opts.sd_vae_encode_method))
123
+
124
+ # Add the fake full 1s mask to the first dimension.
125
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
126
+ image_conditioning = image_conditioning.to(x.dtype)
127
+
128
+ return image_conditioning
129
+
130
+ # Dummy zero conditioning if we're not using inpainting or unclip models.
131
+ # Still takes up a bit of memory, but no encoder call.
132
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
133
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
134
+
135
+
136
+ @dataclass(repr=False)
137
+ class StableDiffusionProcessing:
138
+ sd_model: object = None
139
+ outpath_samples: str = None
140
+ outpath_grids: str = None
141
+ prompt: str = ""
142
+ prompt_for_display: str = None
143
+ negative_prompt: str = ""
144
+ styles: list[str] = None
145
+ seed: int = -1
146
+ subseed: int = -1
147
+ subseed_strength: float = 0
148
+ seed_resize_from_h: int = -1
149
+ seed_resize_from_w: int = -1
150
+ seed_enable_extras: bool = True
151
+ sampler_name: str = None
152
+ scheduler: str = None
153
+ batch_size: int = 1
154
+ n_iter: int = 1
155
+ steps: int = 50
156
+ cfg_scale: float = 7.0
157
+ width: int = 512
158
+ height: int = 512
159
+ restore_faces: bool = None
160
+ tiling: bool = None
161
+ do_not_save_samples: bool = False
162
+ do_not_save_grid: bool = False
163
+ extra_generation_params: dict[str, Any] = None
164
+ overlay_images: list = None
165
+ eta: float = None
166
+ do_not_reload_embeddings: bool = False
167
+ denoising_strength: float = None
168
+ ddim_discretize: str = None
169
+ s_min_uncond: float = None
170
+ s_churn: float = None
171
+ s_tmax: float = None
172
+ s_tmin: float = None
173
+ s_noise: float = None
174
+ override_settings: dict[str, Any] = None
175
+ override_settings_restore_afterwards: bool = True
176
+ sampler_index: int = None
177
+ refiner_checkpoint: str = None
178
+ refiner_switch_at: float = None
179
+ token_merging_ratio = 0
180
+ token_merging_ratio_hr = 0
181
+ disable_extra_networks: bool = False
182
+ firstpass_image: Image = None
183
+
184
+ scripts_value: scripts.ScriptRunner = field(default=None, init=False)
185
+ script_args_value: list = field(default=None, init=False)
186
+ scripts_setup_complete: bool = field(default=False, init=False)
187
+
188
+ cached_uc = [None, None]
189
+ cached_c = [None, None]
190
+
191
+ comments: dict = None
192
+ sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
193
+ is_using_inpainting_conditioning: bool = field(default=False, init=False)
194
+ paste_to: tuple | None = field(default=None, init=False)
195
+
196
+ is_hr_pass: bool = field(default=False, init=False)
197
+
198
+ c: tuple = field(default=None, init=False)
199
+ uc: tuple = field(default=None, init=False)
200
+
201
+ rng: rng.ImageRNG | None = field(default=None, init=False)
202
+ step_multiplier: int = field(default=1, init=False)
203
+ color_corrections: list = field(default=None, init=False)
204
+
205
+ all_prompts: list = field(default=None, init=False)
206
+ all_negative_prompts: list = field(default=None, init=False)
207
+ all_seeds: list = field(default=None, init=False)
208
+ all_subseeds: list = field(default=None, init=False)
209
+ iteration: int = field(default=0, init=False)
210
+ main_prompt: str = field(default=None, init=False)
211
+ main_negative_prompt: str = field(default=None, init=False)
212
+
213
+ prompts: list = field(default=None, init=False)
214
+ negative_prompts: list = field(default=None, init=False)
215
+ seeds: list = field(default=None, init=False)
216
+ subseeds: list = field(default=None, init=False)
217
+ extra_network_data: dict = field(default=None, init=False)
218
+
219
+ user: str = field(default=None, init=False)
220
+
221
+ sd_model_name: str = field(default=None, init=False)
222
+ sd_model_hash: str = field(default=None, init=False)
223
+ sd_vae_name: str = field(default=None, init=False)
224
+ sd_vae_hash: str = field(default=None, init=False)
225
+
226
+ is_api: bool = field(default=False, init=False)
227
+
228
+ def __post_init__(self):
229
+ if self.sampler_index is not None:
230
+ print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
231
+
232
+ self.comments = {}
233
+
234
+ if self.styles is None:
235
+ self.styles = []
236
+
237
+ self.sampler_noise_scheduler_override = None
238
+
239
+ self.extra_generation_params = self.extra_generation_params or {}
240
+ self.override_settings = self.override_settings or {}
241
+ self.script_args = self.script_args or {}
242
+
243
+ self.refiner_checkpoint_info = None
244
+
245
+ if not self.seed_enable_extras:
246
+ self.subseed = -1
247
+ self.subseed_strength = 0
248
+ self.seed_resize_from_h = 0
249
+ self.seed_resize_from_w = 0
250
+
251
+ self.cached_uc = StableDiffusionProcessing.cached_uc
252
+ self.cached_c = StableDiffusionProcessing.cached_c
253
+
254
+ def fill_fields_from_opts(self):
255
+ self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
256
+ self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
257
+ self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
258
+ self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
259
+ self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
260
+
261
+ @property
262
+ def sd_model(self):
263
+ return shared.sd_model
264
+
265
+ @sd_model.setter
266
+ def sd_model(self, value):
267
+ pass
268
+
269
+ @property
270
+ def scripts(self):
271
+ return self.scripts_value
272
+
273
+ @scripts.setter
274
+ def scripts(self, value):
275
+ self.scripts_value = value
276
+
277
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
278
+ self.setup_scripts()
279
+
280
+ @property
281
+ def script_args(self):
282
+ return self.script_args_value
283
+
284
+ @script_args.setter
285
+ def script_args(self, value):
286
+ self.script_args_value = value
287
+
288
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
289
+ self.setup_scripts()
290
+
291
+ def setup_scripts(self):
292
+ self.scripts_setup_complete = True
293
+
294
+ self.scripts.setup_scrips(self, is_ui=not self.is_api)
295
+
296
+ def comment(self, text):
297
+ self.comments[text] = 1
298
+
299
+ def txt2img_image_conditioning(self, x, width=None, height=None):
300
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
301
+
302
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
303
+
304
+ def depth2img_image_conditioning(self, source_image):
305
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
306
+ transformer = AddMiDaS(model_type="dpt_hybrid")
307
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
308
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
309
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
310
+
311
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
312
+ conditioning = torch.nn.functional.interpolate(
313
+ self.sd_model.depth_model(midas_in),
314
+ size=conditioning_image.shape[2:],
315
+ mode="bicubic",
316
+ align_corners=False,
317
+ )
318
+
319
+ (depth_min, depth_max) = torch.aminmax(conditioning)
320
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
321
+ return conditioning
322
+
323
+ def edit_image_conditioning(self, source_image):
324
+ conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
325
+
326
+ return conditioning_image
327
+
328
+ def unclip_image_conditioning(self, source_image):
329
+ c_adm = self.sd_model.embedder(source_image)
330
+ if self.sd_model.noise_augmentor is not None:
331
+ noise_level = 0 # TODO: Allow other noise levels?
332
+ c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
333
+ c_adm = torch.cat((c_adm, noise_level_emb), 1)
334
+ return c_adm
335
+
336
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
337
+ self.is_using_inpainting_conditioning = True
338
+
339
+ # Handle the different mask inputs
340
+ if image_mask is not None:
341
+ if torch.is_tensor(image_mask):
342
+ conditioning_mask = image_mask
343
+ else:
344
+ conditioning_mask = np.array(image_mask.convert("L"))
345
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
346
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
347
+
348
+ if round_image_mask:
349
+ # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
350
+ conditioning_mask = torch.round(conditioning_mask)
351
+
352
+ else:
353
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
354
+
355
+ # Create another latent image, this time with a masked version of the original input.
356
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
357
+ conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
358
+ conditioning_image = torch.lerp(
359
+ source_image,
360
+ source_image * (1.0 - conditioning_mask),
361
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
362
+ )
363
+
364
+ # Encode the new masked image using first stage of network.
365
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
366
+
367
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
368
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
369
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
370
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
371
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
372
+
373
+ return image_conditioning
374
+
375
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
376
+ source_image = devices.cond_cast_float(source_image)
377
+
378
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
379
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
380
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
381
+ return self.depth2img_image_conditioning(source_image)
382
+
383
+ if self.sd_model.cond_stage_key == "edit":
384
+ return self.edit_image_conditioning(source_image)
385
+
386
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
387
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
388
+
389
+ if self.sampler.conditioning_key == "crossattn-adm":
390
+ return self.unclip_image_conditioning(source_image)
391
+
392
+ if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:
393
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
394
+
395
+ # Dummy zero conditioning if we're not using inpainting or depth model.
396
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
397
+
398
+ def init(self, all_prompts, all_seeds, all_subseeds):
399
+ pass
400
+
401
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
402
+ raise NotImplementedError()
403
+
404
+ def close(self):
405
+ self.sampler = None
406
+ self.c = None
407
+ self.uc = None
408
+ if not opts.persistent_cond_cache:
409
+ StableDiffusionProcessing.cached_c = [None, None]
410
+ StableDiffusionProcessing.cached_uc = [None, None]
411
+
412
+ def get_token_merging_ratio(self, for_hr=False):
413
+ if for_hr:
414
+ return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
415
+
416
+ return self.token_merging_ratio or opts.token_merging_ratio
417
+
418
+ def setup_prompts(self):
419
+ if isinstance(self.prompt,list):
420
+ self.all_prompts = self.prompt
421
+ elif isinstance(self.negative_prompt, list):
422
+ self.all_prompts = [self.prompt] * len(self.negative_prompt)
423
+ else:
424
+ self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
425
+
426
+ if isinstance(self.negative_prompt, list):
427
+ self.all_negative_prompts = self.negative_prompt
428
+ else:
429
+ self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)
430
+
431
+ if len(self.all_prompts) != len(self.all_negative_prompts):
432
+ raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
433
+
434
+ self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
435
+ self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
436
+
437
+ self.main_prompt = self.all_prompts[0]
438
+ self.main_negative_prompt = self.all_negative_prompts[0]
439
+
440
+ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
441
+ """Returns parameters that invalidate the cond cache if changed"""
442
+
443
+ return (
444
+ required_prompts,
445
+ steps,
446
+ hires_steps,
447
+ use_old_scheduling,
448
+ opts.CLIP_stop_at_last_layers,
449
+ shared.sd_model.sd_checkpoint_info,
450
+ extra_network_data,
451
+ opts.sdxl_crop_left,
452
+ opts.sdxl_crop_top,
453
+ self.width,
454
+ self.height,
455
+ opts.fp8_storage,
456
+ opts.cache_fp16_weight,
457
+ opts.emphasis,
458
+ )
459
+
460
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
461
+ """
462
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
463
+ using a cache to store the result if the same arguments have been used before.
464
+
465
+ cache is an array containing two elements. The first element is a tuple
466
+ representing the previously used arguments, or None if no arguments
467
+ have been used before. The second element is where the previously
468
+ computed result is stored.
469
+
470
+ caches is a list with items described above.
471
+ """
472
+
473
+ if shared.opts.use_old_scheduling:
474
+ old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
475
+ new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
476
+ if old_schedules != new_schedules:
477
+ self.extra_generation_params["Old prompt editing timelines"] = True
478
+
479
+ cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
480
+
481
+ for cache in caches:
482
+ if cache[0] is not None and cached_params == cache[0]:
483
+ return cache[1]
484
+
485
+ cache = caches[0]
486
+
487
+ with devices.autocast():
488
+ cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
489
+
490
+ cache[0] = cached_params
491
+ return cache[1]
492
+
493
+ def setup_conds(self):
494
+ prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
495
+ negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
496
+
497
+ sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
498
+ total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
499
+ self.step_multiplier = total_steps // self.steps
500
+ self.firstpass_steps = total_steps
501
+
502
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
503
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
504
+
505
+ def get_conds(self):
506
+ return self.c, self.uc
507
+
508
+ def parse_extra_network_prompts(self):
509
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
510
+
511
+ def save_samples(self) -> bool:
512
+ """Returns whether generated images need to be written to disk"""
513
+ return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
514
+
515
+
516
+ class Processed:
517
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
518
+ self.images = images_list
519
+ self.prompt = p.prompt
520
+ self.negative_prompt = p.negative_prompt
521
+ self.seed = seed
522
+ self.subseed = subseed
523
+ self.subseed_strength = p.subseed_strength
524
+ self.info = info
525
+ self.comments = "".join(f"{comment}\n" for comment in p.comments)
526
+ self.width = p.width
527
+ self.height = p.height
528
+ self.sampler_name = p.sampler_name
529
+ self.cfg_scale = p.cfg_scale
530
+ self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
531
+ self.steps = p.steps
532
+ self.batch_size = p.batch_size
533
+ self.restore_faces = p.restore_faces
534
+ self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
535
+ self.sd_model_name = p.sd_model_name
536
+ self.sd_model_hash = p.sd_model_hash
537
+ self.sd_vae_name = p.sd_vae_name
538
+ self.sd_vae_hash = p.sd_vae_hash
539
+ self.seed_resize_from_w = p.seed_resize_from_w
540
+ self.seed_resize_from_h = p.seed_resize_from_h
541
+ self.denoising_strength = getattr(p, 'denoising_strength', None)
542
+ self.extra_generation_params = p.extra_generation_params
543
+ self.index_of_first_image = index_of_first_image
544
+ self.styles = p.styles
545
+ self.job_timestamp = state.job_timestamp
546
+ self.clip_skip = opts.CLIP_stop_at_last_layers
547
+ self.token_merging_ratio = p.token_merging_ratio
548
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
549
+
550
+ self.eta = p.eta
551
+ self.ddim_discretize = p.ddim_discretize
552
+ self.s_churn = p.s_churn
553
+ self.s_tmin = p.s_tmin
554
+ self.s_tmax = p.s_tmax
555
+ self.s_noise = p.s_noise
556
+ self.s_min_uncond = p.s_min_uncond
557
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
558
+ self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
559
+ self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
560
+ self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
561
+ self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
562
+ self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
563
+
564
+ self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
565
+ self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
566
+ self.all_seeds = all_seeds or p.all_seeds or [self.seed]
567
+ self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
568
+ self.infotexts = infotexts or [info] * len(images_list)
569
+ self.version = program_version()
570
+
571
+ def js(self):
572
+ obj = {
573
+ "prompt": self.all_prompts[0],
574
+ "all_prompts": self.all_prompts,
575
+ "negative_prompt": self.all_negative_prompts[0],
576
+ "all_negative_prompts": self.all_negative_prompts,
577
+ "seed": self.seed,
578
+ "all_seeds": self.all_seeds,
579
+ "subseed": self.subseed,
580
+ "all_subseeds": self.all_subseeds,
581
+ "subseed_strength": self.subseed_strength,
582
+ "width": self.width,
583
+ "height": self.height,
584
+ "sampler_name": self.sampler_name,
585
+ "cfg_scale": self.cfg_scale,
586
+ "steps": self.steps,
587
+ "batch_size": self.batch_size,
588
+ "restore_faces": self.restore_faces,
589
+ "face_restoration_model": self.face_restoration_model,
590
+ "sd_model_name": self.sd_model_name,
591
+ "sd_model_hash": self.sd_model_hash,
592
+ "sd_vae_name": self.sd_vae_name,
593
+ "sd_vae_hash": self.sd_vae_hash,
594
+ "seed_resize_from_w": self.seed_resize_from_w,
595
+ "seed_resize_from_h": self.seed_resize_from_h,
596
+ "denoising_strength": self.denoising_strength,
597
+ "extra_generation_params": self.extra_generation_params,
598
+ "index_of_first_image": self.index_of_first_image,
599
+ "infotexts": self.infotexts,
600
+ "styles": self.styles,
601
+ "job_timestamp": self.job_timestamp,
602
+ "clip_skip": self.clip_skip,
603
+ "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
604
+ "version": self.version,
605
+ }
606
+
607
+ return json.dumps(obj, default=lambda o: None)
608
+
609
+ def infotext(self, p: StableDiffusionProcessing, index):
610
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
611
+
612
+ def get_token_merging_ratio(self, for_hr=False):
613
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
614
+
615
+
616
+ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
617
+ g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
618
+ return g.next()
619
+
620
+
621
+ class DecodedSamples(list):
622
+ already_decoded = True
623
+
624
+
625
+ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
626
+ samples = DecodedSamples()
627
+
628
+ if check_for_nans:
629
+ devices.test_for_nans(batch, "unet")
630
+
631
+ for i in range(batch.shape[0]):
632
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
633
+
634
+ if check_for_nans:
635
+
636
+ try:
637
+ devices.test_for_nans(sample, "vae")
638
+ except devices.NansException as e:
639
+ if shared.opts.auto_vae_precision_bfloat16:
640
+ autofix_dtype = torch.bfloat16
641
+ autofix_dtype_text = "bfloat16"
642
+ autofix_dtype_setting = "Automatically convert VAE to bfloat16"
643
+ autofix_dtype_comment = ""
644
+ elif shared.opts.auto_vae_precision:
645
+ autofix_dtype = torch.float32
646
+ autofix_dtype_text = "32-bit float"
647
+ autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
648
+ autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
649
+ else:
650
+ raise e
651
+
652
+ if devices.dtype_vae == autofix_dtype:
653
+ raise e
654
+
655
+ errors.print_error_explanation(
656
+ "A tensor with all NaNs was produced in VAE.\n"
657
+ f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
658
+ f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
659
+ )
660
+
661
+ devices.dtype_vae = autofix_dtype
662
+ model.first_stage_model.to(devices.dtype_vae)
663
+ batch = batch.to(devices.dtype_vae)
664
+
665
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
666
+
667
+ if target_device is not None:
668
+ sample = sample.to(target_device)
669
+
670
+ samples.append(sample)
671
+
672
+ return samples
673
+
674
+
675
+ def get_fixed_seed(seed):
676
+ if seed == '' or seed is None:
677
+ seed = -1
678
+ elif isinstance(seed, str):
679
+ try:
680
+ seed = int(seed)
681
+ except Exception:
682
+ seed = -1
683
+
684
+ if seed == -1:
685
+ return int(random.randrange(4294967294))
686
+
687
+ return seed
688
+
689
+
690
+ def fix_seed(p):
691
+ p.seed = get_fixed_seed(p.seed)
692
+ p.subseed = get_fixed_seed(p.subseed)
693
+
694
+
695
+ def program_version():
696
+ import launch
697
+
698
+ res = launch.git_tag()
699
+ if res == "<none>":
700
+ res = None
701
+
702
+ return res
703
+
704
+
705
+ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
706
+ """
707
+ this function is used to generate the infotext that is stored in the generated images, it's contains the parameters that are required to generate the imagee
708
+ Args:
709
+ p: StableDiffusionProcessing
710
+ all_prompts: list[str]
711
+ all_seeds: list[int]
712
+ all_subseeds: list[int]
713
+ comments: list[str]
714
+ iteration: int
715
+ position_in_batch: int
716
+ use_main_prompt: bool
717
+ index: int
718
+ all_negative_prompts: list[str]
719
+
720
+ Returns: str
721
+
722
+ Extra generation params
723
+ p.extra_generation_params dictionary allows for additional parameters to be added to the infotext
724
+ this can be use by the base webui or extensions.
725
+ To add a new entry, add a new key value pair, the dictionary key will be used as the key of the parameter in the infotext
726
+ the value generation_params can be defined as:
727
+ - str | None
728
+ - List[str|None]
729
+ - callable func(**kwargs) -> str | None
730
+
731
+ When defined as a string, it will be used as without extra processing; this is this most common use case.
732
+
733
+ Defining as a list allows for parameter that changes across images in the job, for example, the 'Seed' parameter.
734
+ The list should have the same length as the total number of images in the entire job.
735
+
736
+ Defining as a callable function allows parameter cannot be generated earlier or when extra logic is required.
737
+ For example 'Hires prompt', due to reasons the hr_prompt might be changed by process in the pipeline or extensions
738
+ and may vary across different images, defining as a static string or list would not work.
739
+
740
+ The function takes locals() as **kwargs, as such will have access to variables like 'p' and 'index'.
741
+ the base signature of the function should be:
742
+ func(**kwargs) -> str | None
743
+ optionally it can have additional arguments that will be used in the function:
744
+ func(p, index, **kwargs) -> str | None
745
+ note: for better future compatibility even though this function will have access to all variables in the locals(),
746
+ it is recommended to only use the arguments present in the function signature of create_infotext.
747
+ For actual implementation examples, see StableDiffusionProcessingTxt2Img.init > get_hr_prompt.
748
+ """
749
+
750
+ if use_main_prompt:
751
+ index = 0
752
+ elif index is None:
753
+ index = position_in_batch + iteration * p.batch_size
754
+
755
+ if all_negative_prompts is None:
756
+ all_negative_prompts = p.all_negative_prompts
757
+
758
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
759
+ enable_hr = getattr(p, 'enable_hr', False)
760
+ token_merging_ratio = p.get_token_merging_ratio()
761
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
762
+
763
+ prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
764
+ negative_prompt = p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]
765
+
766
+ uses_ensd = opts.eta_noise_seed_delta != 0
767
+ if uses_ensd:
768
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
769
+
770
+ generation_params = {
771
+ "Steps": p.steps,
772
+ "Sampler": p.sampler_name,
773
+ "Schedule type": p.scheduler,
774
+ "CFG scale": p.cfg_scale,
775
+ "Image CFG scale": getattr(p, 'image_cfg_scale', None),
776
+ "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
777
+ "Face restoration": opts.face_restoration_model if p.restore_faces else None,
778
+ "Size": f"{p.width}x{p.height}",
779
+ "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
780
+ "Model": p.sd_model_name if opts.add_model_name_to_info else None,
781
+ "FP8 weight": opts.fp8_storage if devices.fp8 else None,
782
+ "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
783
+ "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
784
+ "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
785
+ "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
786
+ "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
787
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
788
+ "Denoising strength": p.extra_generation_params.get("Denoising strength"),
789
+ "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
790
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
791
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
792
+ "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
793
+ "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
794
+ "Init image hash": getattr(p, 'init_img_hash', None),
795
+ "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
796
+ "Tiling": "True" if p.tiling else None,
797
+ **p.extra_generation_params,
798
+ "Version": program_version() if opts.add_version_to_infotext else None,
799
+ "User": p.user if opts.add_user_name_to_info else None,
800
+ }
801
+
802
+ for key, value in generation_params.items():
803
+ try:
804
+ if isinstance(value, list):
805
+ generation_params[key] = value[index]
806
+ elif callable(value):
807
+ generation_params[key] = value(**locals())
808
+ except Exception:
809
+ errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
810
+ generation_params[key] = None
811
+
812
+ generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
813
+
814
+ negative_prompt_text = f"\nNegative prompt: {negative_prompt}" if negative_prompt else ""
815
+
816
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
817
+
818
+
819
+ def process_images(p: StableDiffusionProcessing) -> Processed:
820
+ if p.scripts is not None:
821
+ p.scripts.before_process(p)
822
+
823
+ stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
824
+
825
+ try:
826
+ # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
827
+ # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
828
+ if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
829
+ p.override_settings.pop('sd_model_checkpoint', None)
830
+ sd_models.reload_model_weights()
831
+
832
+ for k, v in p.override_settings.items():
833
+ opts.set(k, v, is_api=True, run_callbacks=False)
834
+
835
+ if k == 'sd_model_checkpoint':
836
+ sd_models.reload_model_weights()
837
+
838
+ if k == 'sd_vae':
839
+ sd_vae.reload_vae_weights()
840
+
841
+ sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
842
+
843
+ # backwards compatibility, fix sampler and scheduler if invalid
844
+ sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
845
+
846
+ with profiling.Profiler():
847
+ res = process_images_inner(p)
848
+
849
+ finally:
850
+ sd_models.apply_token_merging(p.sd_model, 0)
851
+
852
+ # restore opts to original state
853
+ if p.override_settings_restore_afterwards:
854
+ for k, v in stored_opts.items():
855
+ setattr(opts, k, v)
856
+
857
+ if k == 'sd_vae':
858
+ sd_vae.reload_vae_weights()
859
+
860
+ return res
861
+
862
+
863
+ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
864
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
865
+
866
+ if isinstance(p.prompt, list):
867
+ assert(len(p.prompt) > 0)
868
+ else:
869
+ assert p.prompt is not None
870
+
871
+ devices.torch_gc()
872
+
873
+ seed = get_fixed_seed(p.seed)
874
+ subseed = get_fixed_seed(p.subseed)
875
+
876
+ if p.restore_faces is None:
877
+ p.restore_faces = opts.face_restoration
878
+
879
+ if p.tiling is None:
880
+ p.tiling = opts.tiling
881
+
882
+ if p.refiner_checkpoint not in (None, "", "None", "none"):
883
+ p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
884
+ if p.refiner_checkpoint_info is None:
885
+ raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
886
+
887
+ if hasattr(shared.sd_model, 'fix_dimensions'):
888
+ p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)
889
+
890
+ p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
891
+ p.sd_model_hash = shared.sd_model.sd_model_hash
892
+ p.sd_vae_name = sd_vae.get_loaded_vae_name()
893
+ p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
894
+
895
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
896
+ modules.sd_hijack.model_hijack.clear_comments()
897
+
898
+ p.fill_fields_from_opts()
899
+ p.setup_prompts()
900
+
901
+ if isinstance(seed, list):
902
+ p.all_seeds = seed
903
+ else:
904
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
905
+
906
+ if isinstance(subseed, list):
907
+ p.all_subseeds = subseed
908
+ else:
909
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
910
+
911
+ if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
912
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
913
+
914
+ if p.scripts is not None:
915
+ p.scripts.process(p)
916
+
917
+ infotexts = []
918
+ output_images = []
919
+ with torch.no_grad(), p.sd_model.ema_scope():
920
+ with devices.autocast():
921
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
922
+
923
+ # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
924
+ if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
925
+ sd_vae_approx.model()
926
+
927
+ sd_unet.apply_unet()
928
+
929
+ if state.job_count == -1:
930
+ state.job_count = p.n_iter
931
+
932
+ for n in range(p.n_iter):
933
+ p.iteration = n
934
+
935
+ if state.skipped:
936
+ state.skipped = False
937
+
938
+ if state.interrupted or state.stopping_generation:
939
+ break
940
+
941
+ sd_models.reload_model_weights() # model can be changed for example by refiner
942
+
943
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
944
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
945
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
946
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
947
+
948
+ latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
949
+ p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
950
+
951
+ if p.scripts is not None:
952
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
953
+
954
+ if len(p.prompts) == 0:
955
+ break
956
+
957
+ p.parse_extra_network_prompts()
958
+
959
+ if not p.disable_extra_networks:
960
+ with devices.autocast():
961
+ extra_networks.activate(p, p.extra_network_data)
962
+
963
+ if p.scripts is not None:
964
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
965
+
966
+ p.setup_conds()
967
+
968
+ p.extra_generation_params.update(model_hijack.extra_generation_params)
969
+
970
+ # params.txt should be saved after scripts.process_batch, since the
971
+ # infotext could be modified by that callback
972
+ # Example: a wildcard processed by process_batch sets an extra model
973
+ # strength, which is saved as "Model Strength: 1.0" in the infotext
974
+ if n == 0 and not cmd_opts.no_prompt_history:
975
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
976
+ processed = Processed(p, [])
977
+ file.write(processed.infotext(p, 0))
978
+
979
+ for comment in model_hijack.comments:
980
+ p.comment(comment)
981
+
982
+ if p.n_iter > 1:
983
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
984
+
985
+ sd_models.apply_alpha_schedule_override(p.sd_model, p)
986
+
987
+ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
988
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
989
+
990
+ if p.scripts is not None:
991
+ ps = scripts.PostSampleArgs(samples_ddim)
992
+ p.scripts.post_sample(p, ps)
993
+ samples_ddim = ps.samples
994
+
995
+ if getattr(samples_ddim, 'already_decoded', False):
996
+ x_samples_ddim = samples_ddim
997
+ else:
998
+ devices.test_for_nans(samples_ddim, "unet")
999
+
1000
+ if opts.sd_vae_decode_method != 'Full':
1001
+ p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
1002
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
1003
+
1004
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
1005
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
1006
+
1007
+ del samples_ddim
1008
+
1009
+ if lowvram.is_enabled(shared.sd_model):
1010
+ lowvram.send_everything_to_cpu()
1011
+
1012
+ devices.torch_gc()
1013
+
1014
+ state.nextjob()
1015
+
1016
+ if p.scripts is not None:
1017
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
1018
+
1019
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1020
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1021
+
1022
+ batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
1023
+ p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
1024
+ x_samples_ddim = batch_params.images
1025
+
1026
+ def infotext(index=0, use_main_prompt=False):
1027
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
1028
+
1029
+ save_samples = p.save_samples()
1030
+
1031
+ for i, x_sample in enumerate(x_samples_ddim):
1032
+ p.batch_index = i
1033
+
1034
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1035
+ x_sample = x_sample.astype(np.uint8)
1036
+
1037
+ if p.restore_faces:
1038
+ if save_samples and opts.save_images_before_face_restoration:
1039
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
1040
+
1041
+ devices.torch_gc()
1042
+
1043
+ x_sample = modules.face_restoration.restore_faces(x_sample)
1044
+ devices.torch_gc()
1045
+
1046
+ image = Image.fromarray(x_sample)
1047
+
1048
+ if p.scripts is not None:
1049
+ pp = scripts.PostprocessImageArgs(image)
1050
+ p.scripts.postprocess_image(p, pp)
1051
+ image = pp.image
1052
+
1053
+ mask_for_overlay = getattr(p, "mask_for_overlay", None)
1054
+
1055
+ if not shared.opts.overlay_inpaint:
1056
+ overlay_image = None
1057
+ elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
1058
+ overlay_image = p.overlay_images[i]
1059
+ else:
1060
+ overlay_image = None
1061
+
1062
+ if p.scripts is not None:
1063
+ ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
1064
+ p.scripts.postprocess_maskoverlay(p, ppmo)
1065
+ mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
1066
+
1067
+ if p.color_corrections is not None and i < len(p.color_corrections):
1068
+ if save_samples and opts.save_images_before_color_correction:
1069
+ image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
1070
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
1071
+ image = apply_color_correction(p.color_corrections[i], image)
1072
+
1073
+ # If the intention is to show the output from the model
1074
+ # that is being composited over the original image,
1075
+ # we need to keep the original image around
1076
+ # and use it in the composite step.
1077
+ image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)
1078
+
1079
+ if p.scripts is not None:
1080
+ pp = scripts.PostprocessImageArgs(image)
1081
+ p.scripts.postprocess_image_after_composite(p, pp)
1082
+ image = pp.image
1083
+
1084
+ if save_samples:
1085
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
1086
+
1087
+ text = infotext(i)
1088
+ infotexts.append(text)
1089
+ if opts.enable_pnginfo:
1090
+ image.info["parameters"] = text
1091
+ output_images.append(image)
1092
+
1093
+ if mask_for_overlay is not None:
1094
+ if opts.return_mask or opts.save_mask:
1095
+ image_mask = mask_for_overlay.convert('RGB')
1096
+ if save_samples and opts.save_mask:
1097
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
1098
+ if opts.return_mask:
1099
+ output_images.append(image_mask)
1100
+
1101
+ if opts.return_mask_composite or opts.save_mask_composite:
1102
+ image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
1103
+ if save_samples and opts.save_mask_composite:
1104
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
1105
+ if opts.return_mask_composite:
1106
+ output_images.append(image_mask_composite)
1107
+
1108
+ del x_samples_ddim
1109
+
1110
+ devices.torch_gc()
1111
+
1112
+ if not infotexts:
1113
+ infotexts.append(Processed(p, []).infotext(p, 0))
1114
+
1115
+ p.color_corrections = None
1116
+
1117
+ index_of_first_image = 0
1118
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
1119
+ if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
1120
+ grid = images.image_grid(output_images, p.batch_size)
1121
+
1122
+ if opts.return_grid:
1123
+ text = infotext(use_main_prompt=True)
1124
+ infotexts.insert(0, text)
1125
+ if opts.enable_pnginfo:
1126
+ grid.info["parameters"] = text
1127
+ output_images.insert(0, grid)
1128
+ index_of_first_image = 1
1129
+ if opts.grid_save:
1130
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
1131
+
1132
+ if not p.disable_extra_networks and p.extra_network_data:
1133
+ extra_networks.deactivate(p, p.extra_network_data)
1134
+
1135
+ devices.torch_gc()
1136
+
1137
+ res = Processed(
1138
+ p,
1139
+ images_list=output_images,
1140
+ seed=p.all_seeds[0],
1141
+ info=infotexts[0],
1142
+ subseed=p.all_subseeds[0],
1143
+ index_of_first_image=index_of_first_image,
1144
+ infotexts=infotexts,
1145
+ )
1146
+
1147
+ if p.scripts is not None:
1148
+ p.scripts.postprocess(p, res)
1149
+
1150
+ return res
1151
+
1152
+
1153
+ def old_hires_fix_first_pass_dimensions(width, height):
1154
+ """old algorithm for auto-calculating first pass size"""
1155
+
1156
+ desired_pixel_count = 512 * 512
1157
+ actual_pixel_count = width * height
1158
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
1159
+ width = math.ceil(scale * width / 64) * 64
1160
+ height = math.ceil(scale * height / 64) * 64
1161
+
1162
+ return width, height
1163
+
1164
+
1165
+ @dataclass(repr=False)
1166
+ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
1167
+ enable_hr: bool = False
1168
+ denoising_strength: float = 0.75
1169
+ firstphase_width: int = 0
1170
+ firstphase_height: int = 0
1171
+ hr_scale: float = 2.0
1172
+ hr_upscaler: str = None
1173
+ hr_second_pass_steps: int = 0
1174
+ hr_resize_x: int = 0
1175
+ hr_resize_y: int = 0
1176
+ hr_checkpoint_name: str = None
1177
+ hr_sampler_name: str = None
1178
+ hr_scheduler: str = None
1179
+ hr_prompt: str = ''
1180
+ hr_negative_prompt: str = ''
1181
+ force_task_id: str = None
1182
+
1183
+ cached_hr_uc = [None, None]
1184
+ cached_hr_c = [None, None]
1185
+
1186
+ hr_checkpoint_info: dict = field(default=None, init=False)
1187
+ hr_upscale_to_x: int = field(default=0, init=False)
1188
+ hr_upscale_to_y: int = field(default=0, init=False)
1189
+ truncate_x: int = field(default=0, init=False)
1190
+ truncate_y: int = field(default=0, init=False)
1191
+ applied_old_hires_behavior_to: tuple = field(default=None, init=False)
1192
+ latent_scale_mode: dict = field(default=None, init=False)
1193
+ hr_c: tuple | None = field(default=None, init=False)
1194
+ hr_uc: tuple | None = field(default=None, init=False)
1195
+ all_hr_prompts: list = field(default=None, init=False)
1196
+ all_hr_negative_prompts: list = field(default=None, init=False)
1197
+ hr_prompts: list = field(default=None, init=False)
1198
+ hr_negative_prompts: list = field(default=None, init=False)
1199
+ hr_extra_network_data: list = field(default=None, init=False)
1200
+
1201
+ def __post_init__(self):
1202
+ super().__post_init__()
1203
+
1204
+ if self.firstphase_width != 0 or self.firstphase_height != 0:
1205
+ self.hr_upscale_to_x = self.width
1206
+ self.hr_upscale_to_y = self.height
1207
+ self.width = self.firstphase_width
1208
+ self.height = self.firstphase_height
1209
+
1210
+ self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
1211
+ self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
1212
+
1213
+ def calculate_target_resolution(self):
1214
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
1215
+ self.hr_resize_x = self.width
1216
+ self.hr_resize_y = self.height
1217
+ self.hr_upscale_to_x = self.width
1218
+ self.hr_upscale_to_y = self.height
1219
+
1220
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
1221
+ self.applied_old_hires_behavior_to = (self.width, self.height)
1222
+
1223
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
1224
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
1225
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
1226
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
1227
+ else:
1228
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
1229
+
1230
+ if self.hr_resize_y == 0:
1231
+ self.hr_upscale_to_x = self.hr_resize_x
1232
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1233
+ elif self.hr_resize_x == 0:
1234
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1235
+ self.hr_upscale_to_y = self.hr_resize_y
1236
+ else:
1237
+ target_w = self.hr_resize_x
1238
+ target_h = self.hr_resize_y
1239
+ src_ratio = self.width / self.height
1240
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
1241
+
1242
+ if src_ratio < dst_ratio:
1243
+ self.hr_upscale_to_x = self.hr_resize_x
1244
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1245
+ else:
1246
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1247
+ self.hr_upscale_to_y = self.hr_resize_y
1248
+
1249
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
1250
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
1251
+
1252
+ def init(self, all_prompts, all_seeds, all_subseeds):
1253
+ if self.enable_hr:
1254
+ self.extra_generation_params["Denoising strength"] = self.denoising_strength
1255
+
1256
+ if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
1257
+ self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
1258
+
1259
+ if self.hr_checkpoint_info is None:
1260
+ raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
1261
+
1262
+ self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
1263
+
1264
+ if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
1265
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
1266
+
1267
+ def get_hr_prompt(p, index, prompt_text, **kwargs):
1268
+ hr_prompt = p.all_hr_prompts[index]
1269
+ return hr_prompt if hr_prompt != prompt_text else None
1270
+
1271
+ def get_hr_negative_prompt(p, index, negative_prompt, **kwargs):
1272
+ hr_negative_prompt = p.all_hr_negative_prompts[index]
1273
+ return hr_negative_prompt if hr_negative_prompt != negative_prompt else None
1274
+
1275
+ self.extra_generation_params["Hires prompt"] = get_hr_prompt
1276
+ self.extra_generation_params["Hires negative prompt"] = get_hr_negative_prompt
1277
+
1278
+ self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
1279
+
1280
+ if self.hr_scheduler is None:
1281
+ self.hr_scheduler = self.scheduler
1282
+
1283
+ self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
1284
+ if self.enable_hr and self.latent_scale_mode is None:
1285
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
1286
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
1287
+
1288
+ self.calculate_target_resolution()
1289
+
1290
+ if not state.processing_has_refined_job_count:
1291
+ if state.job_count == -1:
1292
+ state.job_count = self.n_iter
1293
+ if getattr(self, 'txt2img_upscale', False):
1294
+ total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
1295
+ else:
1296
+ total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
1297
+ shared.total_tqdm.updateTotal(total_steps)
1298
+ state.job_count = state.job_count * 2
1299
+ state.processing_has_refined_job_count = True
1300
+
1301
+ if self.hr_second_pass_steps:
1302
+ self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1303
+
1304
+ if self.hr_upscaler is not None:
1305
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1306
+
1307
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1308
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1309
+
1310
+ if self.firstpass_image is not None and self.enable_hr:
1311
+ # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
1312
+
1313
+ if self.latent_scale_mode is None:
1314
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
1315
+ image = np.moveaxis(image, 2, 0)
1316
+
1317
+ samples = None
1318
+ decoded_samples = torch.asarray(np.expand_dims(image, 0))
1319
+
1320
+ else:
1321
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0
1322
+ image = np.moveaxis(image, 2, 0)
1323
+ image = torch.from_numpy(np.expand_dims(image, axis=0))
1324
+ image = image.to(shared.device, dtype=devices.dtype_vae)
1325
+
1326
+ if opts.sd_vae_encode_method != 'Full':
1327
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1328
+
1329
+ samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1330
+ decoded_samples = None
1331
+ devices.torch_gc()
1332
+
1333
+ else:
1334
+ # here we generate an image normally
1335
+
1336
+ x = self.rng.next()
1337
+ if self.scripts is not None:
1338
+ self.scripts.process_before_every_sampling(
1339
+ p=self,
1340
+ x=x,
1341
+ noise=x,
1342
+ c=conditioning,
1343
+ uc=unconditional_conditioning
1344
+ )
1345
+
1346
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1347
+ del x
1348
+
1349
+ if not self.enable_hr:
1350
+ return samples
1351
+
1352
+ devices.torch_gc()
1353
+
1354
+ if self.latent_scale_mode is None:
1355
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
1356
+ else:
1357
+ decoded_samples = None
1358
+
1359
+ with sd_models.SkipWritingToConfig():
1360
+ sd_models.reload_model_weights(info=self.hr_checkpoint_info)
1361
+
1362
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
1363
+
1364
+ def sample_progressive(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1365
+ is_sdxl = getattr(self.sd_model, 'is_sdxl', False)
1366
+
1367
+ if is_sdxl:
1368
+ min_scale = max(0.5, self.progressive_growing_min_scale)
1369
+ else:
1370
+ min_scale = self.progressive_growing_min_scale
1371
+
1372
+ resolution_steps = np.linspace(min_scale, self.progressive_growing_max_scale, self.progressive_growing_steps)
1373
+
1374
+ initial_width = max(512 if is_sdxl else 64, int(self.width * resolution_steps[0]))
1375
+ initial_height = max(512 if is_sdxl else 64, int(self.height * resolution_steps[0]))
1376
+
1377
+ x = create_random_tensors((opt_C, initial_height // opt_f, initial_width // opt_f), seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1378
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1379
+
1380
+ for i in range(1, len(resolution_steps)):
1381
+ target_width = int(self.width * resolution_steps[i])
1382
+ target_height = int(self.height * resolution_steps[i])
1383
+
1384
+ if is_sdxl:
1385
+ target_width = max(512, min(1536, target_width))
1386
+ target_height = max(512, min(1536, target_height))
1387
+
1388
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode='bicubic', align_corners=False)
1389
+
1390
+ if self.progressive_growing_refinement:
1391
+ steps_for_refinement = self.steps // len(resolution_steps)
1392
+ noise = create_random_tensors(samples.shape[1:], seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1393
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
1394
+ decoded_samples = torch.stack(decoded_samples).float()
1395
+ decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1396
+ self.image_conditioning = self.img2img_image_conditioning(decoded_samples * 2 - 1, samples)
1397
+
1398
+ samples = self.sampler.sample_img2img(
1399
+ self,
1400
+ samples,
1401
+ noise,
1402
+ conditioning,
1403
+ unconditional_conditioning,
1404
+ steps=steps_for_refinement,
1405
+ image_conditioning=self.image_conditioning
1406
+ )
1407
+
1408
+ return samples
1409
+
1410
+ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
1411
+ if shared.state.interrupted:
1412
+ return samples
1413
+
1414
+ self.is_hr_pass = True
1415
+ target_width = self.hr_upscale_to_x
1416
+ target_height = self.hr_upscale_to_y
1417
+
1418
+ def save_intermediate(image, index):
1419
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
1420
+
1421
+ if not self.save_samples() or not opts.save_images_before_highres_fix:
1422
+ return
1423
+
1424
+ if not isinstance(image, Image.Image):
1425
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
1426
+
1427
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
1428
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1429
+
1430
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
1431
+
1432
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
1433
+
1434
+ if self.latent_scale_mode is not None:
1435
+ for i in range(samples.shape[0]):
1436
+ save_intermediate(samples, i)
1437
+
1438
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
1439
+
1440
+ # Avoid making the inpainting conditioning unless necessary as
1441
+ # this does need some extra compute to decode / encode the image again.
1442
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
1443
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
1444
+ else:
1445
+ image_conditioning = self.txt2img_image_conditioning(samples)
1446
+ else:
1447
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1448
+
1449
+ batch_images = []
1450
+ for i, x_sample in enumerate(lowres_samples):
1451
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1452
+ x_sample = x_sample.astype(np.uint8)
1453
+ image = Image.fromarray(x_sample)
1454
+
1455
+ save_intermediate(image, i)
1456
+
1457
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1458
+ image = np.array(image).astype(np.float32) / 255.0
1459
+ image = np.moveaxis(image, 2, 0)
1460
+ batch_images.append(image)
1461
+
1462
+ decoded_samples = torch.from_numpy(np.array(batch_images))
1463
+ decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
1464
+
1465
+ if opts.sd_vae_encode_method != 'Full':
1466
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1467
+ samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
1468
+
1469
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1470
+
1471
+ shared.state.nextjob()
1472
+
1473
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
1474
+
1475
+ self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
1476
+ noise = self.rng.next()
1477
+
1478
+ # GC now before running the next img2img to prevent running out of memory
1479
+ devices.torch_gc()
1480
+
1481
+ if not self.disable_extra_networks:
1482
+ with devices.autocast():
1483
+ extra_networks.activate(self, self.hr_extra_network_data)
1484
+
1485
+ with devices.autocast():
1486
+ self.calculate_hr_conds()
1487
+
1488
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1489
+
1490
+ if self.scripts is not None:
1491
+ self.scripts.before_hr(self)
1492
+ self.scripts.process_before_every_sampling(
1493
+ p=self,
1494
+ x=samples,
1495
+ noise=noise,
1496
+ c=self.hr_c,
1497
+ uc=self.hr_uc,
1498
+ )
1499
+
1500
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
1501
+
1502
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
1503
+
1504
+ self.sampler = None
1505
+ devices.torch_gc()
1506
+
1507
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
1508
+
1509
+ self.is_hr_pass = False
1510
+ return decoded_samples
1511
+
1512
+ def close(self):
1513
+ super().close()
1514
+ self.hr_c = None
1515
+ self.hr_uc = None
1516
+ if not opts.persistent_cond_cache:
1517
+ StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
1518
+ StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1519
+
1520
+ def setup_prompts(self):
1521
+ super().setup_prompts()
1522
+
1523
+ if not self.enable_hr:
1524
+ return
1525
+
1526
+ if self.hr_prompt == '':
1527
+ self.hr_prompt = self.prompt
1528
+
1529
+ if self.hr_negative_prompt == '':
1530
+ self.hr_negative_prompt = self.negative_prompt
1531
+
1532
+ if isinstance(self.hr_prompt, list):
1533
+ self.all_hr_prompts = self.hr_prompt
1534
+ else:
1535
+ self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
1536
+
1537
+ if isinstance(self.hr_negative_prompt, list):
1538
+ self.all_hr_negative_prompts = self.hr_negative_prompt
1539
+ else:
1540
+ self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
1541
+
1542
+ self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
1543
+ self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
1544
+
1545
+ def calculate_hr_conds(self):
1546
+ if self.hr_c is not None:
1547
+ return
1548
+
1549
+ hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
1550
+ hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
1551
+
1552
+ sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
1553
+ steps = self.hr_second_pass_steps or self.steps
1554
+ total_steps = sampler_config.total_steps(steps) if sampler_config else steps
1555
+
1556
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
1557
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
1558
+
1559
+ def setup_conds(self):
1560
+ if self.is_hr_pass:
1561
+ # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
1562
+ self.hr_c = None
1563
+ self.calculate_hr_conds()
1564
+ return
1565
+
1566
+ super().setup_conds()
1567
+
1568
+ self.hr_uc = None
1569
+ self.hr_c = None
1570
+
1571
+ if self.enable_hr and self.hr_checkpoint_info is None:
1572
+ if shared.opts.hires_fix_use_firstpass_conds:
1573
+ self.calculate_hr_conds()
1574
+
1575
+ elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
1576
+ with devices.autocast():
1577
+ extra_networks.activate(self, self.hr_extra_network_data)
1578
+
1579
+ self.calculate_hr_conds()
1580
+
1581
+ with devices.autocast():
1582
+ extra_networks.activate(self, self.extra_network_data)
1583
+
1584
+ def get_conds(self):
1585
+ if self.is_hr_pass:
1586
+ return self.hr_c, self.hr_uc
1587
+
1588
+ return super().get_conds()
1589
+
1590
+ def parse_extra_network_prompts(self):
1591
+ res = super().parse_extra_network_prompts()
1592
+
1593
+ if self.enable_hr:
1594
+ self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1595
+ self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1596
+
1597
+ self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
1598
+
1599
+ return res
1600
+
1601
+
1602
+ @dataclass(repr=False)
1603
+ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
1604
+ init_images: list = None
1605
+ resize_mode: int = 0
1606
+ denoising_strength: float = 0.75
1607
+ image_cfg_scale: float = None
1608
+ mask: Any = None
1609
+ mask_blur_x: int = 4
1610
+ mask_blur_y: int = 4
1611
+ mask_blur: int = None
1612
+ mask_round: bool = True
1613
+ inpainting_fill: int = 0
1614
+ inpaint_full_res: bool = True
1615
+ inpaint_full_res_padding: int = 0
1616
+ inpainting_mask_invert: int = 0
1617
+ initial_noise_multiplier: float = None
1618
+ latent_mask: Image = None
1619
+ force_task_id: str = None
1620
+
1621
+ image_mask: Any = field(default=None, init=False)
1622
+
1623
+ nmask: torch.Tensor = field(default=None, init=False)
1624
+ image_conditioning: torch.Tensor = field(default=None, init=False)
1625
+ init_img_hash: str = field(default=None, init=False)
1626
+ mask_for_overlay: Image = field(default=None, init=False)
1627
+ init_latent: torch.Tensor = field(default=None, init=False)
1628
+
1629
+ def __post_init__(self):
1630
+ super().__post_init__()
1631
+
1632
+ self.image_mask = self.mask
1633
+ self.mask = None
1634
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
1635
+
1636
+ @property
1637
+ def mask_blur(self):
1638
+ if self.mask_blur_x == self.mask_blur_y:
1639
+ return self.mask_blur_x
1640
+ return None
1641
+
1642
+ @mask_blur.setter
1643
+ def mask_blur(self, value):
1644
+ if isinstance(value, int):
1645
+ self.mask_blur_x = value
1646
+ self.mask_blur_y = value
1647
+
1648
+ def init(self, all_prompts, all_seeds, all_subseeds):
1649
+ self.extra_generation_params["Denoising strength"] = self.denoising_strength
1650
+
1651
+ self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
1652
+
1653
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1654
+ crop_region = None
1655
+
1656
+ image_mask = self.image_mask
1657
+
1658
+ if image_mask is not None:
1659
+ # image_mask is passed in as RGBA by Gradio to support alpha masks,
1660
+ # but we still want to support binary masks.
1661
+ image_mask = create_binary_mask(image_mask, round=self.mask_round)
1662
+
1663
+ if self.inpainting_mask_invert:
1664
+ image_mask = ImageOps.invert(image_mask)
1665
+ self.extra_generation_params["Mask mode"] = "Inpaint not masked"
1666
+
1667
+ if self.mask_blur_x > 0:
1668
+ np_mask = np.array(image_mask)
1669
+ kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
1670
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
1671
+ image_mask = Image.fromarray(np_mask)
1672
+
1673
+ if self.mask_blur_y > 0:
1674
+ np_mask = np.array(image_mask)
1675
+ kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
1676
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
1677
+ image_mask = Image.fromarray(np_mask)
1678
+
1679
+ if self.mask_blur_x > 0 or self.mask_blur_y > 0:
1680
+ self.extra_generation_params["Mask blur"] = self.mask_blur
1681
+
1682
+ if self.inpaint_full_res:
1683
+ self.mask_for_overlay = image_mask
1684
+ mask = image_mask.convert('L')
1685
+ crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding)
1686
+ if crop_region:
1687
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1688
+ x1, y1, x2, y2 = crop_region
1689
+ mask = mask.crop(crop_region)
1690
+ image_mask = images.resize_image(2, mask, self.width, self.height)
1691
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
1692
+ self.extra_generation_params["Inpaint area"] = "Only masked"
1693
+ self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
1694
+ else:
1695
+ crop_region = None
1696
+ image_mask = None
1697
+ self.mask_for_overlay = None
1698
+ self.inpaint_full_res = False
1699
+ massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
1700
+ model_hijack.comments.append(massage)
1701
+ logging.info(massage)
1702
+ else:
1703
+ image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1704
+ np_mask = np.array(image_mask)
1705
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1706
+ self.mask_for_overlay = Image.fromarray(np_mask)
1707
+
1708
+ self.overlay_images = []
1709
+
1710
+ latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1711
+
1712
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
1713
+ if add_color_corrections:
1714
+ self.color_corrections = []
1715
+ imgs = []
1716
+ for img in self.init_images:
1717
+
1718
+ # Save init image
1719
+ if opts.save_init_img:
1720
+ self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
1721
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
1722
+
1723
+ image = images.flatten(img, opts.img2img_background_color)
1724
+
1725
+ if crop_region is None and self.resize_mode != 3:
1726
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
1727
+
1728
+ if image_mask is not None:
1729
+ if self.mask_for_overlay.size != (image.width, image.height):
1730
+ self.mask_for_overlay = images.resize_image(self.resize_mode, self.mask_for_overlay, image.width, image.height)
1731
+ image_masked = Image.new('RGBa', (image.width, image.height))
1732
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
1733
+
1734
+ self.overlay_images.append(image_masked.convert('RGBA'))
1735
+
1736
+ # crop_region is not None if we are doing inpaint full res
1737
+ if crop_region is not None:
1738
+ image = image.crop(crop_region)
1739
+ image = images.resize_image(2, image, self.width, self.height)
1740
+
1741
+ if image_mask is not None:
1742
+ if self.inpainting_fill != 1:
1743
+ image = masking.fill(image, latent_mask)
1744
+
1745
+ if self.inpainting_fill == 0:
1746
+ self.extra_generation_params["Masked content"] = 'fill'
1747
+
1748
+ if add_color_corrections:
1749
+ self.color_corrections.append(setup_color_correction(image))
1750
+
1751
+ image = np.array(image).astype(np.float32) / 255.0
1752
+ image = np.moveaxis(image, 2, 0)
1753
+
1754
+ imgs.append(image)
1755
+
1756
+ if len(imgs) == 1:
1757
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
1758
+ if self.overlay_images is not None:
1759
+ self.overlay_images = self.overlay_images * self.batch_size
1760
+
1761
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
1762
+ self.color_corrections = self.color_corrections * self.batch_size
1763
+
1764
+ elif len(imgs) <= self.batch_size:
1765
+ self.batch_size = len(imgs)
1766
+ batch_images = np.array(imgs)
1767
+ else:
1768
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
1769
+
1770
+ image = torch.from_numpy(batch_images)
1771
+ image = image.to(shared.device, dtype=devices.dtype_vae)
1772
+
1773
+ if opts.sd_vae_encode_method != 'Full':
1774
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1775
+
1776
+ self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1777
+ devices.torch_gc()
1778
+
1779
+ if self.resize_mode == 3:
1780
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
1781
+
1782
+ if image_mask is not None:
1783
+ init_mask = latent_mask
1784
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
1785
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1786
+ latmask = latmask[0]
1787
+ if self.mask_round:
1788
+ latmask = np.around(latmask)
1789
+ latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))
1790
+
1791
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)
1792
+ self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)
1793
+
1794
+ # this needs to be fixed to be done in sample() using actual seeds for batches
1795
+ if self.inpainting_fill == 2:
1796
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1797
+ self.extra_generation_params["Masked content"] = 'latent noise'
1798
+
1799
+ elif self.inpainting_fill == 3:
1800
+ self.init_latent = self.init_latent * self.mask
1801
+ self.extra_generation_params["Masked content"] = 'latent nothing'
1802
+
1803
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
1804
+
1805
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1806
+ x = self.rng.next()
1807
+
1808
+ if self.initial_noise_multiplier != 1.0:
1809
+ self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
1810
+ x *= self.initial_noise_multiplier
1811
+
1812
+ if self.scripts is not None:
1813
+ self.scripts.process_before_every_sampling(
1814
+ p=self,
1815
+ x=self.init_latent,
1816
+ noise=x,
1817
+ c=conditioning,
1818
+ uc=unconditional_conditioning
1819
+ )
1820
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1821
+
1822
+ if self.mask is not None:
1823
+ blended_samples = samples * self.nmask + self.init_latent * self.mask
1824
+
1825
+ if self.scripts is not None:
1826
+ mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
1827
+ self.scripts.on_mask_blend(self, mba)
1828
+ blended_samples = mba.blended_latent
1829
+
1830
+ samples = blended_samples
1831
+
1832
+ del x
1833
+ devices.torch_gc()
1834
+
1835
+ return samples
1836
+
1837
+ def get_token_merging_ratio(self, for_hr=False):
1838
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
hm/processing.py ADDED
@@ -0,0 +1,1838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ import sys
7
+ import hashlib
8
+ from dataclasses import dataclass, field
9
+
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image, ImageOps
13
+ import random
14
+ import cv2
15
+ from skimage import exposure
16
+ from typing import Any
17
+
18
+ import modules.sd_hijack
19
+ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
20
+ from modules.rng import slerp # noqa: F401
21
+ from modules.sd_hijack import model_hijack
22
+ from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
23
+ from modules.shared import opts, cmd_opts, state
24
+ import modules.shared as shared
25
+ import modules.paths as paths
26
+ import modules.face_restoration
27
+ import modules.images as images
28
+ import modules.styles
29
+ import modules.sd_models as sd_models
30
+ import modules.sd_vae as sd_vae
31
+ from ldm.data.util import AddMiDaS
32
+ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
33
+
34
+ from einops import repeat, rearrange
35
+ from blendmodes.blend import blendLayers, BlendType
36
+
37
+
38
+ # some of those options should not be changed at all because they would break the model, so I removed them from options.
39
+ opt_C = 4
40
+ opt_f = 8
41
+
42
+
43
+ def setup_color_correction(image):
44
+ logging.info("Calibrating color correction.")
45
+ correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
46
+ return correction_target
47
+
48
+
49
+ def apply_color_correction(correction, original_image):
50
+ logging.info("Applying color correction.")
51
+ image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
52
+ cv2.cvtColor(
53
+ np.asarray(original_image),
54
+ cv2.COLOR_RGB2LAB
55
+ ),
56
+ correction,
57
+ channel_axis=2
58
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
59
+
60
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
61
+
62
+ return image.convert('RGB')
63
+
64
+
65
+ def uncrop(image, dest_size, paste_loc):
66
+ x, y, w, h = paste_loc
67
+ base_image = Image.new('RGBA', dest_size)
68
+ image = images.resize_image(1, image, w, h)
69
+ base_image.paste(image, (x, y))
70
+ image = base_image
71
+
72
+ return image
73
+
74
+
75
+ def apply_overlay(image, paste_loc, overlay):
76
+ if overlay is None:
77
+ return image, image.copy()
78
+
79
+ if paste_loc is not None:
80
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
81
+
82
+ original_denoised_image = image.copy()
83
+
84
+ image = image.convert('RGBA')
85
+ image.alpha_composite(overlay)
86
+ image = image.convert('RGB')
87
+
88
+ return image, original_denoised_image
89
+
90
+ def create_binary_mask(image, round=True):
91
+ if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
92
+ if round:
93
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
94
+ else:
95
+ image = image.split()[-1].convert("L")
96
+ else:
97
+ image = image.convert('L')
98
+ return image
99
+
100
+ def txt2img_image_conditioning(sd_model, x, width, height):
101
+ if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
102
+
103
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
104
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
105
+ image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
106
+
107
+ # Add the fake full 1s mask to the first dimension.
108
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
109
+ image_conditioning = image_conditioning.to(x.dtype)
110
+
111
+ return image_conditioning
112
+
113
+ elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
114
+
115
+ return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
116
+
117
+ else:
118
+ if sd_model.is_sdxl_inpaint:
119
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
120
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
121
+ image_conditioning = images_tensor_to_samples(image_conditioning,
122
+ approximation_indexes.get(opts.sd_vae_encode_method))
123
+
124
+ # Add the fake full 1s mask to the first dimension.
125
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
126
+ image_conditioning = image_conditioning.to(x.dtype)
127
+
128
+ return image_conditioning
129
+
130
+ # Dummy zero conditioning if we're not using inpainting or unclip models.
131
+ # Still takes up a bit of memory, but no encoder call.
132
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
133
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
134
+
135
+
136
+ @dataclass(repr=False)
137
+ class StableDiffusionProcessing:
138
+ sd_model: object = None
139
+ outpath_samples: str = None
140
+ outpath_grids: str = None
141
+ prompt: str = ""
142
+ prompt_for_display: str = None
143
+ negative_prompt: str = ""
144
+ styles: list[str] = None
145
+ seed: int = -1
146
+ subseed: int = -1
147
+ subseed_strength: float = 0
148
+ seed_resize_from_h: int = -1
149
+ seed_resize_from_w: int = -1
150
+ seed_enable_extras: bool = True
151
+ sampler_name: str = None
152
+ scheduler: str = None
153
+ batch_size: int = 1
154
+ n_iter: int = 1
155
+ steps: int = 50
156
+ cfg_scale: float = 7.0
157
+ width: int = 512
158
+ height: int = 512
159
+ restore_faces: bool = None
160
+ tiling: bool = None
161
+ do_not_save_samples: bool = False
162
+ do_not_save_grid: bool = False
163
+ extra_generation_params: dict[str, Any] = None
164
+ overlay_images: list = None
165
+ eta: float = None
166
+ do_not_reload_embeddings: bool = False
167
+ denoising_strength: float = None
168
+ ddim_discretize: str = None
169
+ s_min_uncond: float = None
170
+ s_churn: float = None
171
+ s_tmax: float = None
172
+ s_tmin: float = None
173
+ s_noise: float = None
174
+ override_settings: dict[str, Any] = None
175
+ override_settings_restore_afterwards: bool = True
176
+ sampler_index: int = None
177
+ refiner_checkpoint: str = None
178
+ refiner_switch_at: float = None
179
+ token_merging_ratio = 0
180
+ token_merging_ratio_hr = 0
181
+ disable_extra_networks: bool = False
182
+ firstpass_image: Image = None
183
+
184
+ scripts_value: scripts.ScriptRunner = field(default=None, init=False)
185
+ script_args_value: list = field(default=None, init=False)
186
+ scripts_setup_complete: bool = field(default=False, init=False)
187
+
188
+ cached_uc = [None, None]
189
+ cached_c = [None, None]
190
+
191
+ comments: dict = None
192
+ sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
193
+ is_using_inpainting_conditioning: bool = field(default=False, init=False)
194
+ paste_to: tuple | None = field(default=None, init=False)
195
+
196
+ is_hr_pass: bool = field(default=False, init=False)
197
+
198
+ c: tuple = field(default=None, init=False)
199
+ uc: tuple = field(default=None, init=False)
200
+
201
+ rng: rng.ImageRNG | None = field(default=None, init=False)
202
+ step_multiplier: int = field(default=1, init=False)
203
+ color_corrections: list = field(default=None, init=False)
204
+
205
+ all_prompts: list = field(default=None, init=False)
206
+ all_negative_prompts: list = field(default=None, init=False)
207
+ all_seeds: list = field(default=None, init=False)
208
+ all_subseeds: list = field(default=None, init=False)
209
+ iteration: int = field(default=0, init=False)
210
+ main_prompt: str = field(default=None, init=False)
211
+ main_negative_prompt: str = field(default=None, init=False)
212
+
213
+ prompts: list = field(default=None, init=False)
214
+ negative_prompts: list = field(default=None, init=False)
215
+ seeds: list = field(default=None, init=False)
216
+ subseeds: list = field(default=None, init=False)
217
+ extra_network_data: dict = field(default=None, init=False)
218
+
219
+ user: str = field(default=None, init=False)
220
+
221
+ sd_model_name: str = field(default=None, init=False)
222
+ sd_model_hash: str = field(default=None, init=False)
223
+ sd_vae_name: str = field(default=None, init=False)
224
+ sd_vae_hash: str = field(default=None, init=False)
225
+
226
+ is_api: bool = field(default=False, init=False)
227
+
228
+ def __post_init__(self):
229
+ if self.sampler_index is not None:
230
+ print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
231
+
232
+ self.comments = {}
233
+
234
+ if self.styles is None:
235
+ self.styles = []
236
+
237
+ self.sampler_noise_scheduler_override = None
238
+
239
+ self.extra_generation_params = self.extra_generation_params or {}
240
+ self.override_settings = self.override_settings or {}
241
+ self.script_args = self.script_args or {}
242
+
243
+ self.refiner_checkpoint_info = None
244
+
245
+ if not self.seed_enable_extras:
246
+ self.subseed = -1
247
+ self.subseed_strength = 0
248
+ self.seed_resize_from_h = 0
249
+ self.seed_resize_from_w = 0
250
+
251
+ self.cached_uc = StableDiffusionProcessing.cached_uc
252
+ self.cached_c = StableDiffusionProcessing.cached_c
253
+
254
+ def fill_fields_from_opts(self):
255
+ self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
256
+ self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
257
+ self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
258
+ self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
259
+ self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
260
+
261
+ @property
262
+ def sd_model(self):
263
+ return shared.sd_model
264
+
265
+ @sd_model.setter
266
+ def sd_model(self, value):
267
+ pass
268
+
269
+ @property
270
+ def scripts(self):
271
+ return self.scripts_value
272
+
273
+ @scripts.setter
274
+ def scripts(self, value):
275
+ self.scripts_value = value
276
+
277
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
278
+ self.setup_scripts()
279
+
280
+ @property
281
+ def script_args(self):
282
+ return self.script_args_value
283
+
284
+ @script_args.setter
285
+ def script_args(self, value):
286
+ self.script_args_value = value
287
+
288
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
289
+ self.setup_scripts()
290
+
291
+ def setup_scripts(self):
292
+ self.scripts_setup_complete = True
293
+
294
+ self.scripts.setup_scrips(self, is_ui=not self.is_api)
295
+
296
+ def comment(self, text):
297
+ self.comments[text] = 1
298
+
299
+ def txt2img_image_conditioning(self, x, width=None, height=None):
300
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
301
+
302
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
303
+
304
+ def depth2img_image_conditioning(self, source_image):
305
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
306
+ transformer = AddMiDaS(model_type="dpt_hybrid")
307
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
308
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
309
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
310
+
311
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
312
+ conditioning = torch.nn.functional.interpolate(
313
+ self.sd_model.depth_model(midas_in),
314
+ size=conditioning_image.shape[2:],
315
+ mode="bicubic",
316
+ align_corners=False,
317
+ )
318
+
319
+ (depth_min, depth_max) = torch.aminmax(conditioning)
320
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
321
+ return conditioning
322
+
323
+ def edit_image_conditioning(self, source_image):
324
+ conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
325
+
326
+ return conditioning_image
327
+
328
+ def unclip_image_conditioning(self, source_image):
329
+ c_adm = self.sd_model.embedder(source_image)
330
+ if self.sd_model.noise_augmentor is not None:
331
+ noise_level = 0 # TODO: Allow other noise levels?
332
+ c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
333
+ c_adm = torch.cat((c_adm, noise_level_emb), 1)
334
+ return c_adm
335
+
336
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
337
+ self.is_using_inpainting_conditioning = True
338
+
339
+ # Handle the different mask inputs
340
+ if image_mask is not None:
341
+ if torch.is_tensor(image_mask):
342
+ conditioning_mask = image_mask
343
+ else:
344
+ conditioning_mask = np.array(image_mask.convert("L"))
345
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
346
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
347
+
348
+ if round_image_mask:
349
+ # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
350
+ conditioning_mask = torch.round(conditioning_mask)
351
+
352
+ else:
353
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
354
+
355
+ # Create another latent image, this time with a masked version of the original input.
356
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
357
+ conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
358
+ conditioning_image = torch.lerp(
359
+ source_image,
360
+ source_image * (1.0 - conditioning_mask),
361
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
362
+ )
363
+
364
+ # Encode the new masked image using first stage of network.
365
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
366
+
367
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
368
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
369
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
370
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
371
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
372
+
373
+ return image_conditioning
374
+
375
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
376
+ source_image = devices.cond_cast_float(source_image)
377
+
378
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
379
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
380
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
381
+ return self.depth2img_image_conditioning(source_image)
382
+
383
+ if self.sd_model.cond_stage_key == "edit":
384
+ return self.edit_image_conditioning(source_image)
385
+
386
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
387
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
388
+
389
+ if self.sampler.conditioning_key == "crossattn-adm":
390
+ return self.unclip_image_conditioning(source_image)
391
+
392
+ if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:
393
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
394
+
395
+ # Dummy zero conditioning if we're not using inpainting or depth model.
396
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
397
+
398
+ def init(self, all_prompts, all_seeds, all_subseeds):
399
+ pass
400
+
401
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
402
+ raise NotImplementedError()
403
+
404
+ def close(self):
405
+ self.sampler = None
406
+ self.c = None
407
+ self.uc = None
408
+ if not opts.persistent_cond_cache:
409
+ StableDiffusionProcessing.cached_c = [None, None]
410
+ StableDiffusionProcessing.cached_uc = [None, None]
411
+
412
+ def get_token_merging_ratio(self, for_hr=False):
413
+ if for_hr:
414
+ return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
415
+
416
+ return self.token_merging_ratio or opts.token_merging_ratio
417
+
418
+ def setup_prompts(self):
419
+ if isinstance(self.prompt,list):
420
+ self.all_prompts = self.prompt
421
+ elif isinstance(self.negative_prompt, list):
422
+ self.all_prompts = [self.prompt] * len(self.negative_prompt)
423
+ else:
424
+ self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
425
+
426
+ if isinstance(self.negative_prompt, list):
427
+ self.all_negative_prompts = self.negative_prompt
428
+ else:
429
+ self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)
430
+
431
+ if len(self.all_prompts) != len(self.all_negative_prompts):
432
+ raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
433
+
434
+ self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
435
+ self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
436
+
437
+ self.main_prompt = self.all_prompts[0]
438
+ self.main_negative_prompt = self.all_negative_prompts[0]
439
+
440
+ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
441
+ """Returns parameters that invalidate the cond cache if changed"""
442
+
443
+ return (
444
+ required_prompts,
445
+ steps,
446
+ hires_steps,
447
+ use_old_scheduling,
448
+ opts.CLIP_stop_at_last_layers,
449
+ shared.sd_model.sd_checkpoint_info,
450
+ extra_network_data,
451
+ opts.sdxl_crop_left,
452
+ opts.sdxl_crop_top,
453
+ self.width,
454
+ self.height,
455
+ opts.fp8_storage,
456
+ opts.cache_fp16_weight,
457
+ opts.emphasis,
458
+ )
459
+
460
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
461
+ """
462
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
463
+ using a cache to store the result if the same arguments have been used before.
464
+
465
+ cache is an array containing two elements. The first element is a tuple
466
+ representing the previously used arguments, or None if no arguments
467
+ have been used before. The second element is where the previously
468
+ computed result is stored.
469
+
470
+ caches is a list with items described above.
471
+ """
472
+
473
+ if shared.opts.use_old_scheduling:
474
+ old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
475
+ new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
476
+ if old_schedules != new_schedules:
477
+ self.extra_generation_params["Old prompt editing timelines"] = True
478
+
479
+ cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
480
+
481
+ for cache in caches:
482
+ if cache[0] is not None and cached_params == cache[0]:
483
+ return cache[1]
484
+
485
+ cache = caches[0]
486
+
487
+ with devices.autocast():
488
+ cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
489
+
490
+ cache[0] = cached_params
491
+ return cache[1]
492
+
493
+ def setup_conds(self):
494
+ prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
495
+ negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
496
+
497
+ sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
498
+ total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
499
+ self.step_multiplier = total_steps // self.steps
500
+ self.firstpass_steps = total_steps
501
+
502
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
503
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
504
+
505
+ def get_conds(self):
506
+ return self.c, self.uc
507
+
508
+ def parse_extra_network_prompts(self):
509
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
510
+
511
+ def save_samples(self) -> bool:
512
+ """Returns whether generated images need to be written to disk"""
513
+ return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
514
+
515
+
516
+ class Processed:
517
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
518
+ self.images = images_list
519
+ self.prompt = p.prompt
520
+ self.negative_prompt = p.negative_prompt
521
+ self.seed = seed
522
+ self.subseed = subseed
523
+ self.subseed_strength = p.subseed_strength
524
+ self.info = info
525
+ self.comments = "".join(f"{comment}\n" for comment in p.comments)
526
+ self.width = p.width
527
+ self.height = p.height
528
+ self.sampler_name = p.sampler_name
529
+ self.cfg_scale = p.cfg_scale
530
+ self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
531
+ self.steps = p.steps
532
+ self.batch_size = p.batch_size
533
+ self.restore_faces = p.restore_faces
534
+ self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
535
+ self.sd_model_name = p.sd_model_name
536
+ self.sd_model_hash = p.sd_model_hash
537
+ self.sd_vae_name = p.sd_vae_name
538
+ self.sd_vae_hash = p.sd_vae_hash
539
+ self.seed_resize_from_w = p.seed_resize_from_w
540
+ self.seed_resize_from_h = p.seed_resize_from_h
541
+ self.denoising_strength = getattr(p, 'denoising_strength', None)
542
+ self.extra_generation_params = p.extra_generation_params
543
+ self.index_of_first_image = index_of_first_image
544
+ self.styles = p.styles
545
+ self.job_timestamp = state.job_timestamp
546
+ self.clip_skip = opts.CLIP_stop_at_last_layers
547
+ self.token_merging_ratio = p.token_merging_ratio
548
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
549
+
550
+ self.eta = p.eta
551
+ self.ddim_discretize = p.ddim_discretize
552
+ self.s_churn = p.s_churn
553
+ self.s_tmin = p.s_tmin
554
+ self.s_tmax = p.s_tmax
555
+ self.s_noise = p.s_noise
556
+ self.s_min_uncond = p.s_min_uncond
557
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
558
+ self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
559
+ self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
560
+ self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
561
+ self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
562
+ self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
563
+
564
+ self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
565
+ self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
566
+ self.all_seeds = all_seeds or p.all_seeds or [self.seed]
567
+ self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
568
+ self.infotexts = infotexts or [info] * len(images_list)
569
+ self.version = program_version()
570
+
571
+ def js(self):
572
+ obj = {
573
+ "prompt": self.all_prompts[0],
574
+ "all_prompts": self.all_prompts,
575
+ "negative_prompt": self.all_negative_prompts[0],
576
+ "all_negative_prompts": self.all_negative_prompts,
577
+ "seed": self.seed,
578
+ "all_seeds": self.all_seeds,
579
+ "subseed": self.subseed,
580
+ "all_subseeds": self.all_subseeds,
581
+ "subseed_strength": self.subseed_strength,
582
+ "width": self.width,
583
+ "height": self.height,
584
+ "sampler_name": self.sampler_name,
585
+ "cfg_scale": self.cfg_scale,
586
+ "steps": self.steps,
587
+ "batch_size": self.batch_size,
588
+ "restore_faces": self.restore_faces,
589
+ "face_restoration_model": self.face_restoration_model,
590
+ "sd_model_name": self.sd_model_name,
591
+ "sd_model_hash": self.sd_model_hash,
592
+ "sd_vae_name": self.sd_vae_name,
593
+ "sd_vae_hash": self.sd_vae_hash,
594
+ "seed_resize_from_w": self.seed_resize_from_w,
595
+ "seed_resize_from_h": self.seed_resize_from_h,
596
+ "denoising_strength": self.denoising_strength,
597
+ "extra_generation_params": self.extra_generation_params,
598
+ "index_of_first_image": self.index_of_first_image,
599
+ "infotexts": self.infotexts,
600
+ "styles": self.styles,
601
+ "job_timestamp": self.job_timestamp,
602
+ "clip_skip": self.clip_skip,
603
+ "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
604
+ "version": self.version,
605
+ }
606
+
607
+ return json.dumps(obj, default=lambda o: None)
608
+
609
+ def infotext(self, p: StableDiffusionProcessing, index):
610
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
611
+
612
+ def get_token_merging_ratio(self, for_hr=False):
613
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
614
+
615
+
616
+ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
617
+ g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
618
+ return g.next()
619
+
620
+
621
+ class DecodedSamples(list):
622
+ already_decoded = True
623
+
624
+
625
+ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
626
+ samples = DecodedSamples()
627
+
628
+ if check_for_nans:
629
+ devices.test_for_nans(batch, "unet")
630
+
631
+ for i in range(batch.shape[0]):
632
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
633
+
634
+ if check_for_nans:
635
+
636
+ try:
637
+ devices.test_for_nans(sample, "vae")
638
+ except devices.NansException as e:
639
+ if shared.opts.auto_vae_precision_bfloat16:
640
+ autofix_dtype = torch.bfloat16
641
+ autofix_dtype_text = "bfloat16"
642
+ autofix_dtype_setting = "Automatically convert VAE to bfloat16"
643
+ autofix_dtype_comment = ""
644
+ elif shared.opts.auto_vae_precision:
645
+ autofix_dtype = torch.float32
646
+ autofix_dtype_text = "32-bit float"
647
+ autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
648
+ autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
649
+ else:
650
+ raise e
651
+
652
+ if devices.dtype_vae == autofix_dtype:
653
+ raise e
654
+
655
+ errors.print_error_explanation(
656
+ "A tensor with all NaNs was produced in VAE.\n"
657
+ f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
658
+ f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
659
+ )
660
+
661
+ devices.dtype_vae = autofix_dtype
662
+ model.first_stage_model.to(devices.dtype_vae)
663
+ batch = batch.to(devices.dtype_vae)
664
+
665
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
666
+
667
+ if target_device is not None:
668
+ sample = sample.to(target_device)
669
+
670
+ samples.append(sample)
671
+
672
+ return samples
673
+
674
+
675
+ def get_fixed_seed(seed):
676
+ if seed == '' or seed is None:
677
+ seed = -1
678
+ elif isinstance(seed, str):
679
+ try:
680
+ seed = int(seed)
681
+ except Exception:
682
+ seed = -1
683
+
684
+ if seed == -1:
685
+ return int(random.randrange(4294967294))
686
+
687
+ return seed
688
+
689
+
690
+ def fix_seed(p):
691
+ p.seed = get_fixed_seed(p.seed)
692
+ p.subseed = get_fixed_seed(p.subseed)
693
+
694
+
695
+ def program_version():
696
+ import launch
697
+
698
+ res = launch.git_tag()
699
+ if res == "<none>":
700
+ res = None
701
+
702
+ return res
703
+
704
+
705
+ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
706
+ """
707
+ this function is used to generate the infotext that is stored in the generated images, it's contains the parameters that are required to generate the imagee
708
+ Args:
709
+ p: StableDiffusionProcessing
710
+ all_prompts: list[str]
711
+ all_seeds: list[int]
712
+ all_subseeds: list[int]
713
+ comments: list[str]
714
+ iteration: int
715
+ position_in_batch: int
716
+ use_main_prompt: bool
717
+ index: int
718
+ all_negative_prompts: list[str]
719
+
720
+ Returns: str
721
+
722
+ Extra generation params
723
+ p.extra_generation_params dictionary allows for additional parameters to be added to the infotext
724
+ this can be use by the base webui or extensions.
725
+ To add a new entry, add a new key value pair, the dictionary key will be used as the key of the parameter in the infotext
726
+ the value generation_params can be defined as:
727
+ - str | None
728
+ - List[str|None]
729
+ - callable func(**kwargs) -> str | None
730
+
731
+ When defined as a string, it will be used as without extra processing; this is this most common use case.
732
+
733
+ Defining as a list allows for parameter that changes across images in the job, for example, the 'Seed' parameter.
734
+ The list should have the same length as the total number of images in the entire job.
735
+
736
+ Defining as a callable function allows parameter cannot be generated earlier or when extra logic is required.
737
+ For example 'Hires prompt', due to reasons the hr_prompt might be changed by process in the pipeline or extensions
738
+ and may vary across different images, defining as a static string or list would not work.
739
+
740
+ The function takes locals() as **kwargs, as such will have access to variables like 'p' and 'index'.
741
+ the base signature of the function should be:
742
+ func(**kwargs) -> str | None
743
+ optionally it can have additional arguments that will be used in the function:
744
+ func(p, index, **kwargs) -> str | None
745
+ note: for better future compatibility even though this function will have access to all variables in the locals(),
746
+ it is recommended to only use the arguments present in the function signature of create_infotext.
747
+ For actual implementation examples, see StableDiffusionProcessingTxt2Img.init > get_hr_prompt.
748
+ """
749
+
750
+ if use_main_prompt:
751
+ index = 0
752
+ elif index is None:
753
+ index = position_in_batch + iteration * p.batch_size
754
+
755
+ if all_negative_prompts is None:
756
+ all_negative_prompts = p.all_negative_prompts
757
+
758
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
759
+ enable_hr = getattr(p, 'enable_hr', False)
760
+ token_merging_ratio = p.get_token_merging_ratio()
761
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
762
+
763
+ prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
764
+ negative_prompt = p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]
765
+
766
+ uses_ensd = opts.eta_noise_seed_delta != 0
767
+ if uses_ensd:
768
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
769
+
770
+ generation_params = {
771
+ "Steps": p.steps,
772
+ "Sampler": p.sampler_name,
773
+ "Schedule type": p.scheduler,
774
+ "CFG scale": p.cfg_scale,
775
+ "Image CFG scale": getattr(p, 'image_cfg_scale', None),
776
+ "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
777
+ "Face restoration": opts.face_restoration_model if p.restore_faces else None,
778
+ "Size": f"{p.width}x{p.height}",
779
+ "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
780
+ "Model": p.sd_model_name if opts.add_model_name_to_info else None,
781
+ "FP8 weight": opts.fp8_storage if devices.fp8 else None,
782
+ "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
783
+ "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
784
+ "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
785
+ "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
786
+ "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
787
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
788
+ "Denoising strength": p.extra_generation_params.get("Denoising strength"),
789
+ "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
790
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
791
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
792
+ "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
793
+ "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
794
+ "Init image hash": getattr(p, 'init_img_hash', None),
795
+ "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
796
+ "Tiling": "True" if p.tiling else None,
797
+ **p.extra_generation_params,
798
+ "Version": program_version() if opts.add_version_to_infotext else None,
799
+ "User": p.user if opts.add_user_name_to_info else None,
800
+ }
801
+
802
+ for key, value in generation_params.items():
803
+ try:
804
+ if isinstance(value, list):
805
+ generation_params[key] = value[index]
806
+ elif callable(value):
807
+ generation_params[key] = value(**locals())
808
+ except Exception:
809
+ errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
810
+ generation_params[key] = None
811
+
812
+ generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
813
+
814
+ negative_prompt_text = f"\nNegative prompt: {negative_prompt}" if negative_prompt else ""
815
+
816
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
817
+
818
+
819
+ def process_images(p: StableDiffusionProcessing) -> Processed:
820
+ if p.scripts is not None:
821
+ p.scripts.before_process(p)
822
+
823
+ stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
824
+
825
+ try:
826
+ # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
827
+ # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
828
+ if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
829
+ p.override_settings.pop('sd_model_checkpoint', None)
830
+ sd_models.reload_model_weights()
831
+
832
+ for k, v in p.override_settings.items():
833
+ opts.set(k, v, is_api=True, run_callbacks=False)
834
+
835
+ if k == 'sd_model_checkpoint':
836
+ sd_models.reload_model_weights()
837
+
838
+ if k == 'sd_vae':
839
+ sd_vae.reload_vae_weights()
840
+
841
+ sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
842
+
843
+ # backwards compatibility, fix sampler and scheduler if invalid
844
+ sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
845
+
846
+ with profiling.Profiler():
847
+ res = process_images_inner(p)
848
+
849
+ finally:
850
+ sd_models.apply_token_merging(p.sd_model, 0)
851
+
852
+ # restore opts to original state
853
+ if p.override_settings_restore_afterwards:
854
+ for k, v in stored_opts.items():
855
+ setattr(opts, k, v)
856
+
857
+ if k == 'sd_vae':
858
+ sd_vae.reload_vae_weights()
859
+
860
+ return res
861
+
862
+
863
+ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
864
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
865
+
866
+ if isinstance(p.prompt, list):
867
+ assert(len(p.prompt) > 0)
868
+ else:
869
+ assert p.prompt is not None
870
+
871
+ devices.torch_gc()
872
+
873
+ seed = get_fixed_seed(p.seed)
874
+ subseed = get_fixed_seed(p.subseed)
875
+
876
+ if p.restore_faces is None:
877
+ p.restore_faces = opts.face_restoration
878
+
879
+ if p.tiling is None:
880
+ p.tiling = opts.tiling
881
+
882
+ if p.refiner_checkpoint not in (None, "", "None", "none"):
883
+ p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
884
+ if p.refiner_checkpoint_info is None:
885
+ raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
886
+
887
+ if hasattr(shared.sd_model, 'fix_dimensions'):
888
+ p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)
889
+
890
+ p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
891
+ p.sd_model_hash = shared.sd_model.sd_model_hash
892
+ p.sd_vae_name = sd_vae.get_loaded_vae_name()
893
+ p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
894
+
895
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
896
+ modules.sd_hijack.model_hijack.clear_comments()
897
+
898
+ p.fill_fields_from_opts()
899
+ p.setup_prompts()
900
+
901
+ if isinstance(seed, list):
902
+ p.all_seeds = seed
903
+ else:
904
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
905
+
906
+ if isinstance(subseed, list):
907
+ p.all_subseeds = subseed
908
+ else:
909
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
910
+
911
+ if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
912
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
913
+
914
+ if p.scripts is not None:
915
+ p.scripts.process(p)
916
+
917
+ infotexts = []
918
+ output_images = []
919
+ with torch.no_grad(), p.sd_model.ema_scope():
920
+ with devices.autocast():
921
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
922
+
923
+ # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
924
+ if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
925
+ sd_vae_approx.model()
926
+
927
+ sd_unet.apply_unet()
928
+
929
+ if state.job_count == -1:
930
+ state.job_count = p.n_iter
931
+
932
+ for n in range(p.n_iter):
933
+ p.iteration = n
934
+
935
+ if state.skipped:
936
+ state.skipped = False
937
+
938
+ if state.interrupted or state.stopping_generation:
939
+ break
940
+
941
+ sd_models.reload_model_weights() # model can be changed for example by refiner
942
+
943
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
944
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
945
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
946
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
947
+
948
+ latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
949
+ p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
950
+
951
+ if p.scripts is not None:
952
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
953
+
954
+ if len(p.prompts) == 0:
955
+ break
956
+
957
+ p.parse_extra_network_prompts()
958
+
959
+ if not p.disable_extra_networks:
960
+ with devices.autocast():
961
+ extra_networks.activate(p, p.extra_network_data)
962
+
963
+ if p.scripts is not None:
964
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
965
+
966
+ p.setup_conds()
967
+
968
+ p.extra_generation_params.update(model_hijack.extra_generation_params)
969
+
970
+ # params.txt should be saved after scripts.process_batch, since the
971
+ # infotext could be modified by that callback
972
+ # Example: a wildcard processed by process_batch sets an extra model
973
+ # strength, which is saved as "Model Strength: 1.0" in the infotext
974
+ if n == 0 and not cmd_opts.no_prompt_history:
975
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
976
+ processed = Processed(p, [])
977
+ file.write(processed.infotext(p, 0))
978
+
979
+ for comment in model_hijack.comments:
980
+ p.comment(comment)
981
+
982
+ if p.n_iter > 1:
983
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
984
+
985
+ sd_models.apply_alpha_schedule_override(p.sd_model, p)
986
+
987
+ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
988
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
989
+
990
+ if p.scripts is not None:
991
+ ps = scripts.PostSampleArgs(samples_ddim)
992
+ p.scripts.post_sample(p, ps)
993
+ samples_ddim = ps.samples
994
+
995
+ if getattr(samples_ddim, 'already_decoded', False):
996
+ x_samples_ddim = samples_ddim
997
+ else:
998
+ devices.test_for_nans(samples_ddim, "unet")
999
+
1000
+ if opts.sd_vae_decode_method != 'Full':
1001
+ p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
1002
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
1003
+
1004
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
1005
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
1006
+
1007
+ del samples_ddim
1008
+
1009
+ if lowvram.is_enabled(shared.sd_model):
1010
+ lowvram.send_everything_to_cpu()
1011
+
1012
+ devices.torch_gc()
1013
+
1014
+ state.nextjob()
1015
+
1016
+ if p.scripts is not None:
1017
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
1018
+
1019
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1020
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1021
+
1022
+ batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
1023
+ p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
1024
+ x_samples_ddim = batch_params.images
1025
+
1026
+ def infotext(index=0, use_main_prompt=False):
1027
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
1028
+
1029
+ save_samples = p.save_samples()
1030
+
1031
+ for i, x_sample in enumerate(x_samples_ddim):
1032
+ p.batch_index = i
1033
+
1034
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1035
+ x_sample = x_sample.astype(np.uint8)
1036
+
1037
+ if p.restore_faces:
1038
+ if save_samples and opts.save_images_before_face_restoration:
1039
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
1040
+
1041
+ devices.torch_gc()
1042
+
1043
+ x_sample = modules.face_restoration.restore_faces(x_sample)
1044
+ devices.torch_gc()
1045
+
1046
+ image = Image.fromarray(x_sample)
1047
+
1048
+ if p.scripts is not None:
1049
+ pp = scripts.PostprocessImageArgs(image)
1050
+ p.scripts.postprocess_image(p, pp)
1051
+ image = pp.image
1052
+
1053
+ mask_for_overlay = getattr(p, "mask_for_overlay", None)
1054
+
1055
+ if not shared.opts.overlay_inpaint:
1056
+ overlay_image = None
1057
+ elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
1058
+ overlay_image = p.overlay_images[i]
1059
+ else:
1060
+ overlay_image = None
1061
+
1062
+ if p.scripts is not None:
1063
+ ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
1064
+ p.scripts.postprocess_maskoverlay(p, ppmo)
1065
+ mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
1066
+
1067
+ if p.color_corrections is not None and i < len(p.color_corrections):
1068
+ if save_samples and opts.save_images_before_color_correction:
1069
+ image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
1070
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
1071
+ image = apply_color_correction(p.color_corrections[i], image)
1072
+
1073
+ # If the intention is to show the output from the model
1074
+ # that is being composited over the original image,
1075
+ # we need to keep the original image around
1076
+ # and use it in the composite step.
1077
+ image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)
1078
+
1079
+ if p.scripts is not None:
1080
+ pp = scripts.PostprocessImageArgs(image)
1081
+ p.scripts.postprocess_image_after_composite(p, pp)
1082
+ image = pp.image
1083
+
1084
+ if save_samples:
1085
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
1086
+
1087
+ text = infotext(i)
1088
+ infotexts.append(text)
1089
+ if opts.enable_pnginfo:
1090
+ image.info["parameters"] = text
1091
+ output_images.append(image)
1092
+
1093
+ if mask_for_overlay is not None:
1094
+ if opts.return_mask or opts.save_mask:
1095
+ image_mask = mask_for_overlay.convert('RGB')
1096
+ if save_samples and opts.save_mask:
1097
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
1098
+ if opts.return_mask:
1099
+ output_images.append(image_mask)
1100
+
1101
+ if opts.return_mask_composite or opts.save_mask_composite:
1102
+ image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
1103
+ if save_samples and opts.save_mask_composite:
1104
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
1105
+ if opts.return_mask_composite:
1106
+ output_images.append(image_mask_composite)
1107
+
1108
+ del x_samples_ddim
1109
+
1110
+ devices.torch_gc()
1111
+
1112
+ if not infotexts:
1113
+ infotexts.append(Processed(p, []).infotext(p, 0))
1114
+
1115
+ p.color_corrections = None
1116
+
1117
+ index_of_first_image = 0
1118
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
1119
+ if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
1120
+ grid = images.image_grid(output_images, p.batch_size)
1121
+
1122
+ if opts.return_grid:
1123
+ text = infotext(use_main_prompt=True)
1124
+ infotexts.insert(0, text)
1125
+ if opts.enable_pnginfo:
1126
+ grid.info["parameters"] = text
1127
+ output_images.insert(0, grid)
1128
+ index_of_first_image = 1
1129
+ if opts.grid_save:
1130
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
1131
+
1132
+ if not p.disable_extra_networks and p.extra_network_data:
1133
+ extra_networks.deactivate(p, p.extra_network_data)
1134
+
1135
+ devices.torch_gc()
1136
+
1137
+ res = Processed(
1138
+ p,
1139
+ images_list=output_images,
1140
+ seed=p.all_seeds[0],
1141
+ info=infotexts[0],
1142
+ subseed=p.all_subseeds[0],
1143
+ index_of_first_image=index_of_first_image,
1144
+ infotexts=infotexts,
1145
+ )
1146
+
1147
+ if p.scripts is not None:
1148
+ p.scripts.postprocess(p, res)
1149
+
1150
+ return res
1151
+
1152
+
1153
+ def old_hires_fix_first_pass_dimensions(width, height):
1154
+ """old algorithm for auto-calculating first pass size"""
1155
+
1156
+ desired_pixel_count = 512 * 512
1157
+ actual_pixel_count = width * height
1158
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
1159
+ width = math.ceil(scale * width / 64) * 64
1160
+ height = math.ceil(scale * height / 64) * 64
1161
+
1162
+ return width, height
1163
+
1164
+
1165
+ @dataclass(repr=False)
1166
+ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
1167
+ enable_hr: bool = False
1168
+ denoising_strength: float = 0.75
1169
+ firstphase_width: int = 0
1170
+ firstphase_height: int = 0
1171
+ hr_scale: float = 2.0
1172
+ hr_upscaler: str = None
1173
+ hr_second_pass_steps: int = 0
1174
+ hr_resize_x: int = 0
1175
+ hr_resize_y: int = 0
1176
+ hr_checkpoint_name: str = None
1177
+ hr_sampler_name: str = None
1178
+ hr_scheduler: str = None
1179
+ hr_prompt: str = ''
1180
+ hr_negative_prompt: str = ''
1181
+ force_task_id: str = None
1182
+
1183
+ cached_hr_uc = [None, None]
1184
+ cached_hr_c = [None, None]
1185
+
1186
+ hr_checkpoint_info: dict = field(default=None, init=False)
1187
+ hr_upscale_to_x: int = field(default=0, init=False)
1188
+ hr_upscale_to_y: int = field(default=0, init=False)
1189
+ truncate_x: int = field(default=0, init=False)
1190
+ truncate_y: int = field(default=0, init=False)
1191
+ applied_old_hires_behavior_to: tuple = field(default=None, init=False)
1192
+ latent_scale_mode: dict = field(default=None, init=False)
1193
+ hr_c: tuple | None = field(default=None, init=False)
1194
+ hr_uc: tuple | None = field(default=None, init=False)
1195
+ all_hr_prompts: list = field(default=None, init=False)
1196
+ all_hr_negative_prompts: list = field(default=None, init=False)
1197
+ hr_prompts: list = field(default=None, init=False)
1198
+ hr_negative_prompts: list = field(default=None, init=False)
1199
+ hr_extra_network_data: list = field(default=None, init=False)
1200
+
1201
+ def __post_init__(self):
1202
+ super().__post_init__()
1203
+
1204
+ if self.firstphase_width != 0 or self.firstphase_height != 0:
1205
+ self.hr_upscale_to_x = self.width
1206
+ self.hr_upscale_to_y = self.height
1207
+ self.width = self.firstphase_width
1208
+ self.height = self.firstphase_height
1209
+
1210
+ self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
1211
+ self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
1212
+
1213
+ def calculate_target_resolution(self):
1214
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
1215
+ self.hr_resize_x = self.width
1216
+ self.hr_resize_y = self.height
1217
+ self.hr_upscale_to_x = self.width
1218
+ self.hr_upscale_to_y = self.height
1219
+
1220
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
1221
+ self.applied_old_hires_behavior_to = (self.width, self.height)
1222
+
1223
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
1224
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
1225
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
1226
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
1227
+ else:
1228
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
1229
+
1230
+ if self.hr_resize_y == 0:
1231
+ self.hr_upscale_to_x = self.hr_resize_x
1232
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1233
+ elif self.hr_resize_x == 0:
1234
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1235
+ self.hr_upscale_to_y = self.hr_resize_y
1236
+ else:
1237
+ target_w = self.hr_resize_x
1238
+ target_h = self.hr_resize_y
1239
+ src_ratio = self.width / self.height
1240
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
1241
+
1242
+ if src_ratio < dst_ratio:
1243
+ self.hr_upscale_to_x = self.hr_resize_x
1244
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1245
+ else:
1246
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1247
+ self.hr_upscale_to_y = self.hr_resize_y
1248
+
1249
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
1250
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
1251
+
1252
+ def init(self, all_prompts, all_seeds, all_subseeds):
1253
+ if self.enable_hr:
1254
+ self.extra_generation_params["Denoising strength"] = self.denoising_strength
1255
+
1256
+ if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
1257
+ self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
1258
+
1259
+ if self.hr_checkpoint_info is None:
1260
+ raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
1261
+
1262
+ self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
1263
+
1264
+ if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
1265
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
1266
+
1267
+ def get_hr_prompt(p, index, prompt_text, **kwargs):
1268
+ hr_prompt = p.all_hr_prompts[index]
1269
+ return hr_prompt if hr_prompt != prompt_text else None
1270
+
1271
+ def get_hr_negative_prompt(p, index, negative_prompt, **kwargs):
1272
+ hr_negative_prompt = p.all_hr_negative_prompts[index]
1273
+ return hr_negative_prompt if hr_negative_prompt != negative_prompt else None
1274
+
1275
+ self.extra_generation_params["Hires prompt"] = get_hr_prompt
1276
+ self.extra_generation_params["Hires negative prompt"] = get_hr_negative_prompt
1277
+
1278
+ self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
1279
+
1280
+ if self.hr_scheduler is None:
1281
+ self.hr_scheduler = self.scheduler
1282
+
1283
+ self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
1284
+ if self.enable_hr and self.latent_scale_mode is None:
1285
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
1286
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
1287
+
1288
+ self.calculate_target_resolution()
1289
+
1290
+ if not state.processing_has_refined_job_count:
1291
+ if state.job_count == -1:
1292
+ state.job_count = self.n_iter
1293
+ if getattr(self, 'txt2img_upscale', False):
1294
+ total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
1295
+ else:
1296
+ total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
1297
+ shared.total_tqdm.updateTotal(total_steps)
1298
+ state.job_count = state.job_count * 2
1299
+ state.processing_has_refined_job_count = True
1300
+
1301
+ if self.hr_second_pass_steps:
1302
+ self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1303
+
1304
+ if self.hr_upscaler is not None:
1305
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1306
+
1307
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1308
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1309
+
1310
+ if self.firstpass_image is not None and self.enable_hr:
1311
+ # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
1312
+
1313
+ if self.latent_scale_mode is None:
1314
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
1315
+ image = np.moveaxis(image, 2, 0)
1316
+
1317
+ samples = None
1318
+ decoded_samples = torch.asarray(np.expand_dims(image, 0))
1319
+
1320
+ else:
1321
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0
1322
+ image = np.moveaxis(image, 2, 0)
1323
+ image = torch.from_numpy(np.expand_dims(image, axis=0))
1324
+ image = image.to(shared.device, dtype=devices.dtype_vae)
1325
+
1326
+ if opts.sd_vae_encode_method != 'Full':
1327
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1328
+
1329
+ samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1330
+ decoded_samples = None
1331
+ devices.torch_gc()
1332
+
1333
+ else:
1334
+ # here we generate an image normally
1335
+
1336
+ x = self.rng.next()
1337
+ if self.scripts is not None:
1338
+ self.scripts.process_before_every_sampling(
1339
+ p=self,
1340
+ x=x,
1341
+ noise=x,
1342
+ c=conditioning,
1343
+ uc=unconditional_conditioning
1344
+ )
1345
+
1346
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1347
+ del x
1348
+
1349
+ if not self.enable_hr:
1350
+ return samples
1351
+
1352
+ devices.torch_gc()
1353
+
1354
+ if self.latent_scale_mode is None:
1355
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
1356
+ else:
1357
+ decoded_samples = None
1358
+
1359
+ with sd_models.SkipWritingToConfig():
1360
+ sd_models.reload_model_weights(info=self.hr_checkpoint_info)
1361
+
1362
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
1363
+
1364
+ def sample_progressive(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1365
+ is_sdxl = getattr(self.sd_model, 'is_sdxl', False)
1366
+
1367
+ if is_sdxl:
1368
+ min_scale = max(0.5, self.progressive_growing_min_scale)
1369
+ else:
1370
+ min_scale = self.progressive_growing_min_scale
1371
+
1372
+ resolution_steps = np.linspace(min_scale, self.progressive_growing_max_scale, self.progressive_growing_steps)
1373
+
1374
+ initial_width = max(512 if is_sdxl else 64, int(self.width * resolution_steps[0]))
1375
+ initial_height = max(512 if is_sdxl else 64, int(self.height * resolution_steps[0]))
1376
+
1377
+ x = create_random_tensors((opt_C, initial_height // opt_f, initial_width // opt_f), seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1378
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1379
+
1380
+ for i in range(1, len(resolution_steps)):
1381
+ target_width = int(self.width * resolution_steps[i])
1382
+ target_height = int(self.height * resolution_steps[i])
1383
+
1384
+ if is_sdxl:
1385
+ target_width = max(512, min(1536, target_width))
1386
+ target_height = max(512, min(1536, target_height))
1387
+
1388
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode='bicubic', align_corners=False)
1389
+
1390
+ if self.progressive_growing_refinement:
1391
+ steps_for_refinement = self.steps // len(resolution_steps)
1392
+ noise = create_random_tensors(samples.shape[1:], seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1393
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
1394
+ decoded_samples = torch.stack(decoded_samples).float()
1395
+ decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1396
+ self.image_conditioning = self.img2img_image_conditioning(decoded_samples * 2 - 1, samples)
1397
+
1398
+ samples = self.sampler.sample_img2img(
1399
+ self,
1400
+ samples,
1401
+ noise,
1402
+ conditioning,
1403
+ unconditional_conditioning,
1404
+ steps=steps_for_refinement,
1405
+ image_conditioning=self.image_conditioning
1406
+ )
1407
+
1408
+ return samples
1409
+
1410
+ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
1411
+ if shared.state.interrupted:
1412
+ return samples
1413
+
1414
+ self.is_hr_pass = True
1415
+ target_width = self.hr_upscale_to_x
1416
+ target_height = self.hr_upscale_to_y
1417
+
1418
+ def save_intermediate(image, index):
1419
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
1420
+
1421
+ if not self.save_samples() or not opts.save_images_before_highres_fix:
1422
+ return
1423
+
1424
+ if not isinstance(image, Image.Image):
1425
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
1426
+
1427
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
1428
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1429
+
1430
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
1431
+
1432
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
1433
+
1434
+ if self.latent_scale_mode is not None:
1435
+ for i in range(samples.shape[0]):
1436
+ save_intermediate(samples, i)
1437
+
1438
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
1439
+
1440
+ # Avoid making the inpainting conditioning unless necessary as
1441
+ # this does need some extra compute to decode / encode the image again.
1442
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
1443
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
1444
+ else:
1445
+ image_conditioning = self.txt2img_image_conditioning(samples)
1446
+ else:
1447
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1448
+
1449
+ batch_images = []
1450
+ for i, x_sample in enumerate(lowres_samples):
1451
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1452
+ x_sample = x_sample.astype(np.uint8)
1453
+ image = Image.fromarray(x_sample)
1454
+
1455
+ save_intermediate(image, i)
1456
+
1457
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1458
+ image = np.array(image).astype(np.float32) / 255.0
1459
+ image = np.moveaxis(image, 2, 0)
1460
+ batch_images.append(image)
1461
+
1462
+ decoded_samples = torch.from_numpy(np.array(batch_images))
1463
+ decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
1464
+
1465
+ if opts.sd_vae_encode_method != 'Full':
1466
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1467
+ samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
1468
+
1469
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1470
+
1471
+ shared.state.nextjob()
1472
+
1473
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
1474
+
1475
+ self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
1476
+ noise = self.rng.next()
1477
+
1478
+ # GC now before running the next img2img to prevent running out of memory
1479
+ devices.torch_gc()
1480
+
1481
+ if not self.disable_extra_networks:
1482
+ with devices.autocast():
1483
+ extra_networks.activate(self, self.hr_extra_network_data)
1484
+
1485
+ with devices.autocast():
1486
+ self.calculate_hr_conds()
1487
+
1488
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1489
+
1490
+ if self.scripts is not None:
1491
+ self.scripts.before_hr(self)
1492
+ self.scripts.process_before_every_sampling(
1493
+ p=self,
1494
+ x=samples,
1495
+ noise=noise,
1496
+ c=self.hr_c,
1497
+ uc=self.hr_uc,
1498
+ )
1499
+
1500
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
1501
+
1502
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
1503
+
1504
+ self.sampler = None
1505
+ devices.torch_gc()
1506
+
1507
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
1508
+
1509
+ self.is_hr_pass = False
1510
+ return decoded_samples
1511
+
1512
+ def close(self):
1513
+ super().close()
1514
+ self.hr_c = None
1515
+ self.hr_uc = None
1516
+ if not opts.persistent_cond_cache:
1517
+ StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
1518
+ StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1519
+
1520
+ def setup_prompts(self):
1521
+ super().setup_prompts()
1522
+
1523
+ if not self.enable_hr:
1524
+ return
1525
+
1526
+ if self.hr_prompt == '':
1527
+ self.hr_prompt = self.prompt
1528
+
1529
+ if self.hr_negative_prompt == '':
1530
+ self.hr_negative_prompt = self.negative_prompt
1531
+
1532
+ if isinstance(self.hr_prompt, list):
1533
+ self.all_hr_prompts = self.hr_prompt
1534
+ else:
1535
+ self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
1536
+
1537
+ if isinstance(self.hr_negative_prompt, list):
1538
+ self.all_hr_negative_prompts = self.hr_negative_prompt
1539
+ else:
1540
+ self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
1541
+
1542
+ self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
1543
+ self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
1544
+
1545
+ def calculate_hr_conds(self):
1546
+ if self.hr_c is not None:
1547
+ return
1548
+
1549
+ hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
1550
+ hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
1551
+
1552
+ sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
1553
+ steps = self.hr_second_pass_steps or self.steps
1554
+ total_steps = sampler_config.total_steps(steps) if sampler_config else steps
1555
+
1556
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
1557
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
1558
+
1559
+ def setup_conds(self):
1560
+ if self.is_hr_pass:
1561
+ # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
1562
+ self.hr_c = None
1563
+ self.calculate_hr_conds()
1564
+ return
1565
+
1566
+ super().setup_conds()
1567
+
1568
+ self.hr_uc = None
1569
+ self.hr_c = None
1570
+
1571
+ if self.enable_hr and self.hr_checkpoint_info is None:
1572
+ if shared.opts.hires_fix_use_firstpass_conds:
1573
+ self.calculate_hr_conds()
1574
+
1575
+ elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
1576
+ with devices.autocast():
1577
+ extra_networks.activate(self, self.hr_extra_network_data)
1578
+
1579
+ self.calculate_hr_conds()
1580
+
1581
+ with devices.autocast():
1582
+ extra_networks.activate(self, self.extra_network_data)
1583
+
1584
+ def get_conds(self):
1585
+ if self.is_hr_pass:
1586
+ return self.hr_c, self.hr_uc
1587
+
1588
+ return super().get_conds()
1589
+
1590
+ def parse_extra_network_prompts(self):
1591
+ res = super().parse_extra_network_prompts()
1592
+
1593
+ if self.enable_hr:
1594
+ self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1595
+ self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1596
+
1597
+ self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
1598
+
1599
+ return res
1600
+
1601
+
1602
+ @dataclass(repr=False)
1603
+ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
1604
+ init_images: list = None
1605
+ resize_mode: int = 0
1606
+ denoising_strength: float = 0.75
1607
+ image_cfg_scale: float = None
1608
+ mask: Any = None
1609
+ mask_blur_x: int = 4
1610
+ mask_blur_y: int = 4
1611
+ mask_blur: int = None
1612
+ mask_round: bool = True
1613
+ inpainting_fill: int = 0
1614
+ inpaint_full_res: bool = True
1615
+ inpaint_full_res_padding: int = 0
1616
+ inpainting_mask_invert: int = 0
1617
+ initial_noise_multiplier: float = None
1618
+ latent_mask: Image = None
1619
+ force_task_id: str = None
1620
+
1621
+ image_mask: Any = field(default=None, init=False)
1622
+
1623
+ nmask: torch.Tensor = field(default=None, init=False)
1624
+ image_conditioning: torch.Tensor = field(default=None, init=False)
1625
+ init_img_hash: str = field(default=None, init=False)
1626
+ mask_for_overlay: Image = field(default=None, init=False)
1627
+ init_latent: torch.Tensor = field(default=None, init=False)
1628
+
1629
+ def __post_init__(self):
1630
+ super().__post_init__()
1631
+
1632
+ self.image_mask = self.mask
1633
+ self.mask = None
1634
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
1635
+
1636
+ @property
1637
+ def mask_blur(self):
1638
+ if self.mask_blur_x == self.mask_blur_y:
1639
+ return self.mask_blur_x
1640
+ return None
1641
+
1642
+ @mask_blur.setter
1643
+ def mask_blur(self, value):
1644
+ if isinstance(value, int):
1645
+ self.mask_blur_x = value
1646
+ self.mask_blur_y = value
1647
+
1648
+ def init(self, all_prompts, all_seeds, all_subseeds):
1649
+ self.extra_generation_params["Denoising strength"] = self.denoising_strength
1650
+
1651
+ self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
1652
+
1653
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1654
+ crop_region = None
1655
+
1656
+ image_mask = self.image_mask
1657
+
1658
+ if image_mask is not None:
1659
+ # image_mask is passed in as RGBA by Gradio to support alpha masks,
1660
+ # but we still want to support binary masks.
1661
+ image_mask = create_binary_mask(image_mask, round=self.mask_round)
1662
+
1663
+ if self.inpainting_mask_invert:
1664
+ image_mask = ImageOps.invert(image_mask)
1665
+ self.extra_generation_params["Mask mode"] = "Inpaint not masked"
1666
+
1667
+ if self.mask_blur_x > 0:
1668
+ np_mask = np.array(image_mask)
1669
+ kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
1670
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
1671
+ image_mask = Image.fromarray(np_mask)
1672
+
1673
+ if self.mask_blur_y > 0:
1674
+ np_mask = np.array(image_mask)
1675
+ kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
1676
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
1677
+ image_mask = Image.fromarray(np_mask)
1678
+
1679
+ if self.mask_blur_x > 0 or self.mask_blur_y > 0:
1680
+ self.extra_generation_params["Mask blur"] = self.mask_blur
1681
+
1682
+ if self.inpaint_full_res:
1683
+ self.mask_for_overlay = image_mask
1684
+ mask = image_mask.convert('L')
1685
+ crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding)
1686
+ if crop_region:
1687
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1688
+ x1, y1, x2, y2 = crop_region
1689
+ mask = mask.crop(crop_region)
1690
+ image_mask = images.resize_image(2, mask, self.width, self.height)
1691
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
1692
+ self.extra_generation_params["Inpaint area"] = "Only masked"
1693
+ self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
1694
+ else:
1695
+ crop_region = None
1696
+ image_mask = None
1697
+ self.mask_for_overlay = None
1698
+ self.inpaint_full_res = False
1699
+ massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
1700
+ model_hijack.comments.append(massage)
1701
+ logging.info(massage)
1702
+ else:
1703
+ image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1704
+ np_mask = np.array(image_mask)
1705
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1706
+ self.mask_for_overlay = Image.fromarray(np_mask)
1707
+
1708
+ self.overlay_images = []
1709
+
1710
+ latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1711
+
1712
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
1713
+ if add_color_corrections:
1714
+ self.color_corrections = []
1715
+ imgs = []
1716
+ for img in self.init_images:
1717
+
1718
+ # Save init image
1719
+ if opts.save_init_img:
1720
+ self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
1721
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
1722
+
1723
+ image = images.flatten(img, opts.img2img_background_color)
1724
+
1725
+ if crop_region is None and self.resize_mode != 3:
1726
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
1727
+
1728
+ if image_mask is not None:
1729
+ if self.mask_for_overlay.size != (image.width, image.height):
1730
+ self.mask_for_overlay = images.resize_image(self.resize_mode, self.mask_for_overlay, image.width, image.height)
1731
+ image_masked = Image.new('RGBa', (image.width, image.height))
1732
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
1733
+
1734
+ self.overlay_images.append(image_masked.convert('RGBA'))
1735
+
1736
+ # crop_region is not None if we are doing inpaint full res
1737
+ if crop_region is not None:
1738
+ image = image.crop(crop_region)
1739
+ image = images.resize_image(2, image, self.width, self.height)
1740
+
1741
+ if image_mask is not None:
1742
+ if self.inpainting_fill != 1:
1743
+ image = masking.fill(image, latent_mask)
1744
+
1745
+ if self.inpainting_fill == 0:
1746
+ self.extra_generation_params["Masked content"] = 'fill'
1747
+
1748
+ if add_color_corrections:
1749
+ self.color_corrections.append(setup_color_correction(image))
1750
+
1751
+ image = np.array(image).astype(np.float32) / 255.0
1752
+ image = np.moveaxis(image, 2, 0)
1753
+
1754
+ imgs.append(image)
1755
+
1756
+ if len(imgs) == 1:
1757
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
1758
+ if self.overlay_images is not None:
1759
+ self.overlay_images = self.overlay_images * self.batch_size
1760
+
1761
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
1762
+ self.color_corrections = self.color_corrections * self.batch_size
1763
+
1764
+ elif len(imgs) <= self.batch_size:
1765
+ self.batch_size = len(imgs)
1766
+ batch_images = np.array(imgs)
1767
+ else:
1768
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
1769
+
1770
+ image = torch.from_numpy(batch_images)
1771
+ image = image.to(shared.device, dtype=devices.dtype_vae)
1772
+
1773
+ if opts.sd_vae_encode_method != 'Full':
1774
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1775
+
1776
+ self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1777
+ devices.torch_gc()
1778
+
1779
+ if self.resize_mode == 3:
1780
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
1781
+
1782
+ if image_mask is not None:
1783
+ init_mask = latent_mask
1784
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
1785
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1786
+ latmask = latmask[0]
1787
+ if self.mask_round:
1788
+ latmask = np.around(latmask)
1789
+ latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))
1790
+
1791
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)
1792
+ self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)
1793
+
1794
+ # this needs to be fixed to be done in sample() using actual seeds for batches
1795
+ if self.inpainting_fill == 2:
1796
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1797
+ self.extra_generation_params["Masked content"] = 'latent noise'
1798
+
1799
+ elif self.inpainting_fill == 3:
1800
+ self.init_latent = self.init_latent * self.mask
1801
+ self.extra_generation_params["Masked content"] = 'latent nothing'
1802
+
1803
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
1804
+
1805
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1806
+ x = self.rng.next()
1807
+
1808
+ if self.initial_noise_multiplier != 1.0:
1809
+ self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
1810
+ x *= self.initial_noise_multiplier
1811
+
1812
+ if self.scripts is not None:
1813
+ self.scripts.process_before_every_sampling(
1814
+ p=self,
1815
+ x=self.init_latent,
1816
+ noise=x,
1817
+ c=conditioning,
1818
+ uc=unconditional_conditioning
1819
+ )
1820
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1821
+
1822
+ if self.mask is not None:
1823
+ blended_samples = samples * self.nmask + self.init_latent * self.mask
1824
+
1825
+ if self.scripts is not None:
1826
+ mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
1827
+ self.scripts.on_mask_blend(self, mba)
1828
+ blended_samples = mba.blended_latent
1829
+
1830
+ samples = blended_samples
1831
+
1832
+ del x
1833
+ devices.torch_gc()
1834
+
1835
+ return samples
1836
+
1837
+ def get_token_merging_ratio(self, for_hr=False):
1838
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
hm/txt2img.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from contextlib import closing
3
+
4
+ import modules.scripts
5
+ from modules import processing, infotext_utils
6
+ from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
7
+ from modules.shared import opts
8
+ import modules.shared as shared
9
+ from modules.ui import plaintext_to_html
10
+ from PIL import Image
11
+ import gradio as gr
12
+
13
+
14
+ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles,
15
+ n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool,
16
+ denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int,
17
+ hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str,
18
+ hr_prompt: str, hr_negative_prompt, override_settings_texts, enable_progressive_growing: bool,
19
+ progressive_growing_min_scale: float, progressive_growing_max_scale: float, progressive_growing_steps: int,
20
+ progressive_growing_refinement: bool, *args, force_enable_hr=False):
21
+ override_settings = create_override_settings_dict(override_settings_texts)
22
+
23
+ if force_enable_hr:
24
+ enable_hr = True
25
+
26
+
27
+ print(f"enable_progressive_growing: {enable_progressive_growing}")
28
+ print(f"progressive_growing_min_scale: {progressive_growing_min_scale}")
29
+
30
+ p = processing.StableDiffusionProcessingTxt2Img(
31
+ sd_model=shared.sd_model,
32
+ outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
33
+ outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
34
+ prompt=prompt,
35
+ styles=prompt_styles,
36
+ negative_prompt=negative_prompt,
37
+ batch_size=batch_size,
38
+ n_iter=n_iter,
39
+ cfg_scale=cfg_scale,
40
+ width=width,
41
+ height=height,
42
+ enable_hr=enable_hr,
43
+ denoising_strength=denoising_strength,
44
+ hr_scale=hr_scale,
45
+ hr_upscaler=hr_upscaler,
46
+ hr_second_pass_steps=hr_second_pass_steps,
47
+ hr_resize_x=hr_resize_x,
48
+ hr_resize_y=hr_resize_y,
49
+ hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
50
+ hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
51
+ hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
52
+ hr_prompt=hr_prompt,
53
+ hr_negative_prompt=hr_negative_prompt,
54
+ override_settings=override_settings,
55
+ )
56
+
57
+ p.id_task = id_task
58
+ p.enable_progressive_growing = enable_progressive_growing
59
+ p.progressive_growing_min_scale = progressive_growing_min_scale
60
+ p.progressive_growing_max_scale = progressive_growing_max_scale
61
+ p.progressive_growing_steps = progressive_growing_steps
62
+ p.progressive_growing_refinement = progressive_growing_refinement
63
+ p.scripts = modules.scripts.scripts_txt2img
64
+ p.script_args = args
65
+
66
+ p.user = request.username
67
+
68
+ if shared.opts.enable_console_prompts:
69
+ print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
70
+
71
+ return p
72
+
73
+
74
+ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
75
+ assert len(gallery) > 0, 'No image to upscale'
76
+ assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
77
+
78
+ p = txt2img_create_processing(id_task, request, *args, force_enable_hr=True)
79
+ p.batch_size = 1
80
+ p.n_iter = 1
81
+ # txt2img_upscale attribute that signifies this is called by txt2img_upscale
82
+ p.txt2img_upscale = True
83
+
84
+ geninfo = json.loads(generation_info)
85
+
86
+ image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
87
+ p.firstpass_image = infotext_utils.image_from_url_text(image_info)
88
+
89
+ parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], [])
90
+ p.seed = parameters.get('Seed', -1)
91
+ p.subseed = parameters.get('Variation seed', -1)
92
+
93
+ p.override_settings['save_images_before_highres_fix'] = False
94
+
95
+ with closing(p):
96
+ processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
97
+
98
+ if processed is None:
99
+ processed = processing.process_images(p)
100
+
101
+ shared.total_tqdm.clear()
102
+
103
+ new_gallery = []
104
+ for i, image in enumerate(gallery):
105
+ if i == gallery_index:
106
+ geninfo["infotexts"][gallery_index: gallery_index+1] = processed.infotexts
107
+ new_gallery.extend(processed.images)
108
+ else:
109
+ fake_image = Image.new(mode="RGB", size=(1, 1))
110
+ fake_image.already_saved_as = image["name"].rsplit('?', 1)[0]
111
+ new_gallery.append(fake_image)
112
+
113
+ geninfo["infotexts"][gallery_index] = processed.info
114
+
115
+ return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
116
+
117
+
118
+ def txt2img(id_task: str, request: gr.Request, *args):
119
+ p = txt2img_create_processing(id_task, request, *args)
120
+
121
+ with closing(p):
122
+ processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
123
+
124
+ if processed is None:
125
+ processed = processing.process_images(p)
126
+
127
+ shared.total_tqdm.clear()
128
+
129
+ generation_info_js = processed.js()
130
+ if opts.samples_log_stdout:
131
+ print(generation_info_js)
132
+
133
+ if opts.do_not_show_images:
134
+ processed.images = []
135
+
136
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
hm/ui.py ADDED
@@ -0,0 +1,1249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import mimetypes
3
+ import os
4
+ import sys
5
+ from functools import reduce
6
+ import warnings
7
+ from contextlib import ExitStack
8
+
9
+ import gradio as gr
10
+ import gradio.utils
11
+ import numpy as np
12
+ from PIL import Image, PngImagePlugin # noqa: F401
13
+ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
14
+
15
+ from modules import gradio_extensons, sd_schedulers # noqa: F401
16
+ from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
17
+ from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
18
+ from modules.paths import script_path
19
+ from modules.ui_common import create_refresh_button
20
+ from modules.ui_gradio_extensions import reload_javascript
21
+
22
+ from modules.shared import opts, cmd_opts
23
+
24
+ import modules.infotext_utils as parameters_copypaste
25
+ import modules.hypernetworks.ui as hypernetworks_ui
26
+ import modules.textual_inversion.ui as textual_inversion_ui
27
+ import modules.textual_inversion.textual_inversion as textual_inversion
28
+ import modules.shared as shared
29
+ from modules import prompt_parser
30
+ from modules.sd_hijack import model_hijack
31
+ from modules.infotext_utils import image_from_url_text, PasteField
32
+
33
+ create_setting_component = ui_settings.create_setting_component
34
+
35
+ warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
36
+ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
37
+
38
+ # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
39
+ mimetypes.init()
40
+ mimetypes.add_type('application/javascript', '.js')
41
+ mimetypes.add_type('application/javascript', '.mjs')
42
+
43
+ # Likewise, add explicit content-type header for certain missing image types
44
+ mimetypes.add_type('image/webp', '.webp')
45
+ mimetypes.add_type('image/avif', '.avif')
46
+
47
+ if not cmd_opts.share and not cmd_opts.listen:
48
+ # fix gradio phoning home
49
+ gradio.utils.version_check = lambda: None
50
+ gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
51
+
52
+ if cmd_opts.ngrok is not None:
53
+ import modules.ngrok as ngrok
54
+ print('ngrok authtoken detected, trying to connect...')
55
+ ngrok.connect(
56
+ cmd_opts.ngrok,
57
+ cmd_opts.port if cmd_opts.port is not None else 7860,
58
+ cmd_opts.ngrok_options
59
+ )
60
+
61
+
62
+ def gr_show(visible=True):
63
+ return {"visible": visible, "__type__": "update"}
64
+
65
+
66
+ sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
67
+ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
68
+
69
+ # Using constants for these since the variation selector isn't visible.
70
+ # Important that they exactly match script.js for tooltip to work.
71
+ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
72
+ reuse_symbol = '\u267b\ufe0f' # ♻️
73
+ paste_symbol = '\u2199\ufe0f' # ↙
74
+ refresh_symbol = '\U0001f504' # 🔄
75
+ save_style_symbol = '\U0001f4be' # 💾
76
+ apply_style_symbol = '\U0001f4cb' # 📋
77
+ clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
78
+ extra_networks_symbol = '\U0001F3B4' # 🎴
79
+ switch_values_symbol = '\U000021C5' # ⇅
80
+ restore_progress_symbol = '\U0001F300' # 🌀
81
+ detect_image_size_symbol = '\U0001F4D0' # 📐
82
+
83
+
84
+ plaintext_to_html = ui_common.plaintext_to_html
85
+
86
+
87
+ def send_gradio_gallery_to_image(x):
88
+ if len(x) == 0:
89
+ return None
90
+ return image_from_url_text(x[0])
91
+
92
+
93
+ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
94
+ if not enable:
95
+ return ""
96
+
97
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
98
+ p.calculate_target_resolution()
99
+
100
+ return f"from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
101
+
102
+
103
+ def resize_from_to_html(width, height, scale_by):
104
+ target_width = int(width * scale_by)
105
+ target_height = int(height * scale_by)
106
+
107
+ if not target_width or not target_height:
108
+ return "no image selected"
109
+
110
+ return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
111
+
112
+
113
+ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
114
+ if mode in {0, 1, 3, 4}:
115
+ return [interrogation_function(ii_singles[mode]), None]
116
+ elif mode == 2:
117
+ return [interrogation_function(ii_singles[mode]["image"]), None]
118
+ elif mode == 5:
119
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
120
+ images = shared.listfiles(ii_input_dir)
121
+ print(f"Will process {len(images)} images.")
122
+ if ii_output_dir != "":
123
+ os.makedirs(ii_output_dir, exist_ok=True)
124
+ else:
125
+ ii_output_dir = ii_input_dir
126
+
127
+ for image in images:
128
+ img = Image.open(image)
129
+ filename = os.path.basename(image)
130
+ left, _ = os.path.splitext(filename)
131
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
132
+
133
+ return [gr.update(), None]
134
+
135
+
136
+ def interrogate(image):
137
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
138
+ return gr.update() if prompt is None else prompt
139
+
140
+
141
+ def interrogate_deepbooru(image):
142
+ prompt = deepbooru.model.tag(image)
143
+ return gr.update() if prompt is None else prompt
144
+
145
+
146
+ def connect_clear_prompt(button):
147
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
148
+ button.click(
149
+ _js="clear_prompt",
150
+ fn=None,
151
+ inputs=[],
152
+ outputs=[],
153
+ )
154
+
155
+
156
+ def update_token_counter(text, steps, styles, *, is_positive=True):
157
+ params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive)
158
+ script_callbacks.before_token_counter_callback(params)
159
+ text = params.prompt
160
+ steps = params.steps
161
+ styles = params.styles
162
+ is_positive = params.is_positive
163
+
164
+ if shared.opts.include_styles_into_token_counters:
165
+ apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt
166
+ text = apply_styles(text, styles)
167
+
168
+ try:
169
+ text, _ = extra_networks.parse_prompt(text)
170
+
171
+ if is_positive:
172
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
173
+ else:
174
+ prompt_flat_list = [text]
175
+
176
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
177
+
178
+ except Exception:
179
+ # a parsing error can happen here during typing, and we don't want to bother the user with
180
+ # messages related to it in console
181
+ prompt_schedules = [[[steps, text]]]
182
+
183
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
184
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
185
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
186
+ return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
187
+
188
+
189
+ def update_negative_prompt_token_counter(*args):
190
+ return update_token_counter(*args, is_positive=False)
191
+
192
+
193
+ def setup_progressbar(*args, **kwargs):
194
+ pass
195
+
196
+
197
+ def apply_setting(key, value):
198
+ if value is None:
199
+ return gr.update()
200
+
201
+ if shared.cmd_opts.freeze_settings:
202
+ return gr.update()
203
+
204
+ # dont allow model to be swapped when model hash exists in prompt
205
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
206
+ return gr.update()
207
+
208
+ if key == "sd_model_checkpoint":
209
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
210
+
211
+ if ckpt_info is not None:
212
+ value = ckpt_info.title
213
+ else:
214
+ return gr.update()
215
+
216
+ comp_args = opts.data_labels[key].component_args
217
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
218
+ return
219
+
220
+ valtype = type(opts.data_labels[key].default)
221
+ oldval = opts.data.get(key, None)
222
+ opts.data[key] = valtype(value) if valtype != type(None) else value
223
+ if oldval != value and opts.data_labels[key].onchange is not None:
224
+ opts.data_labels[key].onchange()
225
+
226
+ opts.save(shared.config_filename)
227
+ return getattr(opts, key)
228
+
229
+
230
+ def create_output_panel(tabname, outdir, toprow=None):
231
+ return ui_common.create_output_panel(tabname, outdir, toprow)
232
+
233
+
234
+ def ordered_ui_categories():
235
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
236
+
237
+ for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
238
+ yield category
239
+
240
+
241
+ def create_override_settings_dropdown(tabname, row):
242
+ dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
243
+
244
+ dropdown.change(
245
+ fn=lambda x: gr.Dropdown.update(visible=bool(x)),
246
+ inputs=[dropdown],
247
+ outputs=[dropdown],
248
+ )
249
+
250
+ return dropdown
251
+
252
+
253
+ def create_ui():
254
+ import modules.img2img
255
+ import modules.txt2img
256
+
257
+ reload_javascript()
258
+
259
+ parameters_copypaste.reset()
260
+
261
+ settings = ui_settings.UiSettings()
262
+ settings.register_settings()
263
+
264
+ scripts.scripts_current = scripts.scripts_txt2img
265
+ scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
266
+
267
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
268
+ toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
269
+
270
+ dummy_component = gr.Label(visible=False)
271
+
272
+ extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs", elem_classes=["extra-networks"])
273
+ extra_tabs.__enter__()
274
+
275
+ with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
276
+ with ExitStack() as stack:
277
+ if shared.opts.txt2img_settings_accordion:
278
+ stack.enter_context(gr.Accordion("Open for Settings", open=False))
279
+ stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
280
+
281
+ scripts.scripts_txt2img.prepare_ui()
282
+
283
+ for category in ordered_ui_categories():
284
+ if category == "prompt":
285
+ toprow.create_inline_toprow_prompts()
286
+
287
+ elif category == "dimensions":
288
+ with FormRow():
289
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
290
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
291
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
292
+
293
+ with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
294
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height")
295
+
296
+ if opts.dimensions_and_batch_together:
297
+ with gr.Column(elem_id="txt2img_column_batch"):
298
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
299
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
300
+
301
+ elif category == "cfg":
302
+ with gr.Row():
303
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
304
+
305
+ elif category == "checkboxes":
306
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
307
+ pass
308
+
309
+ elif category == "accordions":
310
+ with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
311
+ with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
312
+ with enable_hr.extra():
313
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
314
+
315
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
316
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
317
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
318
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
319
+
320
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
321
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
322
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
323
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
324
+
325
+ with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
326
+
327
+ hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
328
+ create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
329
+
330
+ hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
331
+ hr_scheduler = gr.Dropdown(label='Hires schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
332
+
333
+ with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
334
+ with gr.Column(scale=80):
335
+ with gr.Row():
336
+ hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
337
+ with gr.Column(scale=80):
338
+ with gr.Row():
339
+ hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
340
+
341
+ with InputAccordion(False, label="Progressive Growing", elem_id="txt2img_progressive_growing") as enable_progressive_growing:
342
+ with FormRow(elem_id="txt2img_progressive_growing_row1", variant="compact"):
343
+ progressive_growing_min_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Min Scale", value=0.25, elem_id="txt2img_progressive_growing_min_scale")
344
+ progressive_growing_max_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Max Scale", value=1.0, elem_id="txt2img_progressive_growing_max_scale")
345
+
346
+ with FormRow(elem_id="txt2img_progressive_growing_row2", variant="compact"):
347
+ progressive_growing_steps = gr.Slider(minimum=2, maximum=10, step=1, label="Steps", value=4, elem_id="txt2img_progressive_growing_steps")
348
+ progressive_growing_refinement = gr.Checkbox(label="Enable Refinement", value=True, elem_id="txt2img_progressive_growing_refinement")
349
+
350
+ scripts.scripts_txt2img.setup_ui_for_section(category)
351
+
352
+ elif category == "batch":
353
+ if not opts.dimensions_and_batch_together:
354
+ with FormRow(elem_id="txt2img_column_batch"):
355
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
356
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
357
+
358
+ elif category == "override_settings":
359
+ with FormRow(elem_id="txt2img_override_settings_row") as row:
360
+ override_settings = create_override_settings_dropdown('txt2img', row)
361
+
362
+ elif category == "scripts":
363
+ with FormGroup(elem_id="txt2img_script_container"):
364
+ custom_inputs = scripts.scripts_txt2img.setup_ui()
365
+
366
+ if category not in {"accordions"}:
367
+ scripts.scripts_txt2img.setup_ui_for_section(category)
368
+
369
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
370
+
371
+ for component in hr_resolution_preview_inputs:
372
+ event = component.release if isinstance(component, gr.Slider) else component.change
373
+
374
+ event(
375
+ fn=calc_resolution_hires,
376
+ inputs=hr_resolution_preview_inputs,
377
+ outputs=[hr_final_resolution],
378
+ show_progress=False,
379
+ )
380
+ event(
381
+ None,
382
+ _js="onCalcResolutionHires",
383
+ inputs=hr_resolution_preview_inputs,
384
+ outputs=[],
385
+ show_progress=False,
386
+ )
387
+
388
+ output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
389
+
390
+ txt2img_inputs = [
391
+ dummy_component,
392
+ toprow.prompt,
393
+ toprow.negative_prompt,
394
+ toprow.ui_styles.dropdown,
395
+ batch_count,
396
+ batch_size,
397
+ cfg_scale,
398
+ height,
399
+ width,
400
+ enable_hr,
401
+ denoising_strength,
402
+ hr_scale,
403
+ hr_upscaler,
404
+ hr_second_pass_steps,
405
+ hr_resize_x,
406
+ hr_resize_y,
407
+ hr_checkpoint_name,
408
+ hr_sampler_name,
409
+ hr_scheduler,
410
+ hr_prompt,
411
+ hr_negative_prompt,
412
+ override_settings,
413
+ enable_progressive_growing,
414
+ progressive_growing_min_scale,
415
+ progressive_growing_max_scale,
416
+ progressive_growing_steps,
417
+ progressive_growing_refinement,
418
+ ] + custom_inputs
419
+
420
+ txt2img_outputs = [
421
+ output_panel.gallery,
422
+ output_panel.generation_info,
423
+ output_panel.infotext,
424
+ output_panel.html_log,
425
+ ]
426
+
427
+ txt2img_args = dict(
428
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
429
+ _js="submit",
430
+ inputs=txt2img_inputs,
431
+ outputs=txt2img_outputs,
432
+ show_progress=False,
433
+ )
434
+
435
+ toprow.prompt.submit(**txt2img_args)
436
+ toprow.submit.click(**txt2img_args)
437
+
438
+ output_panel.button_upscale.click(
439
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
440
+ _js="submit_txt2img_upscale",
441
+ inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:],
442
+ outputs=txt2img_outputs,
443
+ show_progress=False,
444
+ )
445
+
446
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
447
+
448
+ toprow.restore_progress_button.click(
449
+ fn=progress.restore_progress,
450
+ _js="restoreProgressTxt2img",
451
+ inputs=[dummy_component],
452
+ outputs=[
453
+ output_panel.gallery,
454
+ output_panel.generation_info,
455
+ output_panel.infotext,
456
+ output_panel.html_log,
457
+ ],
458
+ show_progress=False,
459
+ )
460
+
461
+ txt2img_paste_fields = [
462
+ PasteField(toprow.prompt, "Prompt", api="prompt"),
463
+ PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
464
+ PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
465
+ PasteField(width, "Size-1", api="width"),
466
+ PasteField(height, "Size-2", api="height"),
467
+ PasteField(batch_size, "Batch size", api="batch_size"),
468
+ PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
469
+ PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
470
+ PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
471
+ PasteField(hr_scale, "Hires upscale", api="hr_scale"),
472
+ PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
473
+ PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
474
+ PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
475
+ PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
476
+ PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
477
+ PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
478
+ PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
479
+ PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
480
+ PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
481
+ PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
482
+ PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
483
+ *scripts.scripts_txt2img.infotext_fields
484
+ ]
485
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
486
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
487
+ paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
488
+ ))
489
+
490
+ steps = scripts.scripts_txt2img.script('Sampler').steps
491
+
492
+ txt2img_preview_params = [
493
+ toprow.prompt,
494
+ toprow.negative_prompt,
495
+ steps,
496
+ scripts.scripts_txt2img.script('Sampler').sampler_name,
497
+ cfg_scale,
498
+ scripts.scripts_txt2img.script('Seed').seed,
499
+ width,
500
+ height,
501
+ ]
502
+
503
+ toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
504
+ toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
505
+ toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
506
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
507
+
508
+ extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
509
+ ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
510
+
511
+ extra_tabs.__exit__()
512
+
513
+ scripts.scripts_current = scripts.scripts_img2img
514
+ scripts.scripts_img2img.initialize_scripts(is_img2img=True)
515
+
516
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
517
+ toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
518
+
519
+ extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs", elem_classes=["extra-networks"])
520
+ extra_tabs.__enter__()
521
+
522
+ with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
523
+ with ExitStack() as stack:
524
+ if shared.opts.img2img_settings_accordion:
525
+ stack.enter_context(gr.Accordion("Open for Settings", open=False))
526
+ stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
527
+
528
+ copy_image_buttons = []
529
+ copy_image_destinations = {}
530
+
531
+ def add_copy_image_controls(tab_name, elem):
532
+ with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
533
+ gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
534
+
535
+ for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
536
+ if name == tab_name:
537
+ gr.Button(title, interactive=False)
538
+ copy_image_destinations[name] = elem
539
+ continue
540
+
541
+ button = gr.Button(title)
542
+ copy_image_buttons.append((button, name, elem))
543
+
544
+ scripts.scripts_img2img.prepare_ui()
545
+
546
+ for category in ordered_ui_categories():
547
+ if category == "prompt":
548
+ toprow.create_inline_toprow_prompts()
549
+
550
+ if category == "image":
551
+ with gr.Tabs(elem_id="mode_img2img"):
552
+ img2img_selected_tab = gr.Number(value=0, visible=False)
553
+
554
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
555
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
556
+ add_copy_image_controls('img2img', init_img)
557
+
558
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
559
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
560
+ add_copy_image_controls('sketch', sketch)
561
+
562
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
563
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
564
+ add_copy_image_controls('inpaint', init_img_with_mask)
565
+
566
+ with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
567
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
568
+ inpaint_color_sketch_orig = gr.State(None)
569
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
570
+
571
+ def update_orig(image, state):
572
+ if image is not None:
573
+ same_size = state is not None and state.size == image.size
574
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
575
+ edited = same_size and has_exact_match
576
+ return image if not edited or state is None else state
577
+
578
+ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
579
+
580
+ with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
581
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
582
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
583
+
584
+ with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
585
+ with gr.Tabs(elem_id="img2img_batch_source"):
586
+ img2img_batch_source_type = gr.Textbox(visible=False, value="upload")
587
+ with gr.TabItem('Upload', id='batch_upload', elem_id="img2img_batch_upload_tab") as tab_batch_upload:
588
+ img2img_batch_upload = gr.Files(label="Files", interactive=True, elem_id="img2img_batch_upload")
589
+ with gr.TabItem('From directory', id='batch_from_dir', elem_id="img2img_batch_from_dir_tab") as tab_batch_from_dir:
590
+ hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
591
+ gr.HTML(
592
+ "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
593
+ "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
594
+ f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
595
+ f"{hidden}</p>"
596
+ )
597
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
598
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
599
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
600
+ tab_batch_upload.select(fn=lambda: "upload", inputs=[], outputs=[img2img_batch_source_type])
601
+ tab_batch_from_dir.select(fn=lambda: "from dir", inputs=[], outputs=[img2img_batch_source_type])
602
+ with gr.Accordion("PNG info", open=False):
603
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", elem_id="img2img_batch_use_png_info")
604
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
605
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
606
+
607
+ img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
608
+
609
+ for i, tab in enumerate(img2img_tabs):
610
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
611
+
612
+ def copy_image(img):
613
+ if isinstance(img, dict) and 'image' in img:
614
+ return img['image']
615
+
616
+ return img
617
+
618
+ for button, name, elem in copy_image_buttons:
619
+ button.click(
620
+ fn=copy_image,
621
+ inputs=[elem],
622
+ outputs=[copy_image_destinations[name]],
623
+ )
624
+ button.click(
625
+ fn=lambda: None,
626
+ _js=f"switch_to_{name.replace(' ', '_')}",
627
+ inputs=[],
628
+ outputs=[],
629
+ )
630
+
631
+ with FormRow():
632
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
633
+
634
+ elif category == "dimensions":
635
+ with FormRow():
636
+ with gr.Column(elem_id="img2img_column_size", scale=4):
637
+ selected_scale_tab = gr.Number(value=0, visible=False)
638
+
639
+ with gr.Tabs():
640
+ with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
641
+ with FormRow():
642
+ with gr.Column(elem_id="img2img_column_size", scale=4):
643
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
644
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
645
+ with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
646
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height")
647
+ detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img")
648
+
649
+ with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
650
+ scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
651
+
652
+ with FormRow():
653
+ scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview")
654
+ gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider")
655
+ button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to")
656
+
657
+ on_change_args = dict(
658
+ fn=resize_from_to_html,
659
+ _js="currentImg2imgSourceResolution",
660
+ inputs=[dummy_component, dummy_component, scale_by],
661
+ outputs=scale_by_html,
662
+ show_progress=False,
663
+ )
664
+
665
+ scale_by.release(**on_change_args)
666
+ button_update_resize_to.click(**on_change_args)
667
+
668
+ tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
669
+ tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
670
+
671
+ if opts.dimensions_and_batch_together:
672
+ with gr.Column(elem_id="img2img_column_batch"):
673
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
674
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
675
+
676
+ elif category == "denoising":
677
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
678
+
679
+ elif category == "cfg":
680
+ with gr.Row():
681
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
682
+ image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
683
+
684
+ elif category == "checkboxes":
685
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
686
+ pass
687
+
688
+ elif category == "accordions":
689
+ with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
690
+ scripts.scripts_img2img.setup_ui_for_section(category)
691
+
692
+ elif category == "batch":
693
+ if not opts.dimensions_and_batch_together:
694
+ with FormRow(elem_id="img2img_column_batch"):
695
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
696
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
697
+
698
+ elif category == "override_settings":
699
+ with FormRow(elem_id="img2img_override_settings_row") as row:
700
+ override_settings = create_override_settings_dropdown('img2img', row)
701
+
702
+ elif category == "scripts":
703
+ with FormGroup(elem_id="img2img_script_container"):
704
+ custom_inputs = scripts.scripts_img2img.setup_ui()
705
+
706
+ elif category == "inpaint":
707
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
708
+ with FormRow():
709
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
710
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
711
+
712
+ with FormRow():
713
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
714
+
715
+ with FormRow():
716
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
717
+
718
+ with FormRow():
719
+ with gr.Column():
720
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
721
+
722
+ with gr.Column(scale=4):
723
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
724
+
725
+ if category not in {"accordions"}:
726
+ scripts.scripts_img2img.setup_ui_for_section(category)
727
+
728
+ # the code below is meant to update the resolution label after the image in the image selection UI has changed.
729
+ # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
730
+ # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
731
+ for component in [init_img, sketch]:
732
+ component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
733
+
734
+ def select_img2img_tab(tab):
735
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
736
+
737
+ for i, elem in enumerate(img2img_tabs):
738
+ elem.select(
739
+ fn=lambda tab=i: select_img2img_tab(tab),
740
+ inputs=[],
741
+ outputs=[inpaint_controls, mask_alpha],
742
+ )
743
+
744
+ output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
745
+
746
+ img2img_args = dict(
747
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
748
+ _js="submit_img2img",
749
+ inputs=[
750
+ dummy_component,
751
+ dummy_component,
752
+ toprow.prompt,
753
+ toprow.negative_prompt,
754
+ toprow.ui_styles.dropdown,
755
+ init_img,
756
+ sketch,
757
+ init_img_with_mask,
758
+ inpaint_color_sketch,
759
+ inpaint_color_sketch_orig,
760
+ init_img_inpaint,
761
+ init_mask_inpaint,
762
+ mask_blur,
763
+ mask_alpha,
764
+ inpainting_fill,
765
+ batch_count,
766
+ batch_size,
767
+ cfg_scale,
768
+ image_cfg_scale,
769
+ denoising_strength,
770
+ selected_scale_tab,
771
+ height,
772
+ width,
773
+ scale_by,
774
+ resize_mode,
775
+ inpaint_full_res,
776
+ inpaint_full_res_padding,
777
+ inpainting_mask_invert,
778
+ img2img_batch_input_dir,
779
+ img2img_batch_output_dir,
780
+ img2img_batch_inpaint_mask_dir,
781
+ override_settings,
782
+ img2img_batch_use_png_info,
783
+ img2img_batch_png_info_props,
784
+ img2img_batch_png_info_dir,
785
+ img2img_batch_source_type,
786
+ img2img_batch_upload,
787
+ ] + custom_inputs,
788
+ outputs=[
789
+ output_panel.gallery,
790
+ output_panel.generation_info,
791
+ output_panel.infotext,
792
+ output_panel.html_log,
793
+ ],
794
+ show_progress=False,
795
+ )
796
+
797
+ interrogate_args = dict(
798
+ _js="get_img2img_tab_index",
799
+ inputs=[
800
+ dummy_component,
801
+ img2img_batch_input_dir,
802
+ img2img_batch_output_dir,
803
+ init_img,
804
+ sketch,
805
+ init_img_with_mask,
806
+ inpaint_color_sketch,
807
+ init_img_inpaint,
808
+ ],
809
+ outputs=[toprow.prompt, dummy_component],
810
+ )
811
+
812
+ toprow.prompt.submit(**img2img_args)
813
+ toprow.submit.click(**img2img_args)
814
+
815
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
816
+
817
+ detect_image_size_btn.click(
818
+ fn=lambda w, h, _: (w or gr.update(), h or gr.update()),
819
+ _js="currentImg2imgSourceResolution",
820
+ inputs=[dummy_component, dummy_component, dummy_component],
821
+ outputs=[width, height],
822
+ show_progress=False,
823
+ )
824
+
825
+ toprow.restore_progress_button.click(
826
+ fn=progress.restore_progress,
827
+ _js="restoreProgressImg2img",
828
+ inputs=[dummy_component],
829
+ outputs=[
830
+ output_panel.gallery,
831
+ output_panel.generation_info,
832
+ output_panel.infotext,
833
+ output_panel.html_log,
834
+ ],
835
+ show_progress=False,
836
+ )
837
+
838
+ toprow.button_interrogate.click(
839
+ fn=lambda *args: process_interrogate(interrogate, *args),
840
+ **interrogate_args,
841
+ )
842
+
843
+ toprow.button_deepbooru.click(
844
+ fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
845
+ **interrogate_args,
846
+ )
847
+
848
+ steps = scripts.scripts_img2img.script('Sampler').steps
849
+
850
+ toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
851
+ toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
852
+ toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
853
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
854
+
855
+ img2img_paste_fields = [
856
+ (toprow.prompt, "Prompt"),
857
+ (toprow.negative_prompt, "Negative prompt"),
858
+ (cfg_scale, "CFG scale"),
859
+ (image_cfg_scale, "Image CFG scale"),
860
+ (width, "Size-1"),
861
+ (height, "Size-2"),
862
+ (batch_size, "Batch size"),
863
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
864
+ (denoising_strength, "Denoising strength"),
865
+ (mask_blur, "Mask blur"),
866
+ (inpainting_mask_invert, 'Mask mode'),
867
+ (inpainting_fill, 'Masked content'),
868
+ (inpaint_full_res, 'Inpaint area'),
869
+ (inpaint_full_res_padding, 'Masked area padding'),
870
+ *scripts.scripts_img2img.infotext_fields
871
+ ]
872
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
873
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
874
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
875
+ paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
876
+ ))
877
+
878
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
879
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery)
880
+
881
+ extra_tabs.__exit__()
882
+
883
+ scripts.scripts_current = None
884
+
885
+ with gr.Blocks(analytics_enabled=False) as extras_interface:
886
+ ui_postprocessing.create_ui()
887
+
888
+ with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
889
+ with ResizeHandleRow(equal_height=False):
890
+ with gr.Column(variant='panel'):
891
+ image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
892
+
893
+ with gr.Column(variant='panel'):
894
+ html = gr.HTML()
895
+ generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
896
+ html2 = gr.HTML()
897
+ with gr.Row():
898
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
899
+
900
+ for tabname, button in buttons.items():
901
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
902
+ paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
903
+ ))
904
+
905
+ image.change(
906
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
907
+ inputs=[image],
908
+ outputs=[html, generation_info, html2],
909
+ )
910
+
911
+ modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
912
+
913
+ with gr.Blocks(analytics_enabled=False) as train_interface:
914
+ with gr.Row(equal_height=False):
915
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
916
+
917
+ with ResizeHandleRow(variant="compact", equal_height=False):
918
+ with gr.Tabs(elem_id="train_tabs"):
919
+
920
+ with gr.Tab(label="Create embedding", id="create_embedding"):
921
+ new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
922
+ initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
923
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
924
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
925
+
926
+ with gr.Row():
927
+ with gr.Column(scale=3):
928
+ gr.HTML(value="")
929
+
930
+ with gr.Column():
931
+ create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
932
+
933
+ with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
934
+ new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
935
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
936
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
937
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
938
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
939
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
940
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
941
+ new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
942
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
943
+
944
+ with gr.Row():
945
+ with gr.Column(scale=3):
946
+ gr.HTML(value="")
947
+
948
+ with gr.Column():
949
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
950
+
951
+ def get_textual_inversion_template_names():
952
+ return sorted(textual_inversion.textual_inversion_templates)
953
+
954
+ with gr.Tab(label="Train", id="train"):
955
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
956
+ with FormRow():
957
+ train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
958
+ create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
959
+
960
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
961
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
962
+
963
+ with FormRow():
964
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
965
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
966
+
967
+ with FormRow():
968
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
969
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
970
+
971
+ with FormRow():
972
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
973
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
974
+
975
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
976
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
977
+
978
+ with FormRow():
979
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
980
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
981
+
982
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
983
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
984
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
985
+ steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
986
+
987
+ with FormRow():
988
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
989
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
990
+
991
+ use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")
992
+
993
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
994
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
995
+
996
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
997
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
998
+
999
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
1000
+
1001
+ with gr.Row():
1002
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
1003
+ interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
1004
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
1005
+
1006
+ params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
1007
+
1008
+ script_callbacks.ui_train_tabs_callback(params)
1009
+
1010
+ with gr.Column(elem_id='ti_gallery_container'):
1011
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
1012
+ gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
1013
+ gr.HTML(elem_id="ti_progress", value="")
1014
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
1015
+
1016
+ create_embedding.click(
1017
+ fn=textual_inversion_ui.create_embedding,
1018
+ inputs=[
1019
+ new_embedding_name,
1020
+ initialization_text,
1021
+ nvpt,
1022
+ overwrite_old_embedding,
1023
+ ],
1024
+ outputs=[
1025
+ train_embedding_name,
1026
+ ti_output,
1027
+ ti_outcome,
1028
+ ]
1029
+ )
1030
+
1031
+ create_hypernetwork.click(
1032
+ fn=hypernetworks_ui.create_hypernetwork,
1033
+ inputs=[
1034
+ new_hypernetwork_name,
1035
+ new_hypernetwork_sizes,
1036
+ overwrite_old_hypernetwork,
1037
+ new_hypernetwork_layer_structure,
1038
+ new_hypernetwork_activation_func,
1039
+ new_hypernetwork_initialization_option,
1040
+ new_hypernetwork_add_layer_norm,
1041
+ new_hypernetwork_use_dropout,
1042
+ new_hypernetwork_dropout_structure
1043
+ ],
1044
+ outputs=[
1045
+ train_hypernetwork_name,
1046
+ ti_output,
1047
+ ti_outcome,
1048
+ ]
1049
+ )
1050
+
1051
+ train_embedding.click(
1052
+ fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
1053
+ _js="start_training_textual_inversion",
1054
+ inputs=[
1055
+ dummy_component,
1056
+ train_embedding_name,
1057
+ embedding_learn_rate,
1058
+ batch_size,
1059
+ gradient_step,
1060
+ dataset_directory,
1061
+ log_directory,
1062
+ training_width,
1063
+ training_height,
1064
+ varsize,
1065
+ steps,
1066
+ clip_grad_mode,
1067
+ clip_grad_value,
1068
+ shuffle_tags,
1069
+ tag_drop_out,
1070
+ latent_sampling_method,
1071
+ use_weight,
1072
+ create_image_every,
1073
+ save_embedding_every,
1074
+ template_file,
1075
+ save_image_with_stored_embedding,
1076
+ preview_from_txt2img,
1077
+ *txt2img_preview_params,
1078
+ ],
1079
+ outputs=[
1080
+ ti_output,
1081
+ ti_outcome,
1082
+ ]
1083
+ )
1084
+
1085
+ train_hypernetwork.click(
1086
+ fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
1087
+ _js="start_training_textual_inversion",
1088
+ inputs=[
1089
+ dummy_component,
1090
+ train_hypernetwork_name,
1091
+ hypernetwork_learn_rate,
1092
+ batch_size,
1093
+ gradient_step,
1094
+ dataset_directory,
1095
+ log_directory,
1096
+ training_width,
1097
+ training_height,
1098
+ varsize,
1099
+ steps,
1100
+ clip_grad_mode,
1101
+ clip_grad_value,
1102
+ shuffle_tags,
1103
+ tag_drop_out,
1104
+ latent_sampling_method,
1105
+ use_weight,
1106
+ create_image_every,
1107
+ save_embedding_every,
1108
+ template_file,
1109
+ preview_from_txt2img,
1110
+ *txt2img_preview_params,
1111
+ ],
1112
+ outputs=[
1113
+ ti_output,
1114
+ ti_outcome,
1115
+ ]
1116
+ )
1117
+
1118
+ interrupt_training.click(
1119
+ fn=lambda: shared.state.interrupt(),
1120
+ inputs=[],
1121
+ outputs=[],
1122
+ )
1123
+
1124
+ loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
1125
+ ui_settings_from_file = loadsave.ui_settings.copy()
1126
+
1127
+ settings.create_ui(loadsave, dummy_component)
1128
+
1129
+ interfaces = [
1130
+ (txt2img_interface, "txt2img", "txt2img"),
1131
+ (img2img_interface, "img2img", "img2img"),
1132
+ (extras_interface, "Extras", "extras"),
1133
+ (pnginfo_interface, "PNG Info", "pnginfo"),
1134
+ (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
1135
+ (train_interface, "Train", "train"),
1136
+ ]
1137
+
1138
+ interfaces += script_callbacks.ui_tabs_callback()
1139
+ interfaces += [(settings.interface, "Settings", "settings")]
1140
+
1141
+ extensions_interface = ui_extensions.create_ui()
1142
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
1143
+
1144
+ shared.tab_names = []
1145
+ for _interface, label, _ifid in interfaces:
1146
+ shared.tab_names.append(label)
1147
+
1148
+ with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
1149
+ settings.add_quicksettings()
1150
+
1151
+ parameters_copypaste.connect_paste_params_buttons()
1152
+
1153
+ with gr.Tabs(elem_id="tabs") as tabs:
1154
+ tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
1155
+ sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))
1156
+
1157
+ for interface, label, ifid in sorted_interfaces:
1158
+ if label in shared.opts.hidden_tabs:
1159
+ continue
1160
+ with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
1161
+ interface.render()
1162
+
1163
+ if ifid not in ["extensions", "settings"]:
1164
+ loadsave.add_block(interface, ifid)
1165
+
1166
+ loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
1167
+
1168
+ loadsave.setup_ui()
1169
+
1170
+ if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
1171
+ gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
1172
+
1173
+ footer = shared.html("footer.html")
1174
+ footer = footer.format(versions=versions_html(), api_docs="/docs" if shared.cmd_opts.api else "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API")
1175
+ gr.HTML(footer, elem_id="footer")
1176
+
1177
+ settings.add_functionality(demo)
1178
+
1179
+ update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
1180
+ settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
1181
+ demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
1182
+
1183
+ modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
1184
+
1185
+ if ui_settings_from_file != loadsave.ui_settings:
1186
+ loadsave.dump_defaults()
1187
+ demo.ui_loadsave = loadsave
1188
+
1189
+ return demo
1190
+
1191
+
1192
+ def versions_html():
1193
+ import torch
1194
+ import launch
1195
+
1196
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
1197
+ commit = launch.commit_hash()
1198
+ tag = launch.git_tag()
1199
+
1200
+ if shared.xformers_available:
1201
+ import xformers
1202
+ xformers_version = xformers.__version__
1203
+ else:
1204
+ xformers_version = "N/A"
1205
+
1206
+ return f"""
1207
+ version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
1208
+ &#x2000;•&#x2000;
1209
+ python: <span title="{sys.version}">{python_version}</span>
1210
+ &#x2000;•&#x2000;
1211
+ torch: {getattr(torch, '__long_version__',torch.__version__)}
1212
+ &#x2000;•&#x2000;
1213
+ xformers: {xformers_version}
1214
+ &#x2000;•&#x2000;
1215
+ gradio: {gr.__version__}
1216
+ &#x2000;•&#x2000;
1217
+ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
1218
+ """
1219
+
1220
+
1221
+ def setup_ui_api(app):
1222
+ from pydantic import BaseModel, Field
1223
+
1224
+ class QuicksettingsHint(BaseModel):
1225
+ name: str = Field(title="Name of the quicksettings field")
1226
+ label: str = Field(title="Label of the quicksettings field")
1227
+
1228
+ def quicksettings_hint():
1229
+ return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
1230
+
1231
+ app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
1232
+
1233
+ app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
1234
+
1235
+ app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])
1236
+
1237
+ def download_sysinfo(attachment=False):
1238
+ from fastapi.responses import PlainTextResponse
1239
+
1240
+ text = sysinfo.get()
1241
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
1242
+
1243
+ return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
1244
+
1245
+ app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
1246
+ app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
1247
+
1248
+ import fastapi.staticfiles
1249
+ app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets")