tellurion Claude Sonnet 4.6 commited on
Commit
47ab351
·
1 Parent(s): 7b75d46

Add refnet/models and ldm/models source files previously excluded by .gitignore

Browse files
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
4
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
5
+
6
+
7
+ class AutoencoderKL(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ ddconfig,
11
+ embed_dim
12
+ ):
13
+ super().__init__()
14
+ self.encoder = Encoder(**ddconfig)
15
+ self.decoder = Decoder(**ddconfig)
16
+ assert ddconfig["double_z"]
17
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
18
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
19
+ self.embed_dim = embed_dim
20
+
21
+ def encode(self, x):
22
+ h = self.encoder(x)
23
+ moments = self.quant_conv(h)
24
+ posterior = DiagonalGaussianDistribution(moments)
25
+ return posterior
26
+
27
+ def decode(self, z):
28
+ z = self.post_quant_conv(z)
29
+ dec = self.decoder(z)
30
+ return dec
31
+
32
+ def get_last_layer(self):
33
+ return self.decoder.conv_out.weight
34
+
35
+ @property
36
+ def dtype(self):
37
+ return self.decoder.conv_out.weight.dtype
refnet/models/basemodel.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from refnet.util import exists, fitting_weights, instantiate_from_config, load_weights, delete_states
4
+ from refnet.ldm import LatentDiffusion
5
+ from typing import Union
6
+ from refnet.sampling import (
7
+ UnetHook,
8
+ KDiffusionSampler,
9
+ DiffuserDenoiser,
10
+ )
11
+
12
+
13
+
14
+ class GuidanceFlag:
15
+ none = 0
16
+ reference = 1
17
+ sketch = 10
18
+ both = 11
19
+
20
+
21
+ def reconstruct_cond(cond, uncond):
22
+ if not isinstance(uncond, list):
23
+ uncond = [uncond]
24
+ for k in cond.keys():
25
+ if k == "inpaint_bg":
26
+ continue
27
+ for uc in uncond:
28
+ if isinstance(cond[k], list):
29
+ cond[k] = [torch.cat([cond[k][i], uc[k][i]]) for i in range(len(cond[k]))]
30
+ elif isinstance(cond[k], torch.Tensor):
31
+ cond[k] = torch.cat([cond[k], uc[k]])
32
+ return cond
33
+
34
+
35
+ class CustomizedLDM(LatentDiffusion):
36
+ def __init__(
37
+ self,
38
+ dtype = torch.float32,
39
+ sigma_max = None,
40
+ sigma_min = None,
41
+ *args,
42
+ **kwargs
43
+ ):
44
+ super().__init__(*args, **kwargs)
45
+ self.dtype = dtype
46
+ self.sigma_max = sigma_max
47
+ self.sigma_min = sigma_min
48
+
49
+ self.model_list = {
50
+ "first": self.first_stage_model,
51
+ "cond": self.cond_stage_model,
52
+ "unet": self.model,
53
+ }
54
+ self.switch_cond_modules = ["cond"]
55
+ self.switch_main_modules = ["unet"]
56
+ self.retrieve_attn_modules()
57
+ self.retrieve_attn_layers()
58
+
59
+ def init_from_ckpt(
60
+ self,
61
+ path,
62
+ only_model = False,
63
+ logging = False,
64
+ make_it_fit = False,
65
+ ignore_keys: list[str] = (),
66
+ ):
67
+ sd = delete_states(load_weights(path), ignore_keys)
68
+ if make_it_fit:
69
+ sd = fitting_weights(self, sd)
70
+
71
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model \
72
+ else self.model.load_state_dict(sd, strict=False)
73
+
74
+ filtered_missing = []
75
+ filtered_unexpect = []
76
+ for k in missing:
77
+ if not k.find("cond_stage_model") > -1 and not k.find("img_embedder") > -1 and not k.find("fg") > -1:
78
+ filtered_missing.append(k)
79
+ for k in unexpected:
80
+ if not k.find("cond_stage_model") > -1 and not k.find("img_embedder") > -1:
81
+ filtered_unexpect.append(k)
82
+
83
+ print(
84
+ f"Restored from {path} with {len(filtered_missing)} filtered missing and "
85
+ f"{len(filtered_unexpect)} filtered unexpected keys")
86
+ if logging:
87
+ if len(missing) > 0:
88
+ print(f"Filtered missing Keys: {filtered_missing}")
89
+ if len(unexpected) > 0:
90
+ print(f"Filtered unexpected Keys: {filtered_unexpect}")
91
+
92
+
93
+ def sample(
94
+ self,
95
+ cond: dict,
96
+ uncond: Union[dict, list[dict]] = None,
97
+ cfg_scale: Union[float, list[float]] = 1.,
98
+ bs: int = 1,
99
+ shape: Union[tuple, list] = None,
100
+ step: int = 20,
101
+ sampler = "DPM++ 3M SDE",
102
+ scheduler = "Automatic",
103
+ device = "cuda",
104
+ x_T = None,
105
+ seed = None,
106
+ deterministic = False,
107
+ **kwargs
108
+ ):
109
+ shape = shape or (self.channels, self.image_size, self.image_size)
110
+ x = x_T or torch.randn(bs, *shape, device=device)
111
+
112
+ if exists(uncond):
113
+ cond = reconstruct_cond(cond, uncond)
114
+
115
+ if sampler.startswith("diffuser"):
116
+ # Using huggingface diffuser noise sampler and scheduler
117
+ sampler = DiffuserDenoiser(
118
+ sampler,
119
+ prediction_type = "v_prediction" if self.parameterization == "v" else "epsilon",
120
+ use_karras = scheduler == "Karras"
121
+ )
122
+
123
+ samples = sampler(
124
+ x,
125
+ cond,
126
+ cond_scale=cfg_scale,
127
+ unet=self,
128
+ timesteps=step,
129
+ generator=torch.manual_seed(seed) if exists(seed) else None,
130
+ device=device
131
+ )
132
+
133
+ else:
134
+ # Using k-diffusion sampler and noise scheduler
135
+ seed = seed or torch.seed()
136
+ sampler = KDiffusionSampler(sampler, scheduler, self, device)
137
+
138
+ sigmas = sampler.get_sigmas(step)
139
+ extra_args = {
140
+ "cond": cond,
141
+ "cond_scale": cfg_scale,
142
+ }
143
+ seed = [seed for _ in range(bs)] if deterministic else seed
144
+ samples = sampler(x, sigmas, extra_args, seed, deterministic, step)
145
+
146
+ return samples
147
+
148
+ def switch_to_fp16(self):
149
+ unet = self.model.diffusion_model
150
+ unet.input_blocks = unet.input_blocks.to(self.half_precision_dtype)
151
+ unet.middle_block = unet.middle_block.to(self.half_precision_dtype)
152
+ unet.output_blocks = unet.output_blocks.to(self.half_precision_dtype)
153
+ self.dtype = self.half_precision_dtype
154
+ unet.dtype = self.half_precision_dtype
155
+
156
+ def switch_to_fp32(self):
157
+ unet = self.model.diffusion_model
158
+ unet.input_blocks = unet.input_blocks.float()
159
+ unet.middle_block = unet.middle_block.float()
160
+ unet.output_blocks = unet.output_blocks.float()
161
+ self.dtype = torch.float32
162
+ unet.dtype = torch.float32
163
+
164
+ def switch_vae_to_fp16(self):
165
+ self.first_stage_model = self.first_stage_model.to(self.half_precision_dtype)
166
+
167
+ def switch_vae_to_fp32(self):
168
+ self.first_stage_model = self.first_stage_model.float()
169
+
170
+ def low_vram_shift(self, cuda_list: Union[str, list[str]]):
171
+ if not isinstance(cuda_list, list):
172
+ cuda_list = [cuda_list]
173
+
174
+ cpu_list = self.model_list.keys() - cuda_list
175
+ for model in cpu_list:
176
+ self.model_list[model] = self.model_list[model].cpu()
177
+ torch.cuda.empty_cache()
178
+
179
+ for model in cuda_list:
180
+ self.model_list[model] = self.model_list[model].cuda()
181
+
182
+
183
+ def retrieve_attn_modules(self):
184
+ from refnet.modules.transformer import BasicTransformerBlock
185
+ from refnet.sampling import torch_dfs
186
+
187
+ scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25}
188
+
189
+ attn_modules = []
190
+ for module in torch_dfs(self.model.diffusion_model):
191
+ if isinstance(module, BasicTransformerBlock):
192
+ attn_modules.append(module)
193
+
194
+ self.attn_modules = {
195
+ "high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69],
196
+ "low": [i for i in range(4, 24)] + [i for i in range(34, 64)],
197
+ "bottom": [i for i in range(24, 34)],
198
+ "encoder": [i for i in range(24)],
199
+ "decoder": [i for i in range(34, len(attn_modules))]
200
+ }
201
+ self.attn_modules["modules"] = attn_modules
202
+
203
+ for k in ["high", "low", "bottom"]:
204
+ scale_factor = scale_factor_levels[k]
205
+ for attn in self.attn_modules[k]:
206
+ attn_modules[attn].scale_factor = scale_factor
207
+
208
+
209
+ def retrieve_attn_layers(self):
210
+ self.attn_layers = []
211
+ for module in (self.attn_modules["modules"]):
212
+ if hasattr(module, "attn2") and exists(getattr(module, "attn2")):
213
+ self.attn_layers.append(module.attn2)
214
+
215
+
216
+ class CustomizedColorizer(CustomizedLDM):
217
+ def __init__(
218
+ self,
219
+ control_encoder_config,
220
+ proj_config,
221
+ token_type = "full",
222
+ *args,
223
+ **kwargs
224
+ ):
225
+ super().__init__(*args, **kwargs)
226
+ self.control_encoder = instantiate_from_config(control_encoder_config)
227
+ self.proj = instantiate_from_config(proj_config)
228
+ self.token_type = token_type
229
+ self.model_list.update({"control_encoder": self.control_encoder, "proj": self.proj})
230
+ self.switch_cond_modules += ["control_encoder", "proj"]
231
+
232
+
233
+ def switch_to_fp16(self):
234
+ self.control_encoder = self.control_encoder.to(self.half_precision_dtype)
235
+ super().switch_to_fp16()
236
+
237
+
238
+ def switch_to_fp32(self):
239
+ self.control_encoder = self.control_encoder.float()
240
+ super().switch_to_fp32()
241
+
242
+
243
+ from refnet.modules.unet import hack_inference_forward
244
+ class CustomizedWrapper:
245
+ def __init__(self):
246
+ self.scaling_sample = False
247
+ self.guidance_steps = (0, 1)
248
+ self.no_guidance_steps = (-0.05, 0.05)
249
+ hack_inference_forward(self.model.diffusion_model)
250
+
251
+ def adjust_reference_scale(self, scale_kwargs):
252
+ if isinstance(scale_kwargs, dict):
253
+ if scale_kwargs["level_control"]:
254
+ for key in scale_kwargs["scales"]:
255
+ if key == "middle":
256
+ continue
257
+ for idx in self.attn_modules[key]:
258
+ self.attn_modules["modules"][idx].reference_scale = scale_kwargs["scales"][key]
259
+ else:
260
+ for idx, s in enumerate(scale_kwargs["scales"]):
261
+ self.attn_modules["modules"][idx].reference_scale = s
262
+ else:
263
+ for module in self.attn_modules["modules"]:
264
+ module.reference_scale = scale_kwargs
265
+
266
+ def adjust_fgbg_scale(self, fg_scale, bg_scale, merge_scale, mask_threshold):
267
+ for layer in self.attn_layers:
268
+ layer.fg_scale = fg_scale
269
+ layer.bg_scale = bg_scale
270
+ layer.merge_scale = merge_scale
271
+ layer.mask_threshold = mask_threshold
272
+ # for layer in self.attn_modules["modules"]:
273
+ # layer.fg_scale = fg_scale
274
+ # layer.bg_scale = bg_scale
275
+ # layer.merge_scale = merge_scale
276
+ # layer.mask_threshold = mask_threshold
277
+
278
+ def apply_model(self, x_noisy, t, cond):
279
+ tr = 1 - t[0] / (self.num_timesteps - 1)
280
+ crossattn = cond["context"][0]
281
+ if ((tr < self.guidance_steps[0] or tr > self.guidance_steps[1]) or
282
+ (tr >= self.no_guidance_steps[0] and tr <= self.no_guidance_steps[1])):
283
+ crossattn = torch.zeros_like(crossattn)[:, :1]
284
+ cond["context"] = [crossattn]
285
+
286
+ model_cond = {k: v for k, v in cond.items() if k != "inpaint_bg"}
287
+ return self.model(x_noisy, t, **model_cond)
288
+
289
+
290
+ def prepare_conditions(self, *args, **kwargs):
291
+ raise NotImplementedError("Inputs preprocessing function is not implemented.")
292
+
293
+
294
+ def check_manipulate(self, scales):
295
+ if exists(scales) and len(scales) > 0:
296
+ for scale in scales:
297
+ if scale > 0:
298
+ return True
299
+ return False
300
+
301
+ @torch.inference_mode()
302
+ def generate(
303
+ self,
304
+ # Conditional inputs
305
+ cond: dict,
306
+ ctl_scale: Union[float|list[float]],
307
+ merge_scale: float,
308
+ mask_scale: float,
309
+ mask_thresh: float,
310
+ mask_thresh_sketch: float,
311
+
312
+ # Sampling settings
313
+ sampler,
314
+ scheduler,
315
+ step: int,
316
+ bs: int,
317
+ gs: list[float],
318
+ strength: Union[float, list[float]],
319
+ fg_strength: float,
320
+ bg_strength: float,
321
+ seed: int,
322
+ start_step: float = 0.0,
323
+ end_step: float = 1.0,
324
+ no_start_step: float = -0.05,
325
+ no_end_step: float = -0.05,
326
+ deterministic: bool = False,
327
+ style_enhance: bool = False,
328
+ bg_enhance: bool = False,
329
+ fg_enhance: bool = False,
330
+ latent_inpaint: bool = False,
331
+ height: int = 512,
332
+ width: int = 512,
333
+
334
+ # Injection settings
335
+ injection: bool = False,
336
+ injection_cfg: float = 0.5,
337
+ injection_control: float = 0,
338
+ injection_start_step: float = 0,
339
+ hook_xr: torch.Tensor = None,
340
+ hook_xs: torch.Tensor = None,
341
+
342
+ # Additional settings
343
+ low_vram: bool = True,
344
+ return_intermediate = False,
345
+ manipulation_params = None,
346
+ **kwargs,
347
+ ):
348
+ """
349
+ User interface function.
350
+ """
351
+ hook_unet = UnetHook()
352
+
353
+ self.guidance_steps = (start_step, end_step)
354
+ self.no_guidance_steps = (no_start_step, no_end_step)
355
+ self.adjust_reference_scale(strength)
356
+ self.adjust_fgbg_scale(fg_strength, bg_strength, merge_scale, mask_thresh_sketch)
357
+
358
+ if low_vram:
359
+ self.low_vram_shift(self.switch_cond_modules)
360
+ else:
361
+ self.low_vram_shift(list(self.model_list.keys()))
362
+
363
+ c, uc = self.prepare_conditions(
364
+ bs = bs,
365
+ control_scale = ctl_scale,
366
+ merge_scale = merge_scale,
367
+ mask_scale = mask_scale,
368
+ mask_threshold_ref = mask_thresh,
369
+ mask_threshold_sketch = mask_thresh_sketch,
370
+ style_enhance = style_enhance,
371
+ bg_enhance = bg_enhance,
372
+ fg_enhance = fg_enhance,
373
+ latent_inpaint = latent_inpaint,
374
+ height = height,
375
+ width = width,
376
+ bg_strength = bg_strength,
377
+ low_vram = low_vram,
378
+ **cond,
379
+ **manipulation_params,
380
+ **kwargs
381
+ )
382
+
383
+ cfg = int(gs[0] > 1) * GuidanceFlag.reference + int(gs[1] > 1) * GuidanceFlag.sketch
384
+ gr_indice = [] if (cfg == GuidanceFlag.none or cfg == GuidanceFlag.sketch) else [i for i in range(bs, bs*2)]
385
+ repeat = 1
386
+ if cfg == GuidanceFlag.none:
387
+ gs = 1
388
+ uc = None
389
+ if cfg == GuidanceFlag.reference:
390
+ gs = gs[0]
391
+ uc = uc[0]
392
+ repeat = 2
393
+ if cfg == GuidanceFlag.sketch:
394
+ gs = gs[1]
395
+ uc = uc[1]
396
+ repeat = 2
397
+ if cfg == GuidanceFlag.both:
398
+ repeat = 3
399
+
400
+ if low_vram:
401
+ self.low_vram_shift("first")
402
+
403
+ if injection:
404
+ rx = self.get_first_stage_encoding(hook_xr.to(self.first_stage_model.dtype))
405
+ hook_unet.enhance_reference(
406
+ model = self.model,
407
+ ldm = self,
408
+ bs = bs * repeat,
409
+ s = -hook_xr.to(self.dtype),
410
+ r = rx,
411
+ style_cfg = injection_cfg,
412
+ control_cfg = injection_control,
413
+ gr_indice = gr_indice,
414
+ start_step = injection_start_step,
415
+ )
416
+
417
+ if low_vram:
418
+ self.low_vram_shift(self.switch_main_modules)
419
+
420
+ z = self.sample(
421
+ cond = c,
422
+ uncond = uc,
423
+ bs = bs,
424
+ shape = (self.channels, height // 8, width // 8),
425
+ cfg_scale = gs,
426
+ step = step,
427
+ sampler = sampler,
428
+ scheduler = scheduler,
429
+ seed = seed,
430
+ deterministic = deterministic,
431
+ return_intermediate = return_intermediate,
432
+ )
433
+
434
+ if injection:
435
+ hook_unet.restore(self.model)
436
+
437
+ if low_vram:
438
+ self.low_vram_shift("first")
439
+ return self.decode_first_stage(z.to(self.first_stage_model.dtype))
refnet/models/colorizerXL.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from ..modules.reference_net import hack_inference_forward
6
+ from ..models.basemodel import CustomizedColorizer, CustomizedWrapper
7
+ from ..modules.lora import LoraModules
8
+ from ..util import exists, expand_to_batch_size, instantiate_from_config, get_crop_scale, resize_and_crop
9
+
10
+
11
+
12
+ class InferenceWrapper(CustomizedWrapper, CustomizedColorizer):
13
+ def __init__(
14
+ self,
15
+ scalar_embedder_config,
16
+ img_embedder_config,
17
+ lora_config = None,
18
+ logits_embed = False,
19
+ *args,
20
+ **kwargs
21
+ ):
22
+ CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs)
23
+ CustomizedWrapper.__init__(self)
24
+
25
+ self.scalar_embedder = instantiate_from_config(scalar_embedder_config)
26
+ self.img_embedder = instantiate_from_config(img_embedder_config)
27
+ self.loras = LoraModules(self, **lora_config) if exists(lora_config) else None
28
+ self.logits_embed = logits_embed
29
+
30
+ new_model_list = {
31
+ "scalar_embedder": self.scalar_embedder,
32
+ "img_embedder": self.img_embedder,
33
+ # "style_encoder": self.style_encoder,
34
+ }
35
+ self.switch_cond_modules += list(new_model_list.keys())
36
+ self.model_list.update(new_model_list)
37
+
38
+ def retrieve_attn_modules(self):
39
+ scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25}
40
+
41
+ from refnet.modules.transformer import BasicTransformerBlock
42
+ from refnet.sampling import torch_dfs
43
+
44
+ attn_modules = []
45
+ for module in torch_dfs(self.model.diffusion_model):
46
+ if isinstance(module, BasicTransformerBlock):
47
+ attn_modules.append(module)
48
+
49
+ self.attn_modules = {
50
+ "high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69],
51
+ "low": [i for i in range(4, 24)] + [i for i in range(34, 64)],
52
+ "bottom": [i for i in range(24, 34)],
53
+ "encoder": [i for i in range(24)],
54
+ "decoder": [i for i in range(34, len(attn_modules))]
55
+ }
56
+ self.attn_modules["modules"] = attn_modules
57
+
58
+ for k in ["high", "low", "bottom"]:
59
+ scale_factor = scale_factor_levels[k]
60
+ for attn in self.attn_modules[k]:
61
+ attn_modules[attn].scale_factor = scale_factor
62
+
63
+ def adjust_reference_scale(self, scale_kwargs):
64
+ for module in self.attn_modules["modules"]:
65
+ module.reference_scale = scale_kwargs["scales"]["encoder"]
66
+
67
+ def adjust_masked_attn(self, scale, mask_threshold, merge_scale):
68
+ for layer in self.attn_layers:
69
+ layer.mask_scale = scale
70
+ layer.mask_threshold = mask_threshold
71
+ layer.merge_scale = merge_scale
72
+
73
+ def rescale_size(self, x: torch.Tensor, height, width):
74
+ oh, ow = x.shape[2:]
75
+ if oh < height or ow < width:
76
+ dh, dw = height - oh, width - ow
77
+ if dh > dw:
78
+ iw = ow + int(dh * ow/oh)
79
+ ih = height
80
+ else:
81
+ ih = oh + int(dw * oh/ow)
82
+ iw = width
83
+ else:
84
+ ih, iw = oh, ow
85
+ return torch.Tensor([ih]), torch.Tensor([iw])
86
+
87
+ def get_learned_embedding(self, c, bg=False, mapping=False, sketch=None, *args, **kwargs):
88
+ clip_emb = self.cond_stage_model.encode(c, "full").detach()
89
+ wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True)
90
+ cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:]
91
+
92
+ if mapping:
93
+ _, sketch_logits = self.img_embedder.encode(-sketch, pooled=False, return_logits=True)
94
+ sketch_logits.mean(dim=1, keepdim=True)
95
+ logits = self.img_embedder.geometry_update(logits, sketch_logits)
96
+ emb = self.proj(clip_emb, logits if self.logits_embed else wd_emb, bg)
97
+ return emb, cls_emb
98
+
99
+ def prepare_conditions(
100
+ self,
101
+ bs,
102
+ sketch,
103
+ reference,
104
+ height,
105
+ width,
106
+ control_scale = (1., 1., 1., 1.),
107
+ merge_scale = 0,
108
+ mask_scale = 1.,
109
+ fg_scale = 1.,
110
+ bg_scale = 1.,
111
+ smask = None,
112
+ rmask = None,
113
+ mask_threshold_ref = 0.,
114
+ mask_threshold_sketch = 0.,
115
+ style_enhance = False,
116
+ fg_enhance = False,
117
+ bg_enhance = False,
118
+ background = None,
119
+ targets = None,
120
+ anchors = None,
121
+ controls = None,
122
+ target_scales = None,
123
+ enhances = None,
124
+ thresholds_list = None,
125
+ geometry_map = False,
126
+ latent_inpaint = False,
127
+ low_vram = False,
128
+ *args,
129
+ **kwargs
130
+ ):
131
+ # prepare reference embedding
132
+ # manipulate = self.check_manipulate(target_scales)
133
+ c = {}
134
+ uc = [{}, {}]
135
+
136
+ if exists(reference):
137
+ emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch, mapping=geometry_map)
138
+ else:
139
+ emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch))
140
+
141
+ h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.])
142
+ y = torch.cat(self.scalar_embedder(torch.cat([(h*w)**0.5, score])).cuda().chunk(2), 1)
143
+
144
+ if bg_enhance:
145
+ assert exists(rmask) and exists(smask)
146
+
147
+ if low_vram:
148
+ self.low_vram_shift(["first", "cond", "img_embedder", "proj"])
149
+
150
+ if latent_inpaint and exists(background):
151
+ bgh, bgw = background.shape[2:]
152
+ ch, cw = get_crop_scale(torch.tensor([height]), torch.tensor([width]), bgh, bgw)
153
+ hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width).to(self.first_stage_model.dtype))
154
+ bg_emb, _ = self.get_learned_embedding(background, bg=True)
155
+ hs_bg = expand_to_batch_size(hs_bg, bs)
156
+ c.update({"inpaint_bg": hs_bg})
157
+ else:
158
+ if exists(background):
159
+ bg_emb, _ = self.get_learned_embedding(background, bg=True)
160
+ else:
161
+ bg_emb, _ = self.get_learned_embedding(
162
+ torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference)),
163
+ True
164
+ )
165
+ emb = torch.cat([emb, bg_emb], 1)
166
+
167
+ if fg_enhance and exists(self.loras):
168
+ self.loras.switch_lora(True, "foreground")
169
+ if not bg_enhance:
170
+ emb = emb.repeat(1, 2, 1)
171
+
172
+ if fg_enhance or bg_enhance:
173
+ # sketch mask for cross-attention
174
+ smask = expand_to_batch_size(smask.to(self.dtype), bs)
175
+ for d in [c] + uc:
176
+ d.update({"mask": F.interpolate(smask, scale_factor=0.125)})
177
+ elif exists(self.loras):
178
+ self.loras.switch_lora(False)
179
+
180
+ sketch = sketch.to(self.dtype)
181
+ context = expand_to_batch_size(emb, bs).to(self.dtype)
182
+ y = expand_to_batch_size(y, bs)
183
+ uc_context = torch.zeros_like(context)
184
+
185
+ control = []
186
+ uc_control = []
187
+ if low_vram:
188
+ self.low_vram_shift(["control_encoder"])
189
+ encoded_sketch = self.control_encoder(
190
+ torch.cat([sketch, -torch.ones_like(sketch)], 0)
191
+ )
192
+ for idx, es in enumerate(encoded_sketch):
193
+ es = es * control_scale[idx]
194
+ ec, uec = es.chunk(2)
195
+ control.append(expand_to_batch_size(ec, bs))
196
+ uc_control.append(expand_to_batch_size(uec, bs))
197
+
198
+ c.update({"control": control, "context": [context], "y": [y]})
199
+ uc[0].update({"control": control, "context": [uc_context], "y": [y]})
200
+ uc[1].update({"control": uc_control, "context": [context], "y": [y]})
201
+ return c, uc
refnet/models/v2-colorizerXL.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from refnet.models.basemodel import CustomizedColorizer, CustomizedWrapper
2
+ from refnet.util import *
3
+ from refnet.modules.lora import LoraModules
4
+ from refnet.modules.reference_net import hack_unet_forward, hack_inference_forward
5
+ from refnet.sampling.hook import ReferenceAttentionControl
6
+
7
+
8
+ class InferenceWrapperXL(CustomizedWrapper, CustomizedColorizer):
9
+ def __init__(
10
+ self,
11
+ scalar_embedder_config,
12
+ img_embedder_config,
13
+ fg_encoder_config = None,
14
+ bg_encoder_config = None,
15
+ style_encoder_config = None,
16
+ lora_config = None,
17
+ logits_embed = False,
18
+ controller = False,
19
+ *args,
20
+ **kwargs
21
+ ):
22
+ CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs)
23
+ CustomizedWrapper.__init__(self)
24
+
25
+ self.logits_embed = logits_embed
26
+
27
+ (
28
+ self.scalar_embedder,
29
+ self.img_embedder,
30
+ self.fg_encoder,
31
+ self.bg_encoder,
32
+ self.style_encoder
33
+ ) = map(
34
+ lambda t: instantiate_from_config(t) if exists(t) else None,
35
+ (
36
+ scalar_embedder_config,
37
+ img_embedder_config,
38
+ fg_encoder_config,
39
+ bg_encoder_config,
40
+ style_encoder_config
41
+ )
42
+ )
43
+ self.loras = LoraModules(self, **lora_config)
44
+
45
+ if controller:
46
+ self.controller = ReferenceAttentionControl(
47
+ # time_embed_ch = self.model.diffusion_model.model_channels * 4,
48
+ reader_module = self.model.diffusion_model,
49
+ writer_module = self.bg_encoder,
50
+ # only_decoder = True
51
+ )
52
+ else:
53
+ self.controller = None
54
+
55
+ new_model_list = {
56
+ # "style_encoder": self.style_encoder,
57
+ "scalar_embedder": self.scalar_embedder,
58
+ "img_embedder": self.img_embedder,
59
+ # "controller": self.controller
60
+ }
61
+
62
+ hack_unet_forward(self.model.diffusion_model)
63
+ if exists(self.fg_encoder):
64
+ hack_inference_forward(self.fg_encoder)
65
+ new_model_list["fg_encoder"] = self.fg_encoder
66
+ if exists(self.bg_encoder):
67
+ hack_inference_forward(self.bg_encoder)
68
+ new_model_list["bg_encoder"] = self.bg_encoder
69
+ # hack_inference_forward(self.bg_encoder)
70
+ # hack_inference_forward(self.style_encoder)
71
+
72
+ self.switch_cond_modules += list(new_model_list.keys())
73
+ # self.switch_main_modules += ["controller"]
74
+ self.model_list.update(new_model_list)
75
+
76
+
77
+ def switch_to_fp16(self):
78
+ super().switch_to_fp16()
79
+ self.model.diffusion_model.map_modules.to(self.half_precision_dtype)
80
+ self.model.diffusion_model.warp_modules.to(self.half_precision_dtype)
81
+ self.model.diffusion_model.style_modules.to(self.half_precision_dtype)
82
+ self.model.diffusion_model.conv_fg.to(self.half_precision_dtype)
83
+
84
+ if exists(self.fg_encoder):
85
+ self.fg_encoder.to(self.half_precision_dtype)
86
+ self.fg_encoder.dtype = self.half_precision_dtype
87
+ self.fg_encoder.time_embed.float()
88
+ if exists(self.bg_encoder):
89
+ self.bg_encoder.to(self.half_precision_dtype)
90
+ self.bg_encoder.dtype = self.half_precision_dtype
91
+ self.bg_encoder.time_embed.float()
92
+ # self.style_encoder.to(self.half_precision_dtype)
93
+ # self.style_encoder.dtype = self.half_precision_dtype
94
+ # self.style_encoder.time_embed.float()
95
+
96
+ def switch_to_fp32(self):
97
+ super().switch_to_fp32()
98
+ self.model.diffusion_model.map_modules.float()
99
+ self.model.diffusion_model.warp_modules.float()
100
+ self.model.diffusion_model.style_modules.float()
101
+
102
+ self.fg_encoder.float()
103
+ self.bg_encoder.float()
104
+ # self.style_encoder.float()
105
+
106
+ self.fg_encoder.dtype = torch.float32
107
+ self.bg_encoder.dtype = torch.float32
108
+ # self.style_encoder.dtype = torch.float32
109
+
110
+ def rescale_size(self, x: torch.Tensor, height, width):
111
+ oh, ow = x.shape[2:]
112
+ if oh < height or ow < width:
113
+ dh, dw = height - oh, width - ow
114
+ if dh > dw:
115
+ iw = ow + int(dh * ow/oh)
116
+ ih = height
117
+ else:
118
+ ih = oh + int(dw * oh/ow)
119
+ iw = width
120
+ else:
121
+ ih, iw = oh, ow
122
+ return torch.tensor([ih]), torch.tensor([iw])
123
+
124
+ def rescale_background_size(self, x, height, width):
125
+ oh, ow = x.shape[2:]
126
+ if oh < height or ow < width:
127
+ # A simple bias to avoid deterioration caused by reference resolution
128
+ mind = max(height, width)
129
+ ih = oh + mind
130
+ iw = ow / oh * ih
131
+ else:
132
+ ih, iw = oh, ow
133
+ # rh, rw = ih / height, iw / width
134
+ return torch.tensor([ih]), torch.tensor([iw])
135
+
136
+ def get_learned_embedding(self, c, bg=False, sketch=None, mapping=False, *args, **kwargs):
137
+ clip_emb = self.cond_stage_model.encode(c, "full").detach()
138
+ wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True)
139
+ cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:]
140
+
141
+ if self.logits_embed and exists(sketch) and mapping:
142
+ _, sketch_logits = self.img_embedder.encode(-sketch, pooled=True, return_logits=True)
143
+ logits = self.img_embedder.geometry_update(logits, sketch_logits)
144
+
145
+ if self.logits_embed:
146
+ emb = self.proj(clip_emb, logits, bg)[0]
147
+ else:
148
+ emb = self.proj(clip_emb, wd_emb, bg)
149
+ return emb.to(self.dtype), cls_emb.to(self.dtype)
150
+
151
+ def prepare_conditions(
152
+ self,
153
+ bs,
154
+ sketch,
155
+ reference,
156
+ height,
157
+ width,
158
+ control_scale = 1,
159
+ mask_scale = 1,
160
+ merge_scale = 0.,
161
+ cond_aug = 0.,
162
+ background = None,
163
+ smask = None,
164
+ rmask = None,
165
+ mask_threshold_ref = 0.,
166
+ mask_threshold_sketch = 0.,
167
+ style_enhance = False,
168
+ fg_enhance = False,
169
+ bg_enhance = False,
170
+ latent_inpaint = False,
171
+ fg_disentangle_scale = 1.,
172
+ targets = None,
173
+ anchors = None,
174
+ controls = None,
175
+ target_scales = None,
176
+ enhances = None,
177
+ thresholds_list = None,
178
+ low_vram = False,
179
+ *args,
180
+ **kwargs
181
+ ):
182
+ def prepare_style_modulations(y):
183
+ # Style enhancement part
184
+ z_ref = self.get_first_stage_encoding(warp_resize(reference, (height, width)))
185
+ if exists(background) and merge_scale > 0:
186
+ rh, rw = self.rescale_size(background, height, width)
187
+ z_bg = self.get_first_stage_encoding(warp_resize(background, (height, width)))
188
+ bg_emb, bg_cls_emb = self.get_learned_embedding(background)
189
+ scalar_embed = torch.cat(
190
+ self.scalar_embedder(torch.cat([rh, rw, ct, cl, h, w])).chunk(6), 1
191
+ ).to(bg_emb.device)
192
+ bgy = torch.cat([bg_cls_emb.squeeze(1), scalar_embed], 1).to(self.dtype)
193
+
194
+ style_modulations = self.style_encoder(
195
+ torch.cat([z_ref, z_bg]),
196
+ timesteps = torch.zeros((2,), dtype=torch.long, device=z_ref.device),
197
+ context = torch.cat([emb, bg_emb]),
198
+ y = torch.cat([y, bgy])
199
+ )
200
+
201
+ for idx, m in enumerate(style_modulations):
202
+ fg, bg = m.chunk(2)
203
+ m = fg * (1-merge_scale) + merge_scale * bg
204
+ style_modulations[idx] = expand_to_batch_size(m, bs).to(self.dtype)
205
+
206
+ else:
207
+ z_bg = None
208
+ bg_emb = None
209
+ bgy = None
210
+ style_modulations = self.style_encoder(
211
+ z_ref,
212
+ timesteps = torch.zeros((1,), dtype=torch.long, device=z_ref.device),
213
+ context = emb,
214
+ y = y,
215
+ )
216
+ style_modulations = [expand_to_batch_size(m, bs).to(self.dtype) for m in style_modulations]
217
+
218
+ return style_modulations, z_bg, bg_emb, bgy
219
+
220
+
221
+ def prepare_background_latents(z_bg, bg_emb, bgy):
222
+ # Background enhancement part
223
+ bgh, bgw = background.shape[2:] if exists(background) else reference.shape[2:]
224
+ ch, cw = get_crop_scale(h, w, bgh, bgw)
225
+
226
+ if low_vram:
227
+ self.low_vram_shift(["first", "cond", "img_embedder"])
228
+ if latent_inpaint and exists(background):
229
+ hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width))
230
+ bg_emb, cls_emb = self.get_learned_embedding(background)
231
+
232
+ else:
233
+ if not exists(z_bg):
234
+ bgy = torch.cat(
235
+ self.scalar_embedder(torch.tensor([ct, cl, ch, cw])).chunk(4), 1
236
+ # self.scalar_embedder(torch.tensor([bgh / bgw, h / w, ct, cl, ch, cw])).chunk(6), 1
237
+ ).to(self.dtype).cuda()
238
+
239
+ if exists(background):
240
+ # bgh, bgw = self.rescale_background_size(background, height, width)
241
+ z_bg = self.get_first_stage_encoding(warp_resize(background, (height, width)))
242
+ bg_emb, cls_emb = self.get_learned_embedding(background)
243
+ # scalar_embed = torch.cat(self.scalar_embedder(torch.cat([bgh, bgw, ct, cl, h, w])).chunk(6), 1).cuda()
244
+ # bgy = torch.cat([cls_emb.squeeze(1), scalar_embed], 1).to(self.dtype)
245
+ else:
246
+ xbg = torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference))
247
+ z_bg = self.get_first_stage_encoding(warp_resize(xbg, (height, width)))
248
+ bg_emb, cls_emb = self.get_learned_embedding(xbg)
249
+
250
+ if low_vram:
251
+ self.low_vram_shift(["bg_encoder"])
252
+ hs_bg = self.bg_encoder(
253
+ x = torch.cat([
254
+ z_bg,
255
+ # torch.where(
256
+ # smask > mask_threshold_sketch,
257
+ # torch.zeros_like(smask),
258
+ # F.interpolate(warp_resize(rmask, (height, width)), scale_factor=0.125)
259
+ # )
260
+ F.interpolate(warp_resize(smask, (height, width)), scale_factor=0.125),
261
+ F.interpolate(warp_resize(rmask, (height, width)), scale_factor=0.125)
262
+ ], 1),
263
+ timesteps = torch.zeros((1,), dtype=torch.long, device=z_bg.device),
264
+ # context = bg_emb,
265
+ y = bgy.to(self.dtype),
266
+ )
267
+ return hs_bg, bg_emb
268
+
269
+ self.loras.recover_lora()
270
+ # prepare reference embedding
271
+ # manipulate = self.check_manipulate(target_scales)
272
+ c = {}
273
+ uc = [{}, {}]
274
+ self.loras.switch_lora(False)
275
+ # self.loras.recover_lora()
276
+
277
+ if exists(reference):
278
+ emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch)
279
+ # rh, rw = reference.shape[2:]
280
+ # rh, rw = self.rescale_background_size(reference, height, width)
281
+
282
+ else:
283
+ emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch))
284
+ # rh, rw = torch.Tensor([height]), torch.Tensor([width])
285
+
286
+ ct, cl = torch.Tensor([0]), torch.Tensor([0])
287
+ # h, w = torch.Tensor([height]), torch.Tensor([width])
288
+ # scalar_embed = torch.cat(self.scalar_embedder(torch.cat([rh, rw, ct, cl, h, w])).chunk(6), 1).cuda()
289
+ # y = torch.cat([cls_emb.squeeze(1), scalar_embed], 1)
290
+ # y = self.scalar_embedder((h*w)**0.5).cuda()
291
+ # y = torch.cat(self.scalar_embedder(torch.cat([h, w])).chunk(2), 1).cuda()
292
+ h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.])
293
+ y = torch.cat(self.scalar_embedder(torch.cat([(h * w) ** 0.5, score])).cuda().chunk(2), 1)
294
+
295
+ z_bg, bg_emb, bgy = None, None, None
296
+
297
+ # Style enhance part
298
+ if style_enhance:
299
+ style_modulations, z_bg, bg_emb, bgy = prepare_style_modulations(y)
300
+ for d in [c] + uc:
301
+ d.update({"style_modulations": style_modulations})
302
+
303
+ # Foreground enhance part
304
+ if fg_enhance:
305
+ assert exists(smask) and exists(rmask)
306
+ self.loras.switch_lora(True, "foreground")
307
+ if low_vram:
308
+ self.low_vram_shift(["first"])
309
+ z_fg = self.get_first_stage_encoding(warp_resize(
310
+ torch.where(rmask >= mask_threshold_ref, reference, torch.ones_like(reference)),
311
+ (height, width)
312
+ )) * fg_disentangle_scale
313
+ # z_ref = default(z_ref, self.get_first_stage_encoding(warp_resize(reference, (height, width))))
314
+ # self.loras.switch_lora(True, False)
315
+ self.loras.adjust_lora_scales(fg_disentangle_scale, "foreground")
316
+ if low_vram:
317
+ self.low_vram_shift(["fg_encoder"])
318
+ hs_fg = self.fg_encoder(
319
+ z_fg,
320
+ timesteps = torch.zeros((1,), dtype=torch.long, device=z_fg.device),
321
+ )
322
+ # hs_fg = [hs * fg_disentangle_scale for hs in hs_fg]
323
+ hs_fg = expand_to_batch_size(hs_fg, bs)
324
+ for d in [c] + uc:
325
+ d.update({
326
+ "hs_fg": hs_fg,
327
+ "inject_mask": expand_to_batch_size(smask, bs),
328
+ })
329
+ # for d in [c] + uc:
330
+ # d.update({"z_fg": expand_to_batch_size(z_fg, bs)})
331
+
332
+ # Background enhance part
333
+ if bg_enhance:
334
+ assert exists(rmask) and exists(smask)
335
+ # if not self.controller.hooked:
336
+ # self.controller.register("read", self.model.diffusion_model)
337
+ # self.loras.switch_lora(False, True)
338
+ hs_bg, bg_emb = prepare_background_latents(z_bg, bg_emb, default(bgy, y))
339
+ self.loras.switch_lora(True, "background")
340
+ if latent_inpaint and exists(background):
341
+ hs_bg = expand_to_batch_size(hs_bg, bs)
342
+ c.update({"inpaint_bg": hs_bg})
343
+ elif exists(self.controller):
344
+ # self.loras.merge_lora()
345
+ self.controller.update()
346
+ else:
347
+ hs_bg = expand_to_batch_size(hs_bg, bs)
348
+ for d in [c] + uc:
349
+ d.update({"hs_bg": hs_bg})
350
+
351
+ elif exists(self.controller):
352
+ # self.controller.reader_restore()
353
+ self.controller.clean()
354
+
355
+ if fg_enhance or bg_enhance:
356
+ # need to activate mask-guided split cross-attetnion
357
+ emb = torch.cat([emb, default(bg_emb, emb)], 1)
358
+ smask = expand_to_batch_size(smask.to(self.dtype), bs)
359
+ for d in [c] + uc:
360
+ d.update({"mask": F.interpolate(smask, scale_factor=0.125), "threshold": mask_threshold_sketch})
361
+
362
+ # if fg_enhance and bg_enhance:
363
+ # self.loras.switch_lora(True, True)
364
+ sketch = sketch.to(self.dtype)
365
+ context = expand_to_batch_size(emb, bs).to(self.dtype)
366
+ y = expand_to_batch_size(y, bs).float()
367
+ uc_context = torch.zeros_like(context)
368
+
369
+ control = []
370
+ uc_control = []
371
+ if low_vram:
372
+ self.low_vram_shift(["control_encoder"])
373
+ encoded_sketch = self.control_encoder(
374
+ torch.cat([sketch, -torch.ones_like(sketch)], 0)
375
+ )
376
+ for idx, es in enumerate(encoded_sketch):
377
+ es = es * control_scale[idx]
378
+ ec, uec = es.chunk(2)
379
+ control.append(expand_to_batch_size(ec, bs))
380
+ uc_control.append(expand_to_batch_size(uec, bs))
381
+
382
+ self.loras.merge_lora()
383
+ c.update({"control": control, "context": [context], "y": [y]})
384
+ uc[0].update({"control": control, "context": [uc_context], "y": [y]})
385
+ uc[1].update({"control": uc_control, "context": [context], "y": [y]})
386
+ return c, uc