qpqpqpqpqpqp commited on
Commit
1df0aba
·
verified ·
1 Parent(s): 51c7f4a

Delete hm/.ipynb_checkpoints/processing-checkpoint.py

Browse files
hm/.ipynb_checkpoints/processing-checkpoint.py DELETED
@@ -1,1838 +0,0 @@
1
- from __future__ import annotations
2
- import json
3
- import logging
4
- import math
5
- import os
6
- import sys
7
- import hashlib
8
- from dataclasses import dataclass, field
9
-
10
- import torch
11
- import numpy as np
12
- from PIL import Image, ImageOps
13
- import random
14
- import cv2
15
- from skimage import exposure
16
- from typing import Any
17
-
18
- import modules.sd_hijack
19
- from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
20
- from modules.rng import slerp # noqa: F401
21
- from modules.sd_hijack import model_hijack
22
- from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
23
- from modules.shared import opts, cmd_opts, state
24
- import modules.shared as shared
25
- import modules.paths as paths
26
- import modules.face_restoration
27
- import modules.images as images
28
- import modules.styles
29
- import modules.sd_models as sd_models
30
- import modules.sd_vae as sd_vae
31
- from ldm.data.util import AddMiDaS
32
- from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
33
-
34
- from einops import repeat, rearrange
35
- from blendmodes.blend import blendLayers, BlendType
36
-
37
-
38
- # some of those options should not be changed at all because they would break the model, so I removed them from options.
39
- opt_C = 4
40
- opt_f = 8
41
-
42
-
43
- def setup_color_correction(image):
44
- logging.info("Calibrating color correction.")
45
- correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
46
- return correction_target
47
-
48
-
49
- def apply_color_correction(correction, original_image):
50
- logging.info("Applying color correction.")
51
- image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
52
- cv2.cvtColor(
53
- np.asarray(original_image),
54
- cv2.COLOR_RGB2LAB
55
- ),
56
- correction,
57
- channel_axis=2
58
- ), cv2.COLOR_LAB2RGB).astype("uint8"))
59
-
60
- image = blendLayers(image, original_image, BlendType.LUMINOSITY)
61
-
62
- return image.convert('RGB')
63
-
64
-
65
- def uncrop(image, dest_size, paste_loc):
66
- x, y, w, h = paste_loc
67
- base_image = Image.new('RGBA', dest_size)
68
- image = images.resize_image(1, image, w, h)
69
- base_image.paste(image, (x, y))
70
- image = base_image
71
-
72
- return image
73
-
74
-
75
- def apply_overlay(image, paste_loc, overlay):
76
- if overlay is None:
77
- return image, image.copy()
78
-
79
- if paste_loc is not None:
80
- image = uncrop(image, (overlay.width, overlay.height), paste_loc)
81
-
82
- original_denoised_image = image.copy()
83
-
84
- image = image.convert('RGBA')
85
- image.alpha_composite(overlay)
86
- image = image.convert('RGB')
87
-
88
- return image, original_denoised_image
89
-
90
- def create_binary_mask(image, round=True):
91
- if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
92
- if round:
93
- image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
94
- else:
95
- image = image.split()[-1].convert("L")
96
- else:
97
- image = image.convert('L')
98
- return image
99
-
100
- def txt2img_image_conditioning(sd_model, x, width, height):
101
- if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
102
-
103
- # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
104
- image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
105
- image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
106
-
107
- # Add the fake full 1s mask to the first dimension.
108
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
109
- image_conditioning = image_conditioning.to(x.dtype)
110
-
111
- return image_conditioning
112
-
113
- elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
114
-
115
- return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
116
-
117
- else:
118
- if sd_model.is_sdxl_inpaint:
119
- # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
120
- image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
121
- image_conditioning = images_tensor_to_samples(image_conditioning,
122
- approximation_indexes.get(opts.sd_vae_encode_method))
123
-
124
- # Add the fake full 1s mask to the first dimension.
125
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
126
- image_conditioning = image_conditioning.to(x.dtype)
127
-
128
- return image_conditioning
129
-
130
- # Dummy zero conditioning if we're not using inpainting or unclip models.
131
- # Still takes up a bit of memory, but no encoder call.
132
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
133
- return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
134
-
135
-
136
- @dataclass(repr=False)
137
- class StableDiffusionProcessing:
138
- sd_model: object = None
139
- outpath_samples: str = None
140
- outpath_grids: str = None
141
- prompt: str = ""
142
- prompt_for_display: str = None
143
- negative_prompt: str = ""
144
- styles: list[str] = None
145
- seed: int = -1
146
- subseed: int = -1
147
- subseed_strength: float = 0
148
- seed_resize_from_h: int = -1
149
- seed_resize_from_w: int = -1
150
- seed_enable_extras: bool = True
151
- sampler_name: str = None
152
- scheduler: str = None
153
- batch_size: int = 1
154
- n_iter: int = 1
155
- steps: int = 50
156
- cfg_scale: float = 7.0
157
- width: int = 512
158
- height: int = 512
159
- restore_faces: bool = None
160
- tiling: bool = None
161
- do_not_save_samples: bool = False
162
- do_not_save_grid: bool = False
163
- extra_generation_params: dict[str, Any] = None
164
- overlay_images: list = None
165
- eta: float = None
166
- do_not_reload_embeddings: bool = False
167
- denoising_strength: float = None
168
- ddim_discretize: str = None
169
- s_min_uncond: float = None
170
- s_churn: float = None
171
- s_tmax: float = None
172
- s_tmin: float = None
173
- s_noise: float = None
174
- override_settings: dict[str, Any] = None
175
- override_settings_restore_afterwards: bool = True
176
- sampler_index: int = None
177
- refiner_checkpoint: str = None
178
- refiner_switch_at: float = None
179
- token_merging_ratio = 0
180
- token_merging_ratio_hr = 0
181
- disable_extra_networks: bool = False
182
- firstpass_image: Image = None
183
-
184
- scripts_value: scripts.ScriptRunner = field(default=None, init=False)
185
- script_args_value: list = field(default=None, init=False)
186
- scripts_setup_complete: bool = field(default=False, init=False)
187
-
188
- cached_uc = [None, None]
189
- cached_c = [None, None]
190
-
191
- comments: dict = None
192
- sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
193
- is_using_inpainting_conditioning: bool = field(default=False, init=False)
194
- paste_to: tuple | None = field(default=None, init=False)
195
-
196
- is_hr_pass: bool = field(default=False, init=False)
197
-
198
- c: tuple = field(default=None, init=False)
199
- uc: tuple = field(default=None, init=False)
200
-
201
- rng: rng.ImageRNG | None = field(default=None, init=False)
202
- step_multiplier: int = field(default=1, init=False)
203
- color_corrections: list = field(default=None, init=False)
204
-
205
- all_prompts: list = field(default=None, init=False)
206
- all_negative_prompts: list = field(default=None, init=False)
207
- all_seeds: list = field(default=None, init=False)
208
- all_subseeds: list = field(default=None, init=False)
209
- iteration: int = field(default=0, init=False)
210
- main_prompt: str = field(default=None, init=False)
211
- main_negative_prompt: str = field(default=None, init=False)
212
-
213
- prompts: list = field(default=None, init=False)
214
- negative_prompts: list = field(default=None, init=False)
215
- seeds: list = field(default=None, init=False)
216
- subseeds: list = field(default=None, init=False)
217
- extra_network_data: dict = field(default=None, init=False)
218
-
219
- user: str = field(default=None, init=False)
220
-
221
- sd_model_name: str = field(default=None, init=False)
222
- sd_model_hash: str = field(default=None, init=False)
223
- sd_vae_name: str = field(default=None, init=False)
224
- sd_vae_hash: str = field(default=None, init=False)
225
-
226
- is_api: bool = field(default=False, init=False)
227
-
228
- def __post_init__(self):
229
- if self.sampler_index is not None:
230
- print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
231
-
232
- self.comments = {}
233
-
234
- if self.styles is None:
235
- self.styles = []
236
-
237
- self.sampler_noise_scheduler_override = None
238
-
239
- self.extra_generation_params = self.extra_generation_params or {}
240
- self.override_settings = self.override_settings or {}
241
- self.script_args = self.script_args or {}
242
-
243
- self.refiner_checkpoint_info = None
244
-
245
- if not self.seed_enable_extras:
246
- self.subseed = -1
247
- self.subseed_strength = 0
248
- self.seed_resize_from_h = 0
249
- self.seed_resize_from_w = 0
250
-
251
- self.cached_uc = StableDiffusionProcessing.cached_uc
252
- self.cached_c = StableDiffusionProcessing.cached_c
253
-
254
- def fill_fields_from_opts(self):
255
- self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
256
- self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
257
- self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
258
- self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
259
- self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
260
-
261
- @property
262
- def sd_model(self):
263
- return shared.sd_model
264
-
265
- @sd_model.setter
266
- def sd_model(self, value):
267
- pass
268
-
269
- @property
270
- def scripts(self):
271
- return self.scripts_value
272
-
273
- @scripts.setter
274
- def scripts(self, value):
275
- self.scripts_value = value
276
-
277
- if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
278
- self.setup_scripts()
279
-
280
- @property
281
- def script_args(self):
282
- return self.script_args_value
283
-
284
- @script_args.setter
285
- def script_args(self, value):
286
- self.script_args_value = value
287
-
288
- if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
289
- self.setup_scripts()
290
-
291
- def setup_scripts(self):
292
- self.scripts_setup_complete = True
293
-
294
- self.scripts.setup_scrips(self, is_ui=not self.is_api)
295
-
296
- def comment(self, text):
297
- self.comments[text] = 1
298
-
299
- def txt2img_image_conditioning(self, x, width=None, height=None):
300
- self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
301
-
302
- return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
303
-
304
- def depth2img_image_conditioning(self, source_image):
305
- # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
306
- transformer = AddMiDaS(model_type="dpt_hybrid")
307
- transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
308
- midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
309
- midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
310
-
311
- conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
312
- conditioning = torch.nn.functional.interpolate(
313
- self.sd_model.depth_model(midas_in),
314
- size=conditioning_image.shape[2:],
315
- mode="bicubic",
316
- align_corners=False,
317
- )
318
-
319
- (depth_min, depth_max) = torch.aminmax(conditioning)
320
- conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
321
- return conditioning
322
-
323
- def edit_image_conditioning(self, source_image):
324
- conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
325
-
326
- return conditioning_image
327
-
328
- def unclip_image_conditioning(self, source_image):
329
- c_adm = self.sd_model.embedder(source_image)
330
- if self.sd_model.noise_augmentor is not None:
331
- noise_level = 0 # TODO: Allow other noise levels?
332
- c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
333
- c_adm = torch.cat((c_adm, noise_level_emb), 1)
334
- return c_adm
335
-
336
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
337
- self.is_using_inpainting_conditioning = True
338
-
339
- # Handle the different mask inputs
340
- if image_mask is not None:
341
- if torch.is_tensor(image_mask):
342
- conditioning_mask = image_mask
343
- else:
344
- conditioning_mask = np.array(image_mask.convert("L"))
345
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
346
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
347
-
348
- if round_image_mask:
349
- # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
350
- conditioning_mask = torch.round(conditioning_mask)
351
-
352
- else:
353
- conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
354
-
355
- # Create another latent image, this time with a masked version of the original input.
356
- # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
357
- conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
358
- conditioning_image = torch.lerp(
359
- source_image,
360
- source_image * (1.0 - conditioning_mask),
361
- getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
362
- )
363
-
364
- # Encode the new masked image using first stage of network.
365
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
366
-
367
- # Create the concatenated conditioning tensor to be fed to `c_concat`
368
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
369
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
370
- image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
371
- image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
372
-
373
- return image_conditioning
374
-
375
- def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
376
- source_image = devices.cond_cast_float(source_image)
377
-
378
- # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
379
- # identify itself with a field common to all models. The conditioning_key is also hybrid.
380
- if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
381
- return self.depth2img_image_conditioning(source_image)
382
-
383
- if self.sd_model.cond_stage_key == "edit":
384
- return self.edit_image_conditioning(source_image)
385
-
386
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
387
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
388
-
389
- if self.sampler.conditioning_key == "crossattn-adm":
390
- return self.unclip_image_conditioning(source_image)
391
-
392
- if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:
393
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
394
-
395
- # Dummy zero conditioning if we're not using inpainting or depth model.
396
- return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
397
-
398
- def init(self, all_prompts, all_seeds, all_subseeds):
399
- pass
400
-
401
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
402
- raise NotImplementedError()
403
-
404
- def close(self):
405
- self.sampler = None
406
- self.c = None
407
- self.uc = None
408
- if not opts.persistent_cond_cache:
409
- StableDiffusionProcessing.cached_c = [None, None]
410
- StableDiffusionProcessing.cached_uc = [None, None]
411
-
412
- def get_token_merging_ratio(self, for_hr=False):
413
- if for_hr:
414
- return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
415
-
416
- return self.token_merging_ratio or opts.token_merging_ratio
417
-
418
- def setup_prompts(self):
419
- if isinstance(self.prompt,list):
420
- self.all_prompts = self.prompt
421
- elif isinstance(self.negative_prompt, list):
422
- self.all_prompts = [self.prompt] * len(self.negative_prompt)
423
- else:
424
- self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
425
-
426
- if isinstance(self.negative_prompt, list):
427
- self.all_negative_prompts = self.negative_prompt
428
- else:
429
- self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)
430
-
431
- if len(self.all_prompts) != len(self.all_negative_prompts):
432
- raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
433
-
434
- self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
435
- self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
436
-
437
- self.main_prompt = self.all_prompts[0]
438
- self.main_negative_prompt = self.all_negative_prompts[0]
439
-
440
- def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
441
- """Returns parameters that invalidate the cond cache if changed"""
442
-
443
- return (
444
- required_prompts,
445
- steps,
446
- hires_steps,
447
- use_old_scheduling,
448
- opts.CLIP_stop_at_last_layers,
449
- shared.sd_model.sd_checkpoint_info,
450
- extra_network_data,
451
- opts.sdxl_crop_left,
452
- opts.sdxl_crop_top,
453
- self.width,
454
- self.height,
455
- opts.fp8_storage,
456
- opts.cache_fp16_weight,
457
- opts.emphasis,
458
- )
459
-
460
- def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
461
- """
462
- Returns the result of calling function(shared.sd_model, required_prompts, steps)
463
- using a cache to store the result if the same arguments have been used before.
464
-
465
- cache is an array containing two elements. The first element is a tuple
466
- representing the previously used arguments, or None if no arguments
467
- have been used before. The second element is where the previously
468
- computed result is stored.
469
-
470
- caches is a list with items described above.
471
- """
472
-
473
- if shared.opts.use_old_scheduling:
474
- old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
475
- new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
476
- if old_schedules != new_schedules:
477
- self.extra_generation_params["Old prompt editing timelines"] = True
478
-
479
- cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
480
-
481
- for cache in caches:
482
- if cache[0] is not None and cached_params == cache[0]:
483
- return cache[1]
484
-
485
- cache = caches[0]
486
-
487
- with devices.autocast():
488
- cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
489
-
490
- cache[0] = cached_params
491
- return cache[1]
492
-
493
- def setup_conds(self):
494
- prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
495
- negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
496
-
497
- sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
498
- total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
499
- self.step_multiplier = total_steps // self.steps
500
- self.firstpass_steps = total_steps
501
-
502
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
503
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
504
-
505
- def get_conds(self):
506
- return self.c, self.uc
507
-
508
- def parse_extra_network_prompts(self):
509
- self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
510
-
511
- def save_samples(self) -> bool:
512
- """Returns whether generated images need to be written to disk"""
513
- return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
514
-
515
-
516
- class Processed:
517
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
518
- self.images = images_list
519
- self.prompt = p.prompt
520
- self.negative_prompt = p.negative_prompt
521
- self.seed = seed
522
- self.subseed = subseed
523
- self.subseed_strength = p.subseed_strength
524
- self.info = info
525
- self.comments = "".join(f"{comment}\n" for comment in p.comments)
526
- self.width = p.width
527
- self.height = p.height
528
- self.sampler_name = p.sampler_name
529
- self.cfg_scale = p.cfg_scale
530
- self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
531
- self.steps = p.steps
532
- self.batch_size = p.batch_size
533
- self.restore_faces = p.restore_faces
534
- self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
535
- self.sd_model_name = p.sd_model_name
536
- self.sd_model_hash = p.sd_model_hash
537
- self.sd_vae_name = p.sd_vae_name
538
- self.sd_vae_hash = p.sd_vae_hash
539
- self.seed_resize_from_w = p.seed_resize_from_w
540
- self.seed_resize_from_h = p.seed_resize_from_h
541
- self.denoising_strength = getattr(p, 'denoising_strength', None)
542
- self.extra_generation_params = p.extra_generation_params
543
- self.index_of_first_image = index_of_first_image
544
- self.styles = p.styles
545
- self.job_timestamp = state.job_timestamp
546
- self.clip_skip = opts.CLIP_stop_at_last_layers
547
- self.token_merging_ratio = p.token_merging_ratio
548
- self.token_merging_ratio_hr = p.token_merging_ratio_hr
549
-
550
- self.eta = p.eta
551
- self.ddim_discretize = p.ddim_discretize
552
- self.s_churn = p.s_churn
553
- self.s_tmin = p.s_tmin
554
- self.s_tmax = p.s_tmax
555
- self.s_noise = p.s_noise
556
- self.s_min_uncond = p.s_min_uncond
557
- self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
558
- self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
559
- self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
560
- self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
561
- self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
562
- self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
563
-
564
- self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
565
- self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
566
- self.all_seeds = all_seeds or p.all_seeds or [self.seed]
567
- self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
568
- self.infotexts = infotexts or [info] * len(images_list)
569
- self.version = program_version()
570
-
571
- def js(self):
572
- obj = {
573
- "prompt": self.all_prompts[0],
574
- "all_prompts": self.all_prompts,
575
- "negative_prompt": self.all_negative_prompts[0],
576
- "all_negative_prompts": self.all_negative_prompts,
577
- "seed": self.seed,
578
- "all_seeds": self.all_seeds,
579
- "subseed": self.subseed,
580
- "all_subseeds": self.all_subseeds,
581
- "subseed_strength": self.subseed_strength,
582
- "width": self.width,
583
- "height": self.height,
584
- "sampler_name": self.sampler_name,
585
- "cfg_scale": self.cfg_scale,
586
- "steps": self.steps,
587
- "batch_size": self.batch_size,
588
- "restore_faces": self.restore_faces,
589
- "face_restoration_model": self.face_restoration_model,
590
- "sd_model_name": self.sd_model_name,
591
- "sd_model_hash": self.sd_model_hash,
592
- "sd_vae_name": self.sd_vae_name,
593
- "sd_vae_hash": self.sd_vae_hash,
594
- "seed_resize_from_w": self.seed_resize_from_w,
595
- "seed_resize_from_h": self.seed_resize_from_h,
596
- "denoising_strength": self.denoising_strength,
597
- "extra_generation_params": self.extra_generation_params,
598
- "index_of_first_image": self.index_of_first_image,
599
- "infotexts": self.infotexts,
600
- "styles": self.styles,
601
- "job_timestamp": self.job_timestamp,
602
- "clip_skip": self.clip_skip,
603
- "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
604
- "version": self.version,
605
- }
606
-
607
- return json.dumps(obj, default=lambda o: None)
608
-
609
- def infotext(self, p: StableDiffusionProcessing, index):
610
- return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
611
-
612
- def get_token_merging_ratio(self, for_hr=False):
613
- return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
614
-
615
-
616
- def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
617
- g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
618
- return g.next()
619
-
620
-
621
- class DecodedSamples(list):
622
- already_decoded = True
623
-
624
-
625
- def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
626
- samples = DecodedSamples()
627
-
628
- if check_for_nans:
629
- devices.test_for_nans(batch, "unet")
630
-
631
- for i in range(batch.shape[0]):
632
- sample = decode_first_stage(model, batch[i:i + 1])[0]
633
-
634
- if check_for_nans:
635
-
636
- try:
637
- devices.test_for_nans(sample, "vae")
638
- except devices.NansException as e:
639
- if shared.opts.auto_vae_precision_bfloat16:
640
- autofix_dtype = torch.bfloat16
641
- autofix_dtype_text = "bfloat16"
642
- autofix_dtype_setting = "Automatically convert VAE to bfloat16"
643
- autofix_dtype_comment = ""
644
- elif shared.opts.auto_vae_precision:
645
- autofix_dtype = torch.float32
646
- autofix_dtype_text = "32-bit float"
647
- autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
648
- autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
649
- else:
650
- raise e
651
-
652
- if devices.dtype_vae == autofix_dtype:
653
- raise e
654
-
655
- errors.print_error_explanation(
656
- "A tensor with all NaNs was produced in VAE.\n"
657
- f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
658
- f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
659
- )
660
-
661
- devices.dtype_vae = autofix_dtype
662
- model.first_stage_model.to(devices.dtype_vae)
663
- batch = batch.to(devices.dtype_vae)
664
-
665
- sample = decode_first_stage(model, batch[i:i + 1])[0]
666
-
667
- if target_device is not None:
668
- sample = sample.to(target_device)
669
-
670
- samples.append(sample)
671
-
672
- return samples
673
-
674
-
675
- def get_fixed_seed(seed):
676
- if seed == '' or seed is None:
677
- seed = -1
678
- elif isinstance(seed, str):
679
- try:
680
- seed = int(seed)
681
- except Exception:
682
- seed = -1
683
-
684
- if seed == -1:
685
- return int(random.randrange(4294967294))
686
-
687
- return seed
688
-
689
-
690
- def fix_seed(p):
691
- p.seed = get_fixed_seed(p.seed)
692
- p.subseed = get_fixed_seed(p.subseed)
693
-
694
-
695
- def program_version():
696
- import launch
697
-
698
- res = launch.git_tag()
699
- if res == "<none>":
700
- res = None
701
-
702
- return res
703
-
704
-
705
- def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
706
- """
707
- this function is used to generate the infotext that is stored in the generated images, it's contains the parameters that are required to generate the imagee
708
- Args:
709
- p: StableDiffusionProcessing
710
- all_prompts: list[str]
711
- all_seeds: list[int]
712
- all_subseeds: list[int]
713
- comments: list[str]
714
- iteration: int
715
- position_in_batch: int
716
- use_main_prompt: bool
717
- index: int
718
- all_negative_prompts: list[str]
719
-
720
- Returns: str
721
-
722
- Extra generation params
723
- p.extra_generation_params dictionary allows for additional parameters to be added to the infotext
724
- this can be use by the base webui or extensions.
725
- To add a new entry, add a new key value pair, the dictionary key will be used as the key of the parameter in the infotext
726
- the value generation_params can be defined as:
727
- - str | None
728
- - List[str|None]
729
- - callable func(**kwargs) -> str | None
730
-
731
- When defined as a string, it will be used as without extra processing; this is this most common use case.
732
-
733
- Defining as a list allows for parameter that changes across images in the job, for example, the 'Seed' parameter.
734
- The list should have the same length as the total number of images in the entire job.
735
-
736
- Defining as a callable function allows parameter cannot be generated earlier or when extra logic is required.
737
- For example 'Hires prompt', due to reasons the hr_prompt might be changed by process in the pipeline or extensions
738
- and may vary across different images, defining as a static string or list would not work.
739
-
740
- The function takes locals() as **kwargs, as such will have access to variables like 'p' and 'index'.
741
- the base signature of the function should be:
742
- func(**kwargs) -> str | None
743
- optionally it can have additional arguments that will be used in the function:
744
- func(p, index, **kwargs) -> str | None
745
- note: for better future compatibility even though this function will have access to all variables in the locals(),
746
- it is recommended to only use the arguments present in the function signature of create_infotext.
747
- For actual implementation examples, see StableDiffusionProcessingTxt2Img.init > get_hr_prompt.
748
- """
749
-
750
- if use_main_prompt:
751
- index = 0
752
- elif index is None:
753
- index = position_in_batch + iteration * p.batch_size
754
-
755
- if all_negative_prompts is None:
756
- all_negative_prompts = p.all_negative_prompts
757
-
758
- clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
759
- enable_hr = getattr(p, 'enable_hr', False)
760
- token_merging_ratio = p.get_token_merging_ratio()
761
- token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
762
-
763
- prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
764
- negative_prompt = p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]
765
-
766
- uses_ensd = opts.eta_noise_seed_delta != 0
767
- if uses_ensd:
768
- uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
769
-
770
- generation_params = {
771
- "Steps": p.steps,
772
- "Sampler": p.sampler_name,
773
- "Schedule type": p.scheduler,
774
- "CFG scale": p.cfg_scale,
775
- "Image CFG scale": getattr(p, 'image_cfg_scale', None),
776
- "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
777
- "Face restoration": opts.face_restoration_model if p.restore_faces else None,
778
- "Size": f"{p.width}x{p.height}",
779
- "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
780
- "Model": p.sd_model_name if opts.add_model_name_to_info else None,
781
- "FP8 weight": opts.fp8_storage if devices.fp8 else None,
782
- "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
783
- "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
784
- "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
785
- "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
786
- "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
787
- "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
788
- "Denoising strength": p.extra_generation_params.get("Denoising strength"),
789
- "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
790
- "Clip skip": None if clip_skip <= 1 else clip_skip,
791
- "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
792
- "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
793
- "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
794
- "Init image hash": getattr(p, 'init_img_hash', None),
795
- "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
796
- "Tiling": "True" if p.tiling else None,
797
- **p.extra_generation_params,
798
- "Version": program_version() if opts.add_version_to_infotext else None,
799
- "User": p.user if opts.add_user_name_to_info else None,
800
- }
801
-
802
- for key, value in generation_params.items():
803
- try:
804
- if isinstance(value, list):
805
- generation_params[key] = value[index]
806
- elif callable(value):
807
- generation_params[key] = value(**locals())
808
- except Exception:
809
- errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
810
- generation_params[key] = None
811
-
812
- generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
813
-
814
- negative_prompt_text = f"\nNegative prompt: {negative_prompt}" if negative_prompt else ""
815
-
816
- return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
817
-
818
-
819
- def process_images(p: StableDiffusionProcessing) -> Processed:
820
- if p.scripts is not None:
821
- p.scripts.before_process(p)
822
-
823
- stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
824
-
825
- try:
826
- # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
827
- # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
828
- if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
829
- p.override_settings.pop('sd_model_checkpoint', None)
830
- sd_models.reload_model_weights()
831
-
832
- for k, v in p.override_settings.items():
833
- opts.set(k, v, is_api=True, run_callbacks=False)
834
-
835
- if k == 'sd_model_checkpoint':
836
- sd_models.reload_model_weights()
837
-
838
- if k == 'sd_vae':
839
- sd_vae.reload_vae_weights()
840
-
841
- sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
842
-
843
- # backwards compatibility, fix sampler and scheduler if invalid
844
- sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
845
-
846
- with profiling.Profiler():
847
- res = process_images_inner(p)
848
-
849
- finally:
850
- sd_models.apply_token_merging(p.sd_model, 0)
851
-
852
- # restore opts to original state
853
- if p.override_settings_restore_afterwards:
854
- for k, v in stored_opts.items():
855
- setattr(opts, k, v)
856
-
857
- if k == 'sd_vae':
858
- sd_vae.reload_vae_weights()
859
-
860
- return res
861
-
862
-
863
- def process_images_inner(p: StableDiffusionProcessing) -> Processed:
864
- """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
865
-
866
- if isinstance(p.prompt, list):
867
- assert(len(p.prompt) > 0)
868
- else:
869
- assert p.prompt is not None
870
-
871
- devices.torch_gc()
872
-
873
- seed = get_fixed_seed(p.seed)
874
- subseed = get_fixed_seed(p.subseed)
875
-
876
- if p.restore_faces is None:
877
- p.restore_faces = opts.face_restoration
878
-
879
- if p.tiling is None:
880
- p.tiling = opts.tiling
881
-
882
- if p.refiner_checkpoint not in (None, "", "None", "none"):
883
- p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
884
- if p.refiner_checkpoint_info is None:
885
- raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
886
-
887
- if hasattr(shared.sd_model, 'fix_dimensions'):
888
- p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)
889
-
890
- p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
891
- p.sd_model_hash = shared.sd_model.sd_model_hash
892
- p.sd_vae_name = sd_vae.get_loaded_vae_name()
893
- p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
894
-
895
- modules.sd_hijack.model_hijack.apply_circular(p.tiling)
896
- modules.sd_hijack.model_hijack.clear_comments()
897
-
898
- p.fill_fields_from_opts()
899
- p.setup_prompts()
900
-
901
- if isinstance(seed, list):
902
- p.all_seeds = seed
903
- else:
904
- p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
905
-
906
- if isinstance(subseed, list):
907
- p.all_subseeds = subseed
908
- else:
909
- p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
910
-
911
- if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
912
- model_hijack.embedding_db.load_textual_inversion_embeddings()
913
-
914
- if p.scripts is not None:
915
- p.scripts.process(p)
916
-
917
- infotexts = []
918
- output_images = []
919
- with torch.no_grad(), p.sd_model.ema_scope():
920
- with devices.autocast():
921
- p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
922
-
923
- # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
924
- if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
925
- sd_vae_approx.model()
926
-
927
- sd_unet.apply_unet()
928
-
929
- if state.job_count == -1:
930
- state.job_count = p.n_iter
931
-
932
- for n in range(p.n_iter):
933
- p.iteration = n
934
-
935
- if state.skipped:
936
- state.skipped = False
937
-
938
- if state.interrupted or state.stopping_generation:
939
- break
940
-
941
- sd_models.reload_model_weights() # model can be changed for example by refiner
942
-
943
- p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
944
- p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
945
- p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
946
- p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
947
-
948
- latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
949
- p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
950
-
951
- if p.scripts is not None:
952
- p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
953
-
954
- if len(p.prompts) == 0:
955
- break
956
-
957
- p.parse_extra_network_prompts()
958
-
959
- if not p.disable_extra_networks:
960
- with devices.autocast():
961
- extra_networks.activate(p, p.extra_network_data)
962
-
963
- if p.scripts is not None:
964
- p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
965
-
966
- p.setup_conds()
967
-
968
- p.extra_generation_params.update(model_hijack.extra_generation_params)
969
-
970
- # params.txt should be saved after scripts.process_batch, since the
971
- # infotext could be modified by that callback
972
- # Example: a wildcard processed by process_batch sets an extra model
973
- # strength, which is saved as "Model Strength: 1.0" in the infotext
974
- if n == 0 and not cmd_opts.no_prompt_history:
975
- with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
976
- processed = Processed(p, [])
977
- file.write(processed.infotext(p, 0))
978
-
979
- for comment in model_hijack.comments:
980
- p.comment(comment)
981
-
982
- if p.n_iter > 1:
983
- shared.state.job = f"Batch {n+1} out of {p.n_iter}"
984
-
985
- sd_models.apply_alpha_schedule_override(p.sd_model, p)
986
-
987
- with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
988
- samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
989
-
990
- if p.scripts is not None:
991
- ps = scripts.PostSampleArgs(samples_ddim)
992
- p.scripts.post_sample(p, ps)
993
- samples_ddim = ps.samples
994
-
995
- if getattr(samples_ddim, 'already_decoded', False):
996
- x_samples_ddim = samples_ddim
997
- else:
998
- devices.test_for_nans(samples_ddim, "unet")
999
-
1000
- if opts.sd_vae_decode_method != 'Full':
1001
- p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
1002
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
1003
-
1004
- x_samples_ddim = torch.stack(x_samples_ddim).float()
1005
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
1006
-
1007
- del samples_ddim
1008
-
1009
- if lowvram.is_enabled(shared.sd_model):
1010
- lowvram.send_everything_to_cpu()
1011
-
1012
- devices.torch_gc()
1013
-
1014
- state.nextjob()
1015
-
1016
- if p.scripts is not None:
1017
- p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
1018
-
1019
- p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1020
- p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1021
-
1022
- batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
1023
- p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
1024
- x_samples_ddim = batch_params.images
1025
-
1026
- def infotext(index=0, use_main_prompt=False):
1027
- return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
1028
-
1029
- save_samples = p.save_samples()
1030
-
1031
- for i, x_sample in enumerate(x_samples_ddim):
1032
- p.batch_index = i
1033
-
1034
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1035
- x_sample = x_sample.astype(np.uint8)
1036
-
1037
- if p.restore_faces:
1038
- if save_samples and opts.save_images_before_face_restoration:
1039
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
1040
-
1041
- devices.torch_gc()
1042
-
1043
- x_sample = modules.face_restoration.restore_faces(x_sample)
1044
- devices.torch_gc()
1045
-
1046
- image = Image.fromarray(x_sample)
1047
-
1048
- if p.scripts is not None:
1049
- pp = scripts.PostprocessImageArgs(image)
1050
- p.scripts.postprocess_image(p, pp)
1051
- image = pp.image
1052
-
1053
- mask_for_overlay = getattr(p, "mask_for_overlay", None)
1054
-
1055
- if not shared.opts.overlay_inpaint:
1056
- overlay_image = None
1057
- elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
1058
- overlay_image = p.overlay_images[i]
1059
- else:
1060
- overlay_image = None
1061
-
1062
- if p.scripts is not None:
1063
- ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
1064
- p.scripts.postprocess_maskoverlay(p, ppmo)
1065
- mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
1066
-
1067
- if p.color_corrections is not None and i < len(p.color_corrections):
1068
- if save_samples and opts.save_images_before_color_correction:
1069
- image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
1070
- images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
1071
- image = apply_color_correction(p.color_corrections[i], image)
1072
-
1073
- # If the intention is to show the output from the model
1074
- # that is being composited over the original image,
1075
- # we need to keep the original image around
1076
- # and use it in the composite step.
1077
- image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)
1078
-
1079
- if p.scripts is not None:
1080
- pp = scripts.PostprocessImageArgs(image)
1081
- p.scripts.postprocess_image_after_composite(p, pp)
1082
- image = pp.image
1083
-
1084
- if save_samples:
1085
- images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
1086
-
1087
- text = infotext(i)
1088
- infotexts.append(text)
1089
- if opts.enable_pnginfo:
1090
- image.info["parameters"] = text
1091
- output_images.append(image)
1092
-
1093
- if mask_for_overlay is not None:
1094
- if opts.return_mask or opts.save_mask:
1095
- image_mask = mask_for_overlay.convert('RGB')
1096
- if save_samples and opts.save_mask:
1097
- images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
1098
- if opts.return_mask:
1099
- output_images.append(image_mask)
1100
-
1101
- if opts.return_mask_composite or opts.save_mask_composite:
1102
- image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
1103
- if save_samples and opts.save_mask_composite:
1104
- images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
1105
- if opts.return_mask_composite:
1106
- output_images.append(image_mask_composite)
1107
-
1108
- del x_samples_ddim
1109
-
1110
- devices.torch_gc()
1111
-
1112
- if not infotexts:
1113
- infotexts.append(Processed(p, []).infotext(p, 0))
1114
-
1115
- p.color_corrections = None
1116
-
1117
- index_of_first_image = 0
1118
- unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
1119
- if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
1120
- grid = images.image_grid(output_images, p.batch_size)
1121
-
1122
- if opts.return_grid:
1123
- text = infotext(use_main_prompt=True)
1124
- infotexts.insert(0, text)
1125
- if opts.enable_pnginfo:
1126
- grid.info["parameters"] = text
1127
- output_images.insert(0, grid)
1128
- index_of_first_image = 1
1129
- if opts.grid_save:
1130
- images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
1131
-
1132
- if not p.disable_extra_networks and p.extra_network_data:
1133
- extra_networks.deactivate(p, p.extra_network_data)
1134
-
1135
- devices.torch_gc()
1136
-
1137
- res = Processed(
1138
- p,
1139
- images_list=output_images,
1140
- seed=p.all_seeds[0],
1141
- info=infotexts[0],
1142
- subseed=p.all_subseeds[0],
1143
- index_of_first_image=index_of_first_image,
1144
- infotexts=infotexts,
1145
- )
1146
-
1147
- if p.scripts is not None:
1148
- p.scripts.postprocess(p, res)
1149
-
1150
- return res
1151
-
1152
-
1153
- def old_hires_fix_first_pass_dimensions(width, height):
1154
- """old algorithm for auto-calculating first pass size"""
1155
-
1156
- desired_pixel_count = 512 * 512
1157
- actual_pixel_count = width * height
1158
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
1159
- width = math.ceil(scale * width / 64) * 64
1160
- height = math.ceil(scale * height / 64) * 64
1161
-
1162
- return width, height
1163
-
1164
-
1165
- @dataclass(repr=False)
1166
- class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
1167
- enable_hr: bool = False
1168
- denoising_strength: float = 0.75
1169
- firstphase_width: int = 0
1170
- firstphase_height: int = 0
1171
- hr_scale: float = 2.0
1172
- hr_upscaler: str = None
1173
- hr_second_pass_steps: int = 0
1174
- hr_resize_x: int = 0
1175
- hr_resize_y: int = 0
1176
- hr_checkpoint_name: str = None
1177
- hr_sampler_name: str = None
1178
- hr_scheduler: str = None
1179
- hr_prompt: str = ''
1180
- hr_negative_prompt: str = ''
1181
- force_task_id: str = None
1182
-
1183
- cached_hr_uc = [None, None]
1184
- cached_hr_c = [None, None]
1185
-
1186
- hr_checkpoint_info: dict = field(default=None, init=False)
1187
- hr_upscale_to_x: int = field(default=0, init=False)
1188
- hr_upscale_to_y: int = field(default=0, init=False)
1189
- truncate_x: int = field(default=0, init=False)
1190
- truncate_y: int = field(default=0, init=False)
1191
- applied_old_hires_behavior_to: tuple = field(default=None, init=False)
1192
- latent_scale_mode: dict = field(default=None, init=False)
1193
- hr_c: tuple | None = field(default=None, init=False)
1194
- hr_uc: tuple | None = field(default=None, init=False)
1195
- all_hr_prompts: list = field(default=None, init=False)
1196
- all_hr_negative_prompts: list = field(default=None, init=False)
1197
- hr_prompts: list = field(default=None, init=False)
1198
- hr_negative_prompts: list = field(default=None, init=False)
1199
- hr_extra_network_data: list = field(default=None, init=False)
1200
-
1201
- def __post_init__(self):
1202
- super().__post_init__()
1203
-
1204
- if self.firstphase_width != 0 or self.firstphase_height != 0:
1205
- self.hr_upscale_to_x = self.width
1206
- self.hr_upscale_to_y = self.height
1207
- self.width = self.firstphase_width
1208
- self.height = self.firstphase_height
1209
-
1210
- self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
1211
- self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
1212
-
1213
- def calculate_target_resolution(self):
1214
- if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
1215
- self.hr_resize_x = self.width
1216
- self.hr_resize_y = self.height
1217
- self.hr_upscale_to_x = self.width
1218
- self.hr_upscale_to_y = self.height
1219
-
1220
- self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
1221
- self.applied_old_hires_behavior_to = (self.width, self.height)
1222
-
1223
- if self.hr_resize_x == 0 and self.hr_resize_y == 0:
1224
- self.extra_generation_params["Hires upscale"] = self.hr_scale
1225
- self.hr_upscale_to_x = int(self.width * self.hr_scale)
1226
- self.hr_upscale_to_y = int(self.height * self.hr_scale)
1227
- else:
1228
- self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
1229
-
1230
- if self.hr_resize_y == 0:
1231
- self.hr_upscale_to_x = self.hr_resize_x
1232
- self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1233
- elif self.hr_resize_x == 0:
1234
- self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1235
- self.hr_upscale_to_y = self.hr_resize_y
1236
- else:
1237
- target_w = self.hr_resize_x
1238
- target_h = self.hr_resize_y
1239
- src_ratio = self.width / self.height
1240
- dst_ratio = self.hr_resize_x / self.hr_resize_y
1241
-
1242
- if src_ratio < dst_ratio:
1243
- self.hr_upscale_to_x = self.hr_resize_x
1244
- self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1245
- else:
1246
- self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1247
- self.hr_upscale_to_y = self.hr_resize_y
1248
-
1249
- self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
1250
- self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
1251
-
1252
- def init(self, all_prompts, all_seeds, all_subseeds):
1253
- if self.enable_hr:
1254
- self.extra_generation_params["Denoising strength"] = self.denoising_strength
1255
-
1256
- if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
1257
- self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
1258
-
1259
- if self.hr_checkpoint_info is None:
1260
- raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
1261
-
1262
- self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
1263
-
1264
- if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
1265
- self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
1266
-
1267
- def get_hr_prompt(p, index, prompt_text, **kwargs):
1268
- hr_prompt = p.all_hr_prompts[index]
1269
- return hr_prompt if hr_prompt != prompt_text else None
1270
-
1271
- def get_hr_negative_prompt(p, index, negative_prompt, **kwargs):
1272
- hr_negative_prompt = p.all_hr_negative_prompts[index]
1273
- return hr_negative_prompt if hr_negative_prompt != negative_prompt else None
1274
-
1275
- self.extra_generation_params["Hires prompt"] = get_hr_prompt
1276
- self.extra_generation_params["Hires negative prompt"] = get_hr_negative_prompt
1277
-
1278
- self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
1279
-
1280
- if self.hr_scheduler is None:
1281
- self.hr_scheduler = self.scheduler
1282
-
1283
- self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
1284
- if self.enable_hr and self.latent_scale_mode is None:
1285
- if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
1286
- raise Exception(f"could not find upscaler named {self.hr_upscaler}")
1287
-
1288
- self.calculate_target_resolution()
1289
-
1290
- if not state.processing_has_refined_job_count:
1291
- if state.job_count == -1:
1292
- state.job_count = self.n_iter
1293
- if getattr(self, 'txt2img_upscale', False):
1294
- total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
1295
- else:
1296
- total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
1297
- shared.total_tqdm.updateTotal(total_steps)
1298
- state.job_count = state.job_count * 2
1299
- state.processing_has_refined_job_count = True
1300
-
1301
- if self.hr_second_pass_steps:
1302
- self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1303
-
1304
- if self.hr_upscaler is not None:
1305
- self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1306
-
1307
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1308
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1309
-
1310
- if self.firstpass_image is not None and self.enable_hr:
1311
- # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
1312
-
1313
- if self.latent_scale_mode is None:
1314
- image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
1315
- image = np.moveaxis(image, 2, 0)
1316
-
1317
- samples = None
1318
- decoded_samples = torch.asarray(np.expand_dims(image, 0))
1319
-
1320
- else:
1321
- image = np.array(self.firstpass_image).astype(np.float32) / 255.0
1322
- image = np.moveaxis(image, 2, 0)
1323
- image = torch.from_numpy(np.expand_dims(image, axis=0))
1324
- image = image.to(shared.device, dtype=devices.dtype_vae)
1325
-
1326
- if opts.sd_vae_encode_method != 'Full':
1327
- self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1328
-
1329
- samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1330
- decoded_samples = None
1331
- devices.torch_gc()
1332
-
1333
- else:
1334
- # here we generate an image normally
1335
-
1336
- x = self.rng.next()
1337
- if self.scripts is not None:
1338
- self.scripts.process_before_every_sampling(
1339
- p=self,
1340
- x=x,
1341
- noise=x,
1342
- c=conditioning,
1343
- uc=unconditional_conditioning
1344
- )
1345
-
1346
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1347
- del x
1348
-
1349
- if not self.enable_hr:
1350
- return samples
1351
-
1352
- devices.torch_gc()
1353
-
1354
- if self.latent_scale_mode is None:
1355
- decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
1356
- else:
1357
- decoded_samples = None
1358
-
1359
- with sd_models.SkipWritingToConfig():
1360
- sd_models.reload_model_weights(info=self.hr_checkpoint_info)
1361
-
1362
- return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
1363
-
1364
- def sample_progressive(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1365
- is_sdxl = getattr(self.sd_model, 'is_sdxl', False)
1366
-
1367
- if is_sdxl:
1368
- min_scale = max(0.5, self.progressive_growing_min_scale)
1369
- else:
1370
- min_scale = self.progressive_growing_min_scale
1371
-
1372
- resolution_steps = np.linspace(min_scale, self.progressive_growing_max_scale, self.progressive_growing_steps)
1373
-
1374
- initial_width = max(512 if is_sdxl else 64, int(self.width * resolution_steps[0]))
1375
- initial_height = max(512 if is_sdxl else 64, int(self.height * resolution_steps[0]))
1376
-
1377
- x = create_random_tensors((opt_C, initial_height // opt_f, initial_width // opt_f), seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1378
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1379
-
1380
- for i in range(1, len(resolution_steps)):
1381
- target_width = int(self.width * resolution_steps[i])
1382
- target_height = int(self.height * resolution_steps[i])
1383
-
1384
- if is_sdxl:
1385
- target_width = max(512, min(1536, target_width))
1386
- target_height = max(512, min(1536, target_height))
1387
-
1388
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode='bicubic', align_corners=False)
1389
-
1390
- if self.progressive_growing_refinement:
1391
- steps_for_refinement = self.steps // len(resolution_steps)
1392
- noise = create_random_tensors(samples.shape[1:], seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1393
- decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
1394
- decoded_samples = torch.stack(decoded_samples).float()
1395
- decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1396
- self.image_conditioning = self.img2img_image_conditioning(decoded_samples * 2 - 1, samples)
1397
-
1398
- samples = self.sampler.sample_img2img(
1399
- self,
1400
- samples,
1401
- noise,
1402
- conditioning,
1403
- unconditional_conditioning,
1404
- steps=steps_for_refinement,
1405
- image_conditioning=self.image_conditioning
1406
- )
1407
-
1408
- return samples
1409
-
1410
- def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
1411
- if shared.state.interrupted:
1412
- return samples
1413
-
1414
- self.is_hr_pass = True
1415
- target_width = self.hr_upscale_to_x
1416
- target_height = self.hr_upscale_to_y
1417
-
1418
- def save_intermediate(image, index):
1419
- """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
1420
-
1421
- if not self.save_samples() or not opts.save_images_before_highres_fix:
1422
- return
1423
-
1424
- if not isinstance(image, Image.Image):
1425
- image = sd_samplers.sample_to_image(image, index, approximation=0)
1426
-
1427
- info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
1428
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1429
-
1430
- img2img_sampler_name = self.hr_sampler_name or self.sampler_name
1431
-
1432
- self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
1433
-
1434
- if self.latent_scale_mode is not None:
1435
- for i in range(samples.shape[0]):
1436
- save_intermediate(samples, i)
1437
-
1438
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
1439
-
1440
- # Avoid making the inpainting conditioning unless necessary as
1441
- # this does need some extra compute to decode / encode the image again.
1442
- if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
1443
- image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
1444
- else:
1445
- image_conditioning = self.txt2img_image_conditioning(samples)
1446
- else:
1447
- lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1448
-
1449
- batch_images = []
1450
- for i, x_sample in enumerate(lowres_samples):
1451
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1452
- x_sample = x_sample.astype(np.uint8)
1453
- image = Image.fromarray(x_sample)
1454
-
1455
- save_intermediate(image, i)
1456
-
1457
- image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1458
- image = np.array(image).astype(np.float32) / 255.0
1459
- image = np.moveaxis(image, 2, 0)
1460
- batch_images.append(image)
1461
-
1462
- decoded_samples = torch.from_numpy(np.array(batch_images))
1463
- decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
1464
-
1465
- if opts.sd_vae_encode_method != 'Full':
1466
- self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1467
- samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
1468
-
1469
- image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1470
-
1471
- shared.state.nextjob()
1472
-
1473
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
1474
-
1475
- self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
1476
- noise = self.rng.next()
1477
-
1478
- # GC now before running the next img2img to prevent running out of memory
1479
- devices.torch_gc()
1480
-
1481
- if not self.disable_extra_networks:
1482
- with devices.autocast():
1483
- extra_networks.activate(self, self.hr_extra_network_data)
1484
-
1485
- with devices.autocast():
1486
- self.calculate_hr_conds()
1487
-
1488
- sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1489
-
1490
- if self.scripts is not None:
1491
- self.scripts.before_hr(self)
1492
- self.scripts.process_before_every_sampling(
1493
- p=self,
1494
- x=samples,
1495
- noise=noise,
1496
- c=self.hr_c,
1497
- uc=self.hr_uc,
1498
- )
1499
-
1500
- samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
1501
-
1502
- sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
1503
-
1504
- self.sampler = None
1505
- devices.torch_gc()
1506
-
1507
- decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
1508
-
1509
- self.is_hr_pass = False
1510
- return decoded_samples
1511
-
1512
- def close(self):
1513
- super().close()
1514
- self.hr_c = None
1515
- self.hr_uc = None
1516
- if not opts.persistent_cond_cache:
1517
- StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
1518
- StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1519
-
1520
- def setup_prompts(self):
1521
- super().setup_prompts()
1522
-
1523
- if not self.enable_hr:
1524
- return
1525
-
1526
- if self.hr_prompt == '':
1527
- self.hr_prompt = self.prompt
1528
-
1529
- if self.hr_negative_prompt == '':
1530
- self.hr_negative_prompt = self.negative_prompt
1531
-
1532
- if isinstance(self.hr_prompt, list):
1533
- self.all_hr_prompts = self.hr_prompt
1534
- else:
1535
- self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
1536
-
1537
- if isinstance(self.hr_negative_prompt, list):
1538
- self.all_hr_negative_prompts = self.hr_negative_prompt
1539
- else:
1540
- self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
1541
-
1542
- self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
1543
- self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
1544
-
1545
- def calculate_hr_conds(self):
1546
- if self.hr_c is not None:
1547
- return
1548
-
1549
- hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
1550
- hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
1551
-
1552
- sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
1553
- steps = self.hr_second_pass_steps or self.steps
1554
- total_steps = sampler_config.total_steps(steps) if sampler_config else steps
1555
-
1556
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
1557
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
1558
-
1559
- def setup_conds(self):
1560
- if self.is_hr_pass:
1561
- # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
1562
- self.hr_c = None
1563
- self.calculate_hr_conds()
1564
- return
1565
-
1566
- super().setup_conds()
1567
-
1568
- self.hr_uc = None
1569
- self.hr_c = None
1570
-
1571
- if self.enable_hr and self.hr_checkpoint_info is None:
1572
- if shared.opts.hires_fix_use_firstpass_conds:
1573
- self.calculate_hr_conds()
1574
-
1575
- elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
1576
- with devices.autocast():
1577
- extra_networks.activate(self, self.hr_extra_network_data)
1578
-
1579
- self.calculate_hr_conds()
1580
-
1581
- with devices.autocast():
1582
- extra_networks.activate(self, self.extra_network_data)
1583
-
1584
- def get_conds(self):
1585
- if self.is_hr_pass:
1586
- return self.hr_c, self.hr_uc
1587
-
1588
- return super().get_conds()
1589
-
1590
- def parse_extra_network_prompts(self):
1591
- res = super().parse_extra_network_prompts()
1592
-
1593
- if self.enable_hr:
1594
- self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1595
- self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1596
-
1597
- self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
1598
-
1599
- return res
1600
-
1601
-
1602
- @dataclass(repr=False)
1603
- class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
1604
- init_images: list = None
1605
- resize_mode: int = 0
1606
- denoising_strength: float = 0.75
1607
- image_cfg_scale: float = None
1608
- mask: Any = None
1609
- mask_blur_x: int = 4
1610
- mask_blur_y: int = 4
1611
- mask_blur: int = None
1612
- mask_round: bool = True
1613
- inpainting_fill: int = 0
1614
- inpaint_full_res: bool = True
1615
- inpaint_full_res_padding: int = 0
1616
- inpainting_mask_invert: int = 0
1617
- initial_noise_multiplier: float = None
1618
- latent_mask: Image = None
1619
- force_task_id: str = None
1620
-
1621
- image_mask: Any = field(default=None, init=False)
1622
-
1623
- nmask: torch.Tensor = field(default=None, init=False)
1624
- image_conditioning: torch.Tensor = field(default=None, init=False)
1625
- init_img_hash: str = field(default=None, init=False)
1626
- mask_for_overlay: Image = field(default=None, init=False)
1627
- init_latent: torch.Tensor = field(default=None, init=False)
1628
-
1629
- def __post_init__(self):
1630
- super().__post_init__()
1631
-
1632
- self.image_mask = self.mask
1633
- self.mask = None
1634
- self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
1635
-
1636
- @property
1637
- def mask_blur(self):
1638
- if self.mask_blur_x == self.mask_blur_y:
1639
- return self.mask_blur_x
1640
- return None
1641
-
1642
- @mask_blur.setter
1643
- def mask_blur(self, value):
1644
- if isinstance(value, int):
1645
- self.mask_blur_x = value
1646
- self.mask_blur_y = value
1647
-
1648
- def init(self, all_prompts, all_seeds, all_subseeds):
1649
- self.extra_generation_params["Denoising strength"] = self.denoising_strength
1650
-
1651
- self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
1652
-
1653
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1654
- crop_region = None
1655
-
1656
- image_mask = self.image_mask
1657
-
1658
- if image_mask is not None:
1659
- # image_mask is passed in as RGBA by Gradio to support alpha masks,
1660
- # but we still want to support binary masks.
1661
- image_mask = create_binary_mask(image_mask, round=self.mask_round)
1662
-
1663
- if self.inpainting_mask_invert:
1664
- image_mask = ImageOps.invert(image_mask)
1665
- self.extra_generation_params["Mask mode"] = "Inpaint not masked"
1666
-
1667
- if self.mask_blur_x > 0:
1668
- np_mask = np.array(image_mask)
1669
- kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
1670
- np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
1671
- image_mask = Image.fromarray(np_mask)
1672
-
1673
- if self.mask_blur_y > 0:
1674
- np_mask = np.array(image_mask)
1675
- kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
1676
- np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
1677
- image_mask = Image.fromarray(np_mask)
1678
-
1679
- if self.mask_blur_x > 0 or self.mask_blur_y > 0:
1680
- self.extra_generation_params["Mask blur"] = self.mask_blur
1681
-
1682
- if self.inpaint_full_res:
1683
- self.mask_for_overlay = image_mask
1684
- mask = image_mask.convert('L')
1685
- crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding)
1686
- if crop_region:
1687
- crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1688
- x1, y1, x2, y2 = crop_region
1689
- mask = mask.crop(crop_region)
1690
- image_mask = images.resize_image(2, mask, self.width, self.height)
1691
- self.paste_to = (x1, y1, x2-x1, y2-y1)
1692
- self.extra_generation_params["Inpaint area"] = "Only masked"
1693
- self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
1694
- else:
1695
- crop_region = None
1696
- image_mask = None
1697
- self.mask_for_overlay = None
1698
- self.inpaint_full_res = False
1699
- massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
1700
- model_hijack.comments.append(massage)
1701
- logging.info(massage)
1702
- else:
1703
- image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1704
- np_mask = np.array(image_mask)
1705
- np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1706
- self.mask_for_overlay = Image.fromarray(np_mask)
1707
-
1708
- self.overlay_images = []
1709
-
1710
- latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1711
-
1712
- add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
1713
- if add_color_corrections:
1714
- self.color_corrections = []
1715
- imgs = []
1716
- for img in self.init_images:
1717
-
1718
- # Save init image
1719
- if opts.save_init_img:
1720
- self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
1721
- images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
1722
-
1723
- image = images.flatten(img, opts.img2img_background_color)
1724
-
1725
- if crop_region is None and self.resize_mode != 3:
1726
- image = images.resize_image(self.resize_mode, image, self.width, self.height)
1727
-
1728
- if image_mask is not None:
1729
- if self.mask_for_overlay.size != (image.width, image.height):
1730
- self.mask_for_overlay = images.resize_image(self.resize_mode, self.mask_for_overlay, image.width, image.height)
1731
- image_masked = Image.new('RGBa', (image.width, image.height))
1732
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
1733
-
1734
- self.overlay_images.append(image_masked.convert('RGBA'))
1735
-
1736
- # crop_region is not None if we are doing inpaint full res
1737
- if crop_region is not None:
1738
- image = image.crop(crop_region)
1739
- image = images.resize_image(2, image, self.width, self.height)
1740
-
1741
- if image_mask is not None:
1742
- if self.inpainting_fill != 1:
1743
- image = masking.fill(image, latent_mask)
1744
-
1745
- if self.inpainting_fill == 0:
1746
- self.extra_generation_params["Masked content"] = 'fill'
1747
-
1748
- if add_color_corrections:
1749
- self.color_corrections.append(setup_color_correction(image))
1750
-
1751
- image = np.array(image).astype(np.float32) / 255.0
1752
- image = np.moveaxis(image, 2, 0)
1753
-
1754
- imgs.append(image)
1755
-
1756
- if len(imgs) == 1:
1757
- batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
1758
- if self.overlay_images is not None:
1759
- self.overlay_images = self.overlay_images * self.batch_size
1760
-
1761
- if self.color_corrections is not None and len(self.color_corrections) == 1:
1762
- self.color_corrections = self.color_corrections * self.batch_size
1763
-
1764
- elif len(imgs) <= self.batch_size:
1765
- self.batch_size = len(imgs)
1766
- batch_images = np.array(imgs)
1767
- else:
1768
- raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
1769
-
1770
- image = torch.from_numpy(batch_images)
1771
- image = image.to(shared.device, dtype=devices.dtype_vae)
1772
-
1773
- if opts.sd_vae_encode_method != 'Full':
1774
- self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1775
-
1776
- self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1777
- devices.torch_gc()
1778
-
1779
- if self.resize_mode == 3:
1780
- self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
1781
-
1782
- if image_mask is not None:
1783
- init_mask = latent_mask
1784
- latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
1785
- latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1786
- latmask = latmask[0]
1787
- if self.mask_round:
1788
- latmask = np.around(latmask)
1789
- latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))
1790
-
1791
- self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)
1792
- self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)
1793
-
1794
- # this needs to be fixed to be done in sample() using actual seeds for batches
1795
- if self.inpainting_fill == 2:
1796
- self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1797
- self.extra_generation_params["Masked content"] = 'latent noise'
1798
-
1799
- elif self.inpainting_fill == 3:
1800
- self.init_latent = self.init_latent * self.mask
1801
- self.extra_generation_params["Masked content"] = 'latent nothing'
1802
-
1803
- self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
1804
-
1805
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1806
- x = self.rng.next()
1807
-
1808
- if self.initial_noise_multiplier != 1.0:
1809
- self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
1810
- x *= self.initial_noise_multiplier
1811
-
1812
- if self.scripts is not None:
1813
- self.scripts.process_before_every_sampling(
1814
- p=self,
1815
- x=self.init_latent,
1816
- noise=x,
1817
- c=conditioning,
1818
- uc=unconditional_conditioning
1819
- )
1820
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1821
-
1822
- if self.mask is not None:
1823
- blended_samples = samples * self.nmask + self.init_latent * self.mask
1824
-
1825
- if self.scripts is not None:
1826
- mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
1827
- self.scripts.on_mask_blend(self, mba)
1828
- blended_samples = mba.blended_latent
1829
-
1830
- samples = blended_samples
1831
-
1832
- del x
1833
- devices.torch_gc()
1834
-
1835
- return samples
1836
-
1837
- def get_token_merging_ratio(self, for_hr=False):
1838
- return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio