Spaces:
Running on Zero
Running on Zero
Add refnet/models and ldm/models source files previously excluded by .gitignore
Browse files- ldm/models/autoencoder.py +37 -0
- refnet/models/basemodel.py +439 -0
- refnet/models/colorizerXL.py +201 -0
- refnet/models/v2-colorizerXL.py +386 -0
ldm/models/autoencoder.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
| 4 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AutoencoderKL(torch.nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
ddconfig,
|
| 11 |
+
embed_dim
|
| 12 |
+
):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.encoder = Encoder(**ddconfig)
|
| 15 |
+
self.decoder = Decoder(**ddconfig)
|
| 16 |
+
assert ddconfig["double_z"]
|
| 17 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
| 18 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 19 |
+
self.embed_dim = embed_dim
|
| 20 |
+
|
| 21 |
+
def encode(self, x):
|
| 22 |
+
h = self.encoder(x)
|
| 23 |
+
moments = self.quant_conv(h)
|
| 24 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 25 |
+
return posterior
|
| 26 |
+
|
| 27 |
+
def decode(self, z):
|
| 28 |
+
z = self.post_quant_conv(z)
|
| 29 |
+
dec = self.decoder(z)
|
| 30 |
+
return dec
|
| 31 |
+
|
| 32 |
+
def get_last_layer(self):
|
| 33 |
+
return self.decoder.conv_out.weight
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def dtype(self):
|
| 37 |
+
return self.decoder.conv_out.weight.dtype
|
refnet/models/basemodel.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from refnet.util import exists, fitting_weights, instantiate_from_config, load_weights, delete_states
|
| 4 |
+
from refnet.ldm import LatentDiffusion
|
| 5 |
+
from typing import Union
|
| 6 |
+
from refnet.sampling import (
|
| 7 |
+
UnetHook,
|
| 8 |
+
KDiffusionSampler,
|
| 9 |
+
DiffuserDenoiser,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GuidanceFlag:
|
| 15 |
+
none = 0
|
| 16 |
+
reference = 1
|
| 17 |
+
sketch = 10
|
| 18 |
+
both = 11
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def reconstruct_cond(cond, uncond):
|
| 22 |
+
if not isinstance(uncond, list):
|
| 23 |
+
uncond = [uncond]
|
| 24 |
+
for k in cond.keys():
|
| 25 |
+
if k == "inpaint_bg":
|
| 26 |
+
continue
|
| 27 |
+
for uc in uncond:
|
| 28 |
+
if isinstance(cond[k], list):
|
| 29 |
+
cond[k] = [torch.cat([cond[k][i], uc[k][i]]) for i in range(len(cond[k]))]
|
| 30 |
+
elif isinstance(cond[k], torch.Tensor):
|
| 31 |
+
cond[k] = torch.cat([cond[k], uc[k]])
|
| 32 |
+
return cond
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CustomizedLDM(LatentDiffusion):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
dtype = torch.float32,
|
| 39 |
+
sigma_max = None,
|
| 40 |
+
sigma_min = None,
|
| 41 |
+
*args,
|
| 42 |
+
**kwargs
|
| 43 |
+
):
|
| 44 |
+
super().__init__(*args, **kwargs)
|
| 45 |
+
self.dtype = dtype
|
| 46 |
+
self.sigma_max = sigma_max
|
| 47 |
+
self.sigma_min = sigma_min
|
| 48 |
+
|
| 49 |
+
self.model_list = {
|
| 50 |
+
"first": self.first_stage_model,
|
| 51 |
+
"cond": self.cond_stage_model,
|
| 52 |
+
"unet": self.model,
|
| 53 |
+
}
|
| 54 |
+
self.switch_cond_modules = ["cond"]
|
| 55 |
+
self.switch_main_modules = ["unet"]
|
| 56 |
+
self.retrieve_attn_modules()
|
| 57 |
+
self.retrieve_attn_layers()
|
| 58 |
+
|
| 59 |
+
def init_from_ckpt(
|
| 60 |
+
self,
|
| 61 |
+
path,
|
| 62 |
+
only_model = False,
|
| 63 |
+
logging = False,
|
| 64 |
+
make_it_fit = False,
|
| 65 |
+
ignore_keys: list[str] = (),
|
| 66 |
+
):
|
| 67 |
+
sd = delete_states(load_weights(path), ignore_keys)
|
| 68 |
+
if make_it_fit:
|
| 69 |
+
sd = fitting_weights(self, sd)
|
| 70 |
+
|
| 71 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model \
|
| 72 |
+
else self.model.load_state_dict(sd, strict=False)
|
| 73 |
+
|
| 74 |
+
filtered_missing = []
|
| 75 |
+
filtered_unexpect = []
|
| 76 |
+
for k in missing:
|
| 77 |
+
if not k.find("cond_stage_model") > -1 and not k.find("img_embedder") > -1 and not k.find("fg") > -1:
|
| 78 |
+
filtered_missing.append(k)
|
| 79 |
+
for k in unexpected:
|
| 80 |
+
if not k.find("cond_stage_model") > -1 and not k.find("img_embedder") > -1:
|
| 81 |
+
filtered_unexpect.append(k)
|
| 82 |
+
|
| 83 |
+
print(
|
| 84 |
+
f"Restored from {path} with {len(filtered_missing)} filtered missing and "
|
| 85 |
+
f"{len(filtered_unexpect)} filtered unexpected keys")
|
| 86 |
+
if logging:
|
| 87 |
+
if len(missing) > 0:
|
| 88 |
+
print(f"Filtered missing Keys: {filtered_missing}")
|
| 89 |
+
if len(unexpected) > 0:
|
| 90 |
+
print(f"Filtered unexpected Keys: {filtered_unexpect}")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def sample(
|
| 94 |
+
self,
|
| 95 |
+
cond: dict,
|
| 96 |
+
uncond: Union[dict, list[dict]] = None,
|
| 97 |
+
cfg_scale: Union[float, list[float]] = 1.,
|
| 98 |
+
bs: int = 1,
|
| 99 |
+
shape: Union[tuple, list] = None,
|
| 100 |
+
step: int = 20,
|
| 101 |
+
sampler = "DPM++ 3M SDE",
|
| 102 |
+
scheduler = "Automatic",
|
| 103 |
+
device = "cuda",
|
| 104 |
+
x_T = None,
|
| 105 |
+
seed = None,
|
| 106 |
+
deterministic = False,
|
| 107 |
+
**kwargs
|
| 108 |
+
):
|
| 109 |
+
shape = shape or (self.channels, self.image_size, self.image_size)
|
| 110 |
+
x = x_T or torch.randn(bs, *shape, device=device)
|
| 111 |
+
|
| 112 |
+
if exists(uncond):
|
| 113 |
+
cond = reconstruct_cond(cond, uncond)
|
| 114 |
+
|
| 115 |
+
if sampler.startswith("diffuser"):
|
| 116 |
+
# Using huggingface diffuser noise sampler and scheduler
|
| 117 |
+
sampler = DiffuserDenoiser(
|
| 118 |
+
sampler,
|
| 119 |
+
prediction_type = "v_prediction" if self.parameterization == "v" else "epsilon",
|
| 120 |
+
use_karras = scheduler == "Karras"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
samples = sampler(
|
| 124 |
+
x,
|
| 125 |
+
cond,
|
| 126 |
+
cond_scale=cfg_scale,
|
| 127 |
+
unet=self,
|
| 128 |
+
timesteps=step,
|
| 129 |
+
generator=torch.manual_seed(seed) if exists(seed) else None,
|
| 130 |
+
device=device
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
# Using k-diffusion sampler and noise scheduler
|
| 135 |
+
seed = seed or torch.seed()
|
| 136 |
+
sampler = KDiffusionSampler(sampler, scheduler, self, device)
|
| 137 |
+
|
| 138 |
+
sigmas = sampler.get_sigmas(step)
|
| 139 |
+
extra_args = {
|
| 140 |
+
"cond": cond,
|
| 141 |
+
"cond_scale": cfg_scale,
|
| 142 |
+
}
|
| 143 |
+
seed = [seed for _ in range(bs)] if deterministic else seed
|
| 144 |
+
samples = sampler(x, sigmas, extra_args, seed, deterministic, step)
|
| 145 |
+
|
| 146 |
+
return samples
|
| 147 |
+
|
| 148 |
+
def switch_to_fp16(self):
|
| 149 |
+
unet = self.model.diffusion_model
|
| 150 |
+
unet.input_blocks = unet.input_blocks.to(self.half_precision_dtype)
|
| 151 |
+
unet.middle_block = unet.middle_block.to(self.half_precision_dtype)
|
| 152 |
+
unet.output_blocks = unet.output_blocks.to(self.half_precision_dtype)
|
| 153 |
+
self.dtype = self.half_precision_dtype
|
| 154 |
+
unet.dtype = self.half_precision_dtype
|
| 155 |
+
|
| 156 |
+
def switch_to_fp32(self):
|
| 157 |
+
unet = self.model.diffusion_model
|
| 158 |
+
unet.input_blocks = unet.input_blocks.float()
|
| 159 |
+
unet.middle_block = unet.middle_block.float()
|
| 160 |
+
unet.output_blocks = unet.output_blocks.float()
|
| 161 |
+
self.dtype = torch.float32
|
| 162 |
+
unet.dtype = torch.float32
|
| 163 |
+
|
| 164 |
+
def switch_vae_to_fp16(self):
|
| 165 |
+
self.first_stage_model = self.first_stage_model.to(self.half_precision_dtype)
|
| 166 |
+
|
| 167 |
+
def switch_vae_to_fp32(self):
|
| 168 |
+
self.first_stage_model = self.first_stage_model.float()
|
| 169 |
+
|
| 170 |
+
def low_vram_shift(self, cuda_list: Union[str, list[str]]):
|
| 171 |
+
if not isinstance(cuda_list, list):
|
| 172 |
+
cuda_list = [cuda_list]
|
| 173 |
+
|
| 174 |
+
cpu_list = self.model_list.keys() - cuda_list
|
| 175 |
+
for model in cpu_list:
|
| 176 |
+
self.model_list[model] = self.model_list[model].cpu()
|
| 177 |
+
torch.cuda.empty_cache()
|
| 178 |
+
|
| 179 |
+
for model in cuda_list:
|
| 180 |
+
self.model_list[model] = self.model_list[model].cuda()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def retrieve_attn_modules(self):
|
| 184 |
+
from refnet.modules.transformer import BasicTransformerBlock
|
| 185 |
+
from refnet.sampling import torch_dfs
|
| 186 |
+
|
| 187 |
+
scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25}
|
| 188 |
+
|
| 189 |
+
attn_modules = []
|
| 190 |
+
for module in torch_dfs(self.model.diffusion_model):
|
| 191 |
+
if isinstance(module, BasicTransformerBlock):
|
| 192 |
+
attn_modules.append(module)
|
| 193 |
+
|
| 194 |
+
self.attn_modules = {
|
| 195 |
+
"high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69],
|
| 196 |
+
"low": [i for i in range(4, 24)] + [i for i in range(34, 64)],
|
| 197 |
+
"bottom": [i for i in range(24, 34)],
|
| 198 |
+
"encoder": [i for i in range(24)],
|
| 199 |
+
"decoder": [i for i in range(34, len(attn_modules))]
|
| 200 |
+
}
|
| 201 |
+
self.attn_modules["modules"] = attn_modules
|
| 202 |
+
|
| 203 |
+
for k in ["high", "low", "bottom"]:
|
| 204 |
+
scale_factor = scale_factor_levels[k]
|
| 205 |
+
for attn in self.attn_modules[k]:
|
| 206 |
+
attn_modules[attn].scale_factor = scale_factor
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def retrieve_attn_layers(self):
|
| 210 |
+
self.attn_layers = []
|
| 211 |
+
for module in (self.attn_modules["modules"]):
|
| 212 |
+
if hasattr(module, "attn2") and exists(getattr(module, "attn2")):
|
| 213 |
+
self.attn_layers.append(module.attn2)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class CustomizedColorizer(CustomizedLDM):
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
control_encoder_config,
|
| 220 |
+
proj_config,
|
| 221 |
+
token_type = "full",
|
| 222 |
+
*args,
|
| 223 |
+
**kwargs
|
| 224 |
+
):
|
| 225 |
+
super().__init__(*args, **kwargs)
|
| 226 |
+
self.control_encoder = instantiate_from_config(control_encoder_config)
|
| 227 |
+
self.proj = instantiate_from_config(proj_config)
|
| 228 |
+
self.token_type = token_type
|
| 229 |
+
self.model_list.update({"control_encoder": self.control_encoder, "proj": self.proj})
|
| 230 |
+
self.switch_cond_modules += ["control_encoder", "proj"]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def switch_to_fp16(self):
|
| 234 |
+
self.control_encoder = self.control_encoder.to(self.half_precision_dtype)
|
| 235 |
+
super().switch_to_fp16()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def switch_to_fp32(self):
|
| 239 |
+
self.control_encoder = self.control_encoder.float()
|
| 240 |
+
super().switch_to_fp32()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
from refnet.modules.unet import hack_inference_forward
|
| 244 |
+
class CustomizedWrapper:
|
| 245 |
+
def __init__(self):
|
| 246 |
+
self.scaling_sample = False
|
| 247 |
+
self.guidance_steps = (0, 1)
|
| 248 |
+
self.no_guidance_steps = (-0.05, 0.05)
|
| 249 |
+
hack_inference_forward(self.model.diffusion_model)
|
| 250 |
+
|
| 251 |
+
def adjust_reference_scale(self, scale_kwargs):
|
| 252 |
+
if isinstance(scale_kwargs, dict):
|
| 253 |
+
if scale_kwargs["level_control"]:
|
| 254 |
+
for key in scale_kwargs["scales"]:
|
| 255 |
+
if key == "middle":
|
| 256 |
+
continue
|
| 257 |
+
for idx in self.attn_modules[key]:
|
| 258 |
+
self.attn_modules["modules"][idx].reference_scale = scale_kwargs["scales"][key]
|
| 259 |
+
else:
|
| 260 |
+
for idx, s in enumerate(scale_kwargs["scales"]):
|
| 261 |
+
self.attn_modules["modules"][idx].reference_scale = s
|
| 262 |
+
else:
|
| 263 |
+
for module in self.attn_modules["modules"]:
|
| 264 |
+
module.reference_scale = scale_kwargs
|
| 265 |
+
|
| 266 |
+
def adjust_fgbg_scale(self, fg_scale, bg_scale, merge_scale, mask_threshold):
|
| 267 |
+
for layer in self.attn_layers:
|
| 268 |
+
layer.fg_scale = fg_scale
|
| 269 |
+
layer.bg_scale = bg_scale
|
| 270 |
+
layer.merge_scale = merge_scale
|
| 271 |
+
layer.mask_threshold = mask_threshold
|
| 272 |
+
# for layer in self.attn_modules["modules"]:
|
| 273 |
+
# layer.fg_scale = fg_scale
|
| 274 |
+
# layer.bg_scale = bg_scale
|
| 275 |
+
# layer.merge_scale = merge_scale
|
| 276 |
+
# layer.mask_threshold = mask_threshold
|
| 277 |
+
|
| 278 |
+
def apply_model(self, x_noisy, t, cond):
|
| 279 |
+
tr = 1 - t[0] / (self.num_timesteps - 1)
|
| 280 |
+
crossattn = cond["context"][0]
|
| 281 |
+
if ((tr < self.guidance_steps[0] or tr > self.guidance_steps[1]) or
|
| 282 |
+
(tr >= self.no_guidance_steps[0] and tr <= self.no_guidance_steps[1])):
|
| 283 |
+
crossattn = torch.zeros_like(crossattn)[:, :1]
|
| 284 |
+
cond["context"] = [crossattn]
|
| 285 |
+
|
| 286 |
+
model_cond = {k: v for k, v in cond.items() if k != "inpaint_bg"}
|
| 287 |
+
return self.model(x_noisy, t, **model_cond)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def prepare_conditions(self, *args, **kwargs):
|
| 291 |
+
raise NotImplementedError("Inputs preprocessing function is not implemented.")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def check_manipulate(self, scales):
|
| 295 |
+
if exists(scales) and len(scales) > 0:
|
| 296 |
+
for scale in scales:
|
| 297 |
+
if scale > 0:
|
| 298 |
+
return True
|
| 299 |
+
return False
|
| 300 |
+
|
| 301 |
+
@torch.inference_mode()
|
| 302 |
+
def generate(
|
| 303 |
+
self,
|
| 304 |
+
# Conditional inputs
|
| 305 |
+
cond: dict,
|
| 306 |
+
ctl_scale: Union[float|list[float]],
|
| 307 |
+
merge_scale: float,
|
| 308 |
+
mask_scale: float,
|
| 309 |
+
mask_thresh: float,
|
| 310 |
+
mask_thresh_sketch: float,
|
| 311 |
+
|
| 312 |
+
# Sampling settings
|
| 313 |
+
sampler,
|
| 314 |
+
scheduler,
|
| 315 |
+
step: int,
|
| 316 |
+
bs: int,
|
| 317 |
+
gs: list[float],
|
| 318 |
+
strength: Union[float, list[float]],
|
| 319 |
+
fg_strength: float,
|
| 320 |
+
bg_strength: float,
|
| 321 |
+
seed: int,
|
| 322 |
+
start_step: float = 0.0,
|
| 323 |
+
end_step: float = 1.0,
|
| 324 |
+
no_start_step: float = -0.05,
|
| 325 |
+
no_end_step: float = -0.05,
|
| 326 |
+
deterministic: bool = False,
|
| 327 |
+
style_enhance: bool = False,
|
| 328 |
+
bg_enhance: bool = False,
|
| 329 |
+
fg_enhance: bool = False,
|
| 330 |
+
latent_inpaint: bool = False,
|
| 331 |
+
height: int = 512,
|
| 332 |
+
width: int = 512,
|
| 333 |
+
|
| 334 |
+
# Injection settings
|
| 335 |
+
injection: bool = False,
|
| 336 |
+
injection_cfg: float = 0.5,
|
| 337 |
+
injection_control: float = 0,
|
| 338 |
+
injection_start_step: float = 0,
|
| 339 |
+
hook_xr: torch.Tensor = None,
|
| 340 |
+
hook_xs: torch.Tensor = None,
|
| 341 |
+
|
| 342 |
+
# Additional settings
|
| 343 |
+
low_vram: bool = True,
|
| 344 |
+
return_intermediate = False,
|
| 345 |
+
manipulation_params = None,
|
| 346 |
+
**kwargs,
|
| 347 |
+
):
|
| 348 |
+
"""
|
| 349 |
+
User interface function.
|
| 350 |
+
"""
|
| 351 |
+
hook_unet = UnetHook()
|
| 352 |
+
|
| 353 |
+
self.guidance_steps = (start_step, end_step)
|
| 354 |
+
self.no_guidance_steps = (no_start_step, no_end_step)
|
| 355 |
+
self.adjust_reference_scale(strength)
|
| 356 |
+
self.adjust_fgbg_scale(fg_strength, bg_strength, merge_scale, mask_thresh_sketch)
|
| 357 |
+
|
| 358 |
+
if low_vram:
|
| 359 |
+
self.low_vram_shift(self.switch_cond_modules)
|
| 360 |
+
else:
|
| 361 |
+
self.low_vram_shift(list(self.model_list.keys()))
|
| 362 |
+
|
| 363 |
+
c, uc = self.prepare_conditions(
|
| 364 |
+
bs = bs,
|
| 365 |
+
control_scale = ctl_scale,
|
| 366 |
+
merge_scale = merge_scale,
|
| 367 |
+
mask_scale = mask_scale,
|
| 368 |
+
mask_threshold_ref = mask_thresh,
|
| 369 |
+
mask_threshold_sketch = mask_thresh_sketch,
|
| 370 |
+
style_enhance = style_enhance,
|
| 371 |
+
bg_enhance = bg_enhance,
|
| 372 |
+
fg_enhance = fg_enhance,
|
| 373 |
+
latent_inpaint = latent_inpaint,
|
| 374 |
+
height = height,
|
| 375 |
+
width = width,
|
| 376 |
+
bg_strength = bg_strength,
|
| 377 |
+
low_vram = low_vram,
|
| 378 |
+
**cond,
|
| 379 |
+
**manipulation_params,
|
| 380 |
+
**kwargs
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
cfg = int(gs[0] > 1) * GuidanceFlag.reference + int(gs[1] > 1) * GuidanceFlag.sketch
|
| 384 |
+
gr_indice = [] if (cfg == GuidanceFlag.none or cfg == GuidanceFlag.sketch) else [i for i in range(bs, bs*2)]
|
| 385 |
+
repeat = 1
|
| 386 |
+
if cfg == GuidanceFlag.none:
|
| 387 |
+
gs = 1
|
| 388 |
+
uc = None
|
| 389 |
+
if cfg == GuidanceFlag.reference:
|
| 390 |
+
gs = gs[0]
|
| 391 |
+
uc = uc[0]
|
| 392 |
+
repeat = 2
|
| 393 |
+
if cfg == GuidanceFlag.sketch:
|
| 394 |
+
gs = gs[1]
|
| 395 |
+
uc = uc[1]
|
| 396 |
+
repeat = 2
|
| 397 |
+
if cfg == GuidanceFlag.both:
|
| 398 |
+
repeat = 3
|
| 399 |
+
|
| 400 |
+
if low_vram:
|
| 401 |
+
self.low_vram_shift("first")
|
| 402 |
+
|
| 403 |
+
if injection:
|
| 404 |
+
rx = self.get_first_stage_encoding(hook_xr.to(self.first_stage_model.dtype))
|
| 405 |
+
hook_unet.enhance_reference(
|
| 406 |
+
model = self.model,
|
| 407 |
+
ldm = self,
|
| 408 |
+
bs = bs * repeat,
|
| 409 |
+
s = -hook_xr.to(self.dtype),
|
| 410 |
+
r = rx,
|
| 411 |
+
style_cfg = injection_cfg,
|
| 412 |
+
control_cfg = injection_control,
|
| 413 |
+
gr_indice = gr_indice,
|
| 414 |
+
start_step = injection_start_step,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
if low_vram:
|
| 418 |
+
self.low_vram_shift(self.switch_main_modules)
|
| 419 |
+
|
| 420 |
+
z = self.sample(
|
| 421 |
+
cond = c,
|
| 422 |
+
uncond = uc,
|
| 423 |
+
bs = bs,
|
| 424 |
+
shape = (self.channels, height // 8, width // 8),
|
| 425 |
+
cfg_scale = gs,
|
| 426 |
+
step = step,
|
| 427 |
+
sampler = sampler,
|
| 428 |
+
scheduler = scheduler,
|
| 429 |
+
seed = seed,
|
| 430 |
+
deterministic = deterministic,
|
| 431 |
+
return_intermediate = return_intermediate,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if injection:
|
| 435 |
+
hook_unet.restore(self.model)
|
| 436 |
+
|
| 437 |
+
if low_vram:
|
| 438 |
+
self.low_vram_shift("first")
|
| 439 |
+
return self.decode_first_stage(z.to(self.first_stage_model.dtype))
|
refnet/models/colorizerXL.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from ..modules.reference_net import hack_inference_forward
|
| 6 |
+
from ..models.basemodel import CustomizedColorizer, CustomizedWrapper
|
| 7 |
+
from ..modules.lora import LoraModules
|
| 8 |
+
from ..util import exists, expand_to_batch_size, instantiate_from_config, get_crop_scale, resize_and_crop
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class InferenceWrapper(CustomizedWrapper, CustomizedColorizer):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
scalar_embedder_config,
|
| 16 |
+
img_embedder_config,
|
| 17 |
+
lora_config = None,
|
| 18 |
+
logits_embed = False,
|
| 19 |
+
*args,
|
| 20 |
+
**kwargs
|
| 21 |
+
):
|
| 22 |
+
CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs)
|
| 23 |
+
CustomizedWrapper.__init__(self)
|
| 24 |
+
|
| 25 |
+
self.scalar_embedder = instantiate_from_config(scalar_embedder_config)
|
| 26 |
+
self.img_embedder = instantiate_from_config(img_embedder_config)
|
| 27 |
+
self.loras = LoraModules(self, **lora_config) if exists(lora_config) else None
|
| 28 |
+
self.logits_embed = logits_embed
|
| 29 |
+
|
| 30 |
+
new_model_list = {
|
| 31 |
+
"scalar_embedder": self.scalar_embedder,
|
| 32 |
+
"img_embedder": self.img_embedder,
|
| 33 |
+
# "style_encoder": self.style_encoder,
|
| 34 |
+
}
|
| 35 |
+
self.switch_cond_modules += list(new_model_list.keys())
|
| 36 |
+
self.model_list.update(new_model_list)
|
| 37 |
+
|
| 38 |
+
def retrieve_attn_modules(self):
|
| 39 |
+
scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25}
|
| 40 |
+
|
| 41 |
+
from refnet.modules.transformer import BasicTransformerBlock
|
| 42 |
+
from refnet.sampling import torch_dfs
|
| 43 |
+
|
| 44 |
+
attn_modules = []
|
| 45 |
+
for module in torch_dfs(self.model.diffusion_model):
|
| 46 |
+
if isinstance(module, BasicTransformerBlock):
|
| 47 |
+
attn_modules.append(module)
|
| 48 |
+
|
| 49 |
+
self.attn_modules = {
|
| 50 |
+
"high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69],
|
| 51 |
+
"low": [i for i in range(4, 24)] + [i for i in range(34, 64)],
|
| 52 |
+
"bottom": [i for i in range(24, 34)],
|
| 53 |
+
"encoder": [i for i in range(24)],
|
| 54 |
+
"decoder": [i for i in range(34, len(attn_modules))]
|
| 55 |
+
}
|
| 56 |
+
self.attn_modules["modules"] = attn_modules
|
| 57 |
+
|
| 58 |
+
for k in ["high", "low", "bottom"]:
|
| 59 |
+
scale_factor = scale_factor_levels[k]
|
| 60 |
+
for attn in self.attn_modules[k]:
|
| 61 |
+
attn_modules[attn].scale_factor = scale_factor
|
| 62 |
+
|
| 63 |
+
def adjust_reference_scale(self, scale_kwargs):
|
| 64 |
+
for module in self.attn_modules["modules"]:
|
| 65 |
+
module.reference_scale = scale_kwargs["scales"]["encoder"]
|
| 66 |
+
|
| 67 |
+
def adjust_masked_attn(self, scale, mask_threshold, merge_scale):
|
| 68 |
+
for layer in self.attn_layers:
|
| 69 |
+
layer.mask_scale = scale
|
| 70 |
+
layer.mask_threshold = mask_threshold
|
| 71 |
+
layer.merge_scale = merge_scale
|
| 72 |
+
|
| 73 |
+
def rescale_size(self, x: torch.Tensor, height, width):
|
| 74 |
+
oh, ow = x.shape[2:]
|
| 75 |
+
if oh < height or ow < width:
|
| 76 |
+
dh, dw = height - oh, width - ow
|
| 77 |
+
if dh > dw:
|
| 78 |
+
iw = ow + int(dh * ow/oh)
|
| 79 |
+
ih = height
|
| 80 |
+
else:
|
| 81 |
+
ih = oh + int(dw * oh/ow)
|
| 82 |
+
iw = width
|
| 83 |
+
else:
|
| 84 |
+
ih, iw = oh, ow
|
| 85 |
+
return torch.Tensor([ih]), torch.Tensor([iw])
|
| 86 |
+
|
| 87 |
+
def get_learned_embedding(self, c, bg=False, mapping=False, sketch=None, *args, **kwargs):
|
| 88 |
+
clip_emb = self.cond_stage_model.encode(c, "full").detach()
|
| 89 |
+
wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True)
|
| 90 |
+
cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:]
|
| 91 |
+
|
| 92 |
+
if mapping:
|
| 93 |
+
_, sketch_logits = self.img_embedder.encode(-sketch, pooled=False, return_logits=True)
|
| 94 |
+
sketch_logits.mean(dim=1, keepdim=True)
|
| 95 |
+
logits = self.img_embedder.geometry_update(logits, sketch_logits)
|
| 96 |
+
emb = self.proj(clip_emb, logits if self.logits_embed else wd_emb, bg)
|
| 97 |
+
return emb, cls_emb
|
| 98 |
+
|
| 99 |
+
def prepare_conditions(
|
| 100 |
+
self,
|
| 101 |
+
bs,
|
| 102 |
+
sketch,
|
| 103 |
+
reference,
|
| 104 |
+
height,
|
| 105 |
+
width,
|
| 106 |
+
control_scale = (1., 1., 1., 1.),
|
| 107 |
+
merge_scale = 0,
|
| 108 |
+
mask_scale = 1.,
|
| 109 |
+
fg_scale = 1.,
|
| 110 |
+
bg_scale = 1.,
|
| 111 |
+
smask = None,
|
| 112 |
+
rmask = None,
|
| 113 |
+
mask_threshold_ref = 0.,
|
| 114 |
+
mask_threshold_sketch = 0.,
|
| 115 |
+
style_enhance = False,
|
| 116 |
+
fg_enhance = False,
|
| 117 |
+
bg_enhance = False,
|
| 118 |
+
background = None,
|
| 119 |
+
targets = None,
|
| 120 |
+
anchors = None,
|
| 121 |
+
controls = None,
|
| 122 |
+
target_scales = None,
|
| 123 |
+
enhances = None,
|
| 124 |
+
thresholds_list = None,
|
| 125 |
+
geometry_map = False,
|
| 126 |
+
latent_inpaint = False,
|
| 127 |
+
low_vram = False,
|
| 128 |
+
*args,
|
| 129 |
+
**kwargs
|
| 130 |
+
):
|
| 131 |
+
# prepare reference embedding
|
| 132 |
+
# manipulate = self.check_manipulate(target_scales)
|
| 133 |
+
c = {}
|
| 134 |
+
uc = [{}, {}]
|
| 135 |
+
|
| 136 |
+
if exists(reference):
|
| 137 |
+
emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch, mapping=geometry_map)
|
| 138 |
+
else:
|
| 139 |
+
emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch))
|
| 140 |
+
|
| 141 |
+
h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.])
|
| 142 |
+
y = torch.cat(self.scalar_embedder(torch.cat([(h*w)**0.5, score])).cuda().chunk(2), 1)
|
| 143 |
+
|
| 144 |
+
if bg_enhance:
|
| 145 |
+
assert exists(rmask) and exists(smask)
|
| 146 |
+
|
| 147 |
+
if low_vram:
|
| 148 |
+
self.low_vram_shift(["first", "cond", "img_embedder", "proj"])
|
| 149 |
+
|
| 150 |
+
if latent_inpaint and exists(background):
|
| 151 |
+
bgh, bgw = background.shape[2:]
|
| 152 |
+
ch, cw = get_crop_scale(torch.tensor([height]), torch.tensor([width]), bgh, bgw)
|
| 153 |
+
hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width).to(self.first_stage_model.dtype))
|
| 154 |
+
bg_emb, _ = self.get_learned_embedding(background, bg=True)
|
| 155 |
+
hs_bg = expand_to_batch_size(hs_bg, bs)
|
| 156 |
+
c.update({"inpaint_bg": hs_bg})
|
| 157 |
+
else:
|
| 158 |
+
if exists(background):
|
| 159 |
+
bg_emb, _ = self.get_learned_embedding(background, bg=True)
|
| 160 |
+
else:
|
| 161 |
+
bg_emb, _ = self.get_learned_embedding(
|
| 162 |
+
torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference)),
|
| 163 |
+
True
|
| 164 |
+
)
|
| 165 |
+
emb = torch.cat([emb, bg_emb], 1)
|
| 166 |
+
|
| 167 |
+
if fg_enhance and exists(self.loras):
|
| 168 |
+
self.loras.switch_lora(True, "foreground")
|
| 169 |
+
if not bg_enhance:
|
| 170 |
+
emb = emb.repeat(1, 2, 1)
|
| 171 |
+
|
| 172 |
+
if fg_enhance or bg_enhance:
|
| 173 |
+
# sketch mask for cross-attention
|
| 174 |
+
smask = expand_to_batch_size(smask.to(self.dtype), bs)
|
| 175 |
+
for d in [c] + uc:
|
| 176 |
+
d.update({"mask": F.interpolate(smask, scale_factor=0.125)})
|
| 177 |
+
elif exists(self.loras):
|
| 178 |
+
self.loras.switch_lora(False)
|
| 179 |
+
|
| 180 |
+
sketch = sketch.to(self.dtype)
|
| 181 |
+
context = expand_to_batch_size(emb, bs).to(self.dtype)
|
| 182 |
+
y = expand_to_batch_size(y, bs)
|
| 183 |
+
uc_context = torch.zeros_like(context)
|
| 184 |
+
|
| 185 |
+
control = []
|
| 186 |
+
uc_control = []
|
| 187 |
+
if low_vram:
|
| 188 |
+
self.low_vram_shift(["control_encoder"])
|
| 189 |
+
encoded_sketch = self.control_encoder(
|
| 190 |
+
torch.cat([sketch, -torch.ones_like(sketch)], 0)
|
| 191 |
+
)
|
| 192 |
+
for idx, es in enumerate(encoded_sketch):
|
| 193 |
+
es = es * control_scale[idx]
|
| 194 |
+
ec, uec = es.chunk(2)
|
| 195 |
+
control.append(expand_to_batch_size(ec, bs))
|
| 196 |
+
uc_control.append(expand_to_batch_size(uec, bs))
|
| 197 |
+
|
| 198 |
+
c.update({"control": control, "context": [context], "y": [y]})
|
| 199 |
+
uc[0].update({"control": control, "context": [uc_context], "y": [y]})
|
| 200 |
+
uc[1].update({"control": uc_control, "context": [context], "y": [y]})
|
| 201 |
+
return c, uc
|
refnet/models/v2-colorizerXL.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from refnet.models.basemodel import CustomizedColorizer, CustomizedWrapper
|
| 2 |
+
from refnet.util import *
|
| 3 |
+
from refnet.modules.lora import LoraModules
|
| 4 |
+
from refnet.modules.reference_net import hack_unet_forward, hack_inference_forward
|
| 5 |
+
from refnet.sampling.hook import ReferenceAttentionControl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class InferenceWrapperXL(CustomizedWrapper, CustomizedColorizer):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
scalar_embedder_config,
|
| 12 |
+
img_embedder_config,
|
| 13 |
+
fg_encoder_config = None,
|
| 14 |
+
bg_encoder_config = None,
|
| 15 |
+
style_encoder_config = None,
|
| 16 |
+
lora_config = None,
|
| 17 |
+
logits_embed = False,
|
| 18 |
+
controller = False,
|
| 19 |
+
*args,
|
| 20 |
+
**kwargs
|
| 21 |
+
):
|
| 22 |
+
CustomizedColorizer.__init__(self, version="sdxl", *args, **kwargs)
|
| 23 |
+
CustomizedWrapper.__init__(self)
|
| 24 |
+
|
| 25 |
+
self.logits_embed = logits_embed
|
| 26 |
+
|
| 27 |
+
(
|
| 28 |
+
self.scalar_embedder,
|
| 29 |
+
self.img_embedder,
|
| 30 |
+
self.fg_encoder,
|
| 31 |
+
self.bg_encoder,
|
| 32 |
+
self.style_encoder
|
| 33 |
+
) = map(
|
| 34 |
+
lambda t: instantiate_from_config(t) if exists(t) else None,
|
| 35 |
+
(
|
| 36 |
+
scalar_embedder_config,
|
| 37 |
+
img_embedder_config,
|
| 38 |
+
fg_encoder_config,
|
| 39 |
+
bg_encoder_config,
|
| 40 |
+
style_encoder_config
|
| 41 |
+
)
|
| 42 |
+
)
|
| 43 |
+
self.loras = LoraModules(self, **lora_config)
|
| 44 |
+
|
| 45 |
+
if controller:
|
| 46 |
+
self.controller = ReferenceAttentionControl(
|
| 47 |
+
# time_embed_ch = self.model.diffusion_model.model_channels * 4,
|
| 48 |
+
reader_module = self.model.diffusion_model,
|
| 49 |
+
writer_module = self.bg_encoder,
|
| 50 |
+
# only_decoder = True
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
self.controller = None
|
| 54 |
+
|
| 55 |
+
new_model_list = {
|
| 56 |
+
# "style_encoder": self.style_encoder,
|
| 57 |
+
"scalar_embedder": self.scalar_embedder,
|
| 58 |
+
"img_embedder": self.img_embedder,
|
| 59 |
+
# "controller": self.controller
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
hack_unet_forward(self.model.diffusion_model)
|
| 63 |
+
if exists(self.fg_encoder):
|
| 64 |
+
hack_inference_forward(self.fg_encoder)
|
| 65 |
+
new_model_list["fg_encoder"] = self.fg_encoder
|
| 66 |
+
if exists(self.bg_encoder):
|
| 67 |
+
hack_inference_forward(self.bg_encoder)
|
| 68 |
+
new_model_list["bg_encoder"] = self.bg_encoder
|
| 69 |
+
# hack_inference_forward(self.bg_encoder)
|
| 70 |
+
# hack_inference_forward(self.style_encoder)
|
| 71 |
+
|
| 72 |
+
self.switch_cond_modules += list(new_model_list.keys())
|
| 73 |
+
# self.switch_main_modules += ["controller"]
|
| 74 |
+
self.model_list.update(new_model_list)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def switch_to_fp16(self):
|
| 78 |
+
super().switch_to_fp16()
|
| 79 |
+
self.model.diffusion_model.map_modules.to(self.half_precision_dtype)
|
| 80 |
+
self.model.diffusion_model.warp_modules.to(self.half_precision_dtype)
|
| 81 |
+
self.model.diffusion_model.style_modules.to(self.half_precision_dtype)
|
| 82 |
+
self.model.diffusion_model.conv_fg.to(self.half_precision_dtype)
|
| 83 |
+
|
| 84 |
+
if exists(self.fg_encoder):
|
| 85 |
+
self.fg_encoder.to(self.half_precision_dtype)
|
| 86 |
+
self.fg_encoder.dtype = self.half_precision_dtype
|
| 87 |
+
self.fg_encoder.time_embed.float()
|
| 88 |
+
if exists(self.bg_encoder):
|
| 89 |
+
self.bg_encoder.to(self.half_precision_dtype)
|
| 90 |
+
self.bg_encoder.dtype = self.half_precision_dtype
|
| 91 |
+
self.bg_encoder.time_embed.float()
|
| 92 |
+
# self.style_encoder.to(self.half_precision_dtype)
|
| 93 |
+
# self.style_encoder.dtype = self.half_precision_dtype
|
| 94 |
+
# self.style_encoder.time_embed.float()
|
| 95 |
+
|
| 96 |
+
def switch_to_fp32(self):
|
| 97 |
+
super().switch_to_fp32()
|
| 98 |
+
self.model.diffusion_model.map_modules.float()
|
| 99 |
+
self.model.diffusion_model.warp_modules.float()
|
| 100 |
+
self.model.diffusion_model.style_modules.float()
|
| 101 |
+
|
| 102 |
+
self.fg_encoder.float()
|
| 103 |
+
self.bg_encoder.float()
|
| 104 |
+
# self.style_encoder.float()
|
| 105 |
+
|
| 106 |
+
self.fg_encoder.dtype = torch.float32
|
| 107 |
+
self.bg_encoder.dtype = torch.float32
|
| 108 |
+
# self.style_encoder.dtype = torch.float32
|
| 109 |
+
|
| 110 |
+
def rescale_size(self, x: torch.Tensor, height, width):
|
| 111 |
+
oh, ow = x.shape[2:]
|
| 112 |
+
if oh < height or ow < width:
|
| 113 |
+
dh, dw = height - oh, width - ow
|
| 114 |
+
if dh > dw:
|
| 115 |
+
iw = ow + int(dh * ow/oh)
|
| 116 |
+
ih = height
|
| 117 |
+
else:
|
| 118 |
+
ih = oh + int(dw * oh/ow)
|
| 119 |
+
iw = width
|
| 120 |
+
else:
|
| 121 |
+
ih, iw = oh, ow
|
| 122 |
+
return torch.tensor([ih]), torch.tensor([iw])
|
| 123 |
+
|
| 124 |
+
def rescale_background_size(self, x, height, width):
|
| 125 |
+
oh, ow = x.shape[2:]
|
| 126 |
+
if oh < height or ow < width:
|
| 127 |
+
# A simple bias to avoid deterioration caused by reference resolution
|
| 128 |
+
mind = max(height, width)
|
| 129 |
+
ih = oh + mind
|
| 130 |
+
iw = ow / oh * ih
|
| 131 |
+
else:
|
| 132 |
+
ih, iw = oh, ow
|
| 133 |
+
# rh, rw = ih / height, iw / width
|
| 134 |
+
return torch.tensor([ih]), torch.tensor([iw])
|
| 135 |
+
|
| 136 |
+
def get_learned_embedding(self, c, bg=False, sketch=None, mapping=False, *args, **kwargs):
|
| 137 |
+
clip_emb = self.cond_stage_model.encode(c, "full").detach()
|
| 138 |
+
wd_emb, logits = self.img_embedder.encode(c, pooled=False, return_logits=True)
|
| 139 |
+
cls_emb, local_emb = clip_emb[:, :1], clip_emb[:, 1:]
|
| 140 |
+
|
| 141 |
+
if self.logits_embed and exists(sketch) and mapping:
|
| 142 |
+
_, sketch_logits = self.img_embedder.encode(-sketch, pooled=True, return_logits=True)
|
| 143 |
+
logits = self.img_embedder.geometry_update(logits, sketch_logits)
|
| 144 |
+
|
| 145 |
+
if self.logits_embed:
|
| 146 |
+
emb = self.proj(clip_emb, logits, bg)[0]
|
| 147 |
+
else:
|
| 148 |
+
emb = self.proj(clip_emb, wd_emb, bg)
|
| 149 |
+
return emb.to(self.dtype), cls_emb.to(self.dtype)
|
| 150 |
+
|
| 151 |
+
def prepare_conditions(
|
| 152 |
+
self,
|
| 153 |
+
bs,
|
| 154 |
+
sketch,
|
| 155 |
+
reference,
|
| 156 |
+
height,
|
| 157 |
+
width,
|
| 158 |
+
control_scale = 1,
|
| 159 |
+
mask_scale = 1,
|
| 160 |
+
merge_scale = 0.,
|
| 161 |
+
cond_aug = 0.,
|
| 162 |
+
background = None,
|
| 163 |
+
smask = None,
|
| 164 |
+
rmask = None,
|
| 165 |
+
mask_threshold_ref = 0.,
|
| 166 |
+
mask_threshold_sketch = 0.,
|
| 167 |
+
style_enhance = False,
|
| 168 |
+
fg_enhance = False,
|
| 169 |
+
bg_enhance = False,
|
| 170 |
+
latent_inpaint = False,
|
| 171 |
+
fg_disentangle_scale = 1.,
|
| 172 |
+
targets = None,
|
| 173 |
+
anchors = None,
|
| 174 |
+
controls = None,
|
| 175 |
+
target_scales = None,
|
| 176 |
+
enhances = None,
|
| 177 |
+
thresholds_list = None,
|
| 178 |
+
low_vram = False,
|
| 179 |
+
*args,
|
| 180 |
+
**kwargs
|
| 181 |
+
):
|
| 182 |
+
def prepare_style_modulations(y):
|
| 183 |
+
# Style enhancement part
|
| 184 |
+
z_ref = self.get_first_stage_encoding(warp_resize(reference, (height, width)))
|
| 185 |
+
if exists(background) and merge_scale > 0:
|
| 186 |
+
rh, rw = self.rescale_size(background, height, width)
|
| 187 |
+
z_bg = self.get_first_stage_encoding(warp_resize(background, (height, width)))
|
| 188 |
+
bg_emb, bg_cls_emb = self.get_learned_embedding(background)
|
| 189 |
+
scalar_embed = torch.cat(
|
| 190 |
+
self.scalar_embedder(torch.cat([rh, rw, ct, cl, h, w])).chunk(6), 1
|
| 191 |
+
).to(bg_emb.device)
|
| 192 |
+
bgy = torch.cat([bg_cls_emb.squeeze(1), scalar_embed], 1).to(self.dtype)
|
| 193 |
+
|
| 194 |
+
style_modulations = self.style_encoder(
|
| 195 |
+
torch.cat([z_ref, z_bg]),
|
| 196 |
+
timesteps = torch.zeros((2,), dtype=torch.long, device=z_ref.device),
|
| 197 |
+
context = torch.cat([emb, bg_emb]),
|
| 198 |
+
y = torch.cat([y, bgy])
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
for idx, m in enumerate(style_modulations):
|
| 202 |
+
fg, bg = m.chunk(2)
|
| 203 |
+
m = fg * (1-merge_scale) + merge_scale * bg
|
| 204 |
+
style_modulations[idx] = expand_to_batch_size(m, bs).to(self.dtype)
|
| 205 |
+
|
| 206 |
+
else:
|
| 207 |
+
z_bg = None
|
| 208 |
+
bg_emb = None
|
| 209 |
+
bgy = None
|
| 210 |
+
style_modulations = self.style_encoder(
|
| 211 |
+
z_ref,
|
| 212 |
+
timesteps = torch.zeros((1,), dtype=torch.long, device=z_ref.device),
|
| 213 |
+
context = emb,
|
| 214 |
+
y = y,
|
| 215 |
+
)
|
| 216 |
+
style_modulations = [expand_to_batch_size(m, bs).to(self.dtype) for m in style_modulations]
|
| 217 |
+
|
| 218 |
+
return style_modulations, z_bg, bg_emb, bgy
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def prepare_background_latents(z_bg, bg_emb, bgy):
|
| 222 |
+
# Background enhancement part
|
| 223 |
+
bgh, bgw = background.shape[2:] if exists(background) else reference.shape[2:]
|
| 224 |
+
ch, cw = get_crop_scale(h, w, bgh, bgw)
|
| 225 |
+
|
| 226 |
+
if low_vram:
|
| 227 |
+
self.low_vram_shift(["first", "cond", "img_embedder"])
|
| 228 |
+
if latent_inpaint and exists(background):
|
| 229 |
+
hs_bg = self.get_first_stage_encoding(resize_and_crop(background, ch, cw, height, width))
|
| 230 |
+
bg_emb, cls_emb = self.get_learned_embedding(background)
|
| 231 |
+
|
| 232 |
+
else:
|
| 233 |
+
if not exists(z_bg):
|
| 234 |
+
bgy = torch.cat(
|
| 235 |
+
self.scalar_embedder(torch.tensor([ct, cl, ch, cw])).chunk(4), 1
|
| 236 |
+
# self.scalar_embedder(torch.tensor([bgh / bgw, h / w, ct, cl, ch, cw])).chunk(6), 1
|
| 237 |
+
).to(self.dtype).cuda()
|
| 238 |
+
|
| 239 |
+
if exists(background):
|
| 240 |
+
# bgh, bgw = self.rescale_background_size(background, height, width)
|
| 241 |
+
z_bg = self.get_first_stage_encoding(warp_resize(background, (height, width)))
|
| 242 |
+
bg_emb, cls_emb = self.get_learned_embedding(background)
|
| 243 |
+
# scalar_embed = torch.cat(self.scalar_embedder(torch.cat([bgh, bgw, ct, cl, h, w])).chunk(6), 1).cuda()
|
| 244 |
+
# bgy = torch.cat([cls_emb.squeeze(1), scalar_embed], 1).to(self.dtype)
|
| 245 |
+
else:
|
| 246 |
+
xbg = torch.where(rmask < mask_threshold_ref, reference, torch.ones_like(reference))
|
| 247 |
+
z_bg = self.get_first_stage_encoding(warp_resize(xbg, (height, width)))
|
| 248 |
+
bg_emb, cls_emb = self.get_learned_embedding(xbg)
|
| 249 |
+
|
| 250 |
+
if low_vram:
|
| 251 |
+
self.low_vram_shift(["bg_encoder"])
|
| 252 |
+
hs_bg = self.bg_encoder(
|
| 253 |
+
x = torch.cat([
|
| 254 |
+
z_bg,
|
| 255 |
+
# torch.where(
|
| 256 |
+
# smask > mask_threshold_sketch,
|
| 257 |
+
# torch.zeros_like(smask),
|
| 258 |
+
# F.interpolate(warp_resize(rmask, (height, width)), scale_factor=0.125)
|
| 259 |
+
# )
|
| 260 |
+
F.interpolate(warp_resize(smask, (height, width)), scale_factor=0.125),
|
| 261 |
+
F.interpolate(warp_resize(rmask, (height, width)), scale_factor=0.125)
|
| 262 |
+
], 1),
|
| 263 |
+
timesteps = torch.zeros((1,), dtype=torch.long, device=z_bg.device),
|
| 264 |
+
# context = bg_emb,
|
| 265 |
+
y = bgy.to(self.dtype),
|
| 266 |
+
)
|
| 267 |
+
return hs_bg, bg_emb
|
| 268 |
+
|
| 269 |
+
self.loras.recover_lora()
|
| 270 |
+
# prepare reference embedding
|
| 271 |
+
# manipulate = self.check_manipulate(target_scales)
|
| 272 |
+
c = {}
|
| 273 |
+
uc = [{}, {}]
|
| 274 |
+
self.loras.switch_lora(False)
|
| 275 |
+
# self.loras.recover_lora()
|
| 276 |
+
|
| 277 |
+
if exists(reference):
|
| 278 |
+
emb, cls_emb = self.get_learned_embedding(reference, sketch=sketch)
|
| 279 |
+
# rh, rw = reference.shape[2:]
|
| 280 |
+
# rh, rw = self.rescale_background_size(reference, height, width)
|
| 281 |
+
|
| 282 |
+
else:
|
| 283 |
+
emb, cls_emb = map(lambda t: torch.zeros_like(t), self.get_learned_embedding(sketch))
|
| 284 |
+
# rh, rw = torch.Tensor([height]), torch.Tensor([width])
|
| 285 |
+
|
| 286 |
+
ct, cl = torch.Tensor([0]), torch.Tensor([0])
|
| 287 |
+
# h, w = torch.Tensor([height]), torch.Tensor([width])
|
| 288 |
+
# scalar_embed = torch.cat(self.scalar_embedder(torch.cat([rh, rw, ct, cl, h, w])).chunk(6), 1).cuda()
|
| 289 |
+
# y = torch.cat([cls_emb.squeeze(1), scalar_embed], 1)
|
| 290 |
+
# y = self.scalar_embedder((h*w)**0.5).cuda()
|
| 291 |
+
# y = torch.cat(self.scalar_embedder(torch.cat([h, w])).chunk(2), 1).cuda()
|
| 292 |
+
h, w, score = torch.Tensor([height]), torch.Tensor([width]), torch.Tensor([7.])
|
| 293 |
+
y = torch.cat(self.scalar_embedder(torch.cat([(h * w) ** 0.5, score])).cuda().chunk(2), 1)
|
| 294 |
+
|
| 295 |
+
z_bg, bg_emb, bgy = None, None, None
|
| 296 |
+
|
| 297 |
+
# Style enhance part
|
| 298 |
+
if style_enhance:
|
| 299 |
+
style_modulations, z_bg, bg_emb, bgy = prepare_style_modulations(y)
|
| 300 |
+
for d in [c] + uc:
|
| 301 |
+
d.update({"style_modulations": style_modulations})
|
| 302 |
+
|
| 303 |
+
# Foreground enhance part
|
| 304 |
+
if fg_enhance:
|
| 305 |
+
assert exists(smask) and exists(rmask)
|
| 306 |
+
self.loras.switch_lora(True, "foreground")
|
| 307 |
+
if low_vram:
|
| 308 |
+
self.low_vram_shift(["first"])
|
| 309 |
+
z_fg = self.get_first_stage_encoding(warp_resize(
|
| 310 |
+
torch.where(rmask >= mask_threshold_ref, reference, torch.ones_like(reference)),
|
| 311 |
+
(height, width)
|
| 312 |
+
)) * fg_disentangle_scale
|
| 313 |
+
# z_ref = default(z_ref, self.get_first_stage_encoding(warp_resize(reference, (height, width))))
|
| 314 |
+
# self.loras.switch_lora(True, False)
|
| 315 |
+
self.loras.adjust_lora_scales(fg_disentangle_scale, "foreground")
|
| 316 |
+
if low_vram:
|
| 317 |
+
self.low_vram_shift(["fg_encoder"])
|
| 318 |
+
hs_fg = self.fg_encoder(
|
| 319 |
+
z_fg,
|
| 320 |
+
timesteps = torch.zeros((1,), dtype=torch.long, device=z_fg.device),
|
| 321 |
+
)
|
| 322 |
+
# hs_fg = [hs * fg_disentangle_scale for hs in hs_fg]
|
| 323 |
+
hs_fg = expand_to_batch_size(hs_fg, bs)
|
| 324 |
+
for d in [c] + uc:
|
| 325 |
+
d.update({
|
| 326 |
+
"hs_fg": hs_fg,
|
| 327 |
+
"inject_mask": expand_to_batch_size(smask, bs),
|
| 328 |
+
})
|
| 329 |
+
# for d in [c] + uc:
|
| 330 |
+
# d.update({"z_fg": expand_to_batch_size(z_fg, bs)})
|
| 331 |
+
|
| 332 |
+
# Background enhance part
|
| 333 |
+
if bg_enhance:
|
| 334 |
+
assert exists(rmask) and exists(smask)
|
| 335 |
+
# if not self.controller.hooked:
|
| 336 |
+
# self.controller.register("read", self.model.diffusion_model)
|
| 337 |
+
# self.loras.switch_lora(False, True)
|
| 338 |
+
hs_bg, bg_emb = prepare_background_latents(z_bg, bg_emb, default(bgy, y))
|
| 339 |
+
self.loras.switch_lora(True, "background")
|
| 340 |
+
if latent_inpaint and exists(background):
|
| 341 |
+
hs_bg = expand_to_batch_size(hs_bg, bs)
|
| 342 |
+
c.update({"inpaint_bg": hs_bg})
|
| 343 |
+
elif exists(self.controller):
|
| 344 |
+
# self.loras.merge_lora()
|
| 345 |
+
self.controller.update()
|
| 346 |
+
else:
|
| 347 |
+
hs_bg = expand_to_batch_size(hs_bg, bs)
|
| 348 |
+
for d in [c] + uc:
|
| 349 |
+
d.update({"hs_bg": hs_bg})
|
| 350 |
+
|
| 351 |
+
elif exists(self.controller):
|
| 352 |
+
# self.controller.reader_restore()
|
| 353 |
+
self.controller.clean()
|
| 354 |
+
|
| 355 |
+
if fg_enhance or bg_enhance:
|
| 356 |
+
# need to activate mask-guided split cross-attetnion
|
| 357 |
+
emb = torch.cat([emb, default(bg_emb, emb)], 1)
|
| 358 |
+
smask = expand_to_batch_size(smask.to(self.dtype), bs)
|
| 359 |
+
for d in [c] + uc:
|
| 360 |
+
d.update({"mask": F.interpolate(smask, scale_factor=0.125), "threshold": mask_threshold_sketch})
|
| 361 |
+
|
| 362 |
+
# if fg_enhance and bg_enhance:
|
| 363 |
+
# self.loras.switch_lora(True, True)
|
| 364 |
+
sketch = sketch.to(self.dtype)
|
| 365 |
+
context = expand_to_batch_size(emb, bs).to(self.dtype)
|
| 366 |
+
y = expand_to_batch_size(y, bs).float()
|
| 367 |
+
uc_context = torch.zeros_like(context)
|
| 368 |
+
|
| 369 |
+
control = []
|
| 370 |
+
uc_control = []
|
| 371 |
+
if low_vram:
|
| 372 |
+
self.low_vram_shift(["control_encoder"])
|
| 373 |
+
encoded_sketch = self.control_encoder(
|
| 374 |
+
torch.cat([sketch, -torch.ones_like(sketch)], 0)
|
| 375 |
+
)
|
| 376 |
+
for idx, es in enumerate(encoded_sketch):
|
| 377 |
+
es = es * control_scale[idx]
|
| 378 |
+
ec, uec = es.chunk(2)
|
| 379 |
+
control.append(expand_to_batch_size(ec, bs))
|
| 380 |
+
uc_control.append(expand_to_batch_size(uec, bs))
|
| 381 |
+
|
| 382 |
+
self.loras.merge_lora()
|
| 383 |
+
c.update({"control": control, "context": [context], "y": [y]})
|
| 384 |
+
uc[0].update({"control": control, "context": [uc_context], "y": [y]})
|
| 385 |
+
uc[1].update({"control": uc_control, "context": [context], "y": [y]})
|
| 386 |
+
return c, uc
|