File size: 9,760 Bytes
3dabe4a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | from importlib import import_module
from tqdm.auto import trange
import torch
sampling = None
BACKEND = None
INITIALIZED = False
if not BACKEND:
try:
_ = import_module("modules.sd_samplers_kdiffusion")
sampling = import_module("k_diffusion.sampling")
BACKEND = "WebUI"
except ImportError as _:
pass
if not BACKEND:
try:
sampling = import_module("comfy.k_diffusion.sampling")
BACKEND = "ComfyUI"
except ImportError as _:
pass
class _Rescaler:
def __init__(self, model, x, mode, **extra_args):
self.model = model
self.x = x
self.mode = mode
self.extra_args = extra_args
if BACKEND == "WebUI":
self.init_latent, self.mask, self.nmask = model.init_latent, model.mask, model.nmask
if BACKEND == "ComfyUI":
self.latent_image, self.noise = model.latent_image, model.noise
self.denoise_mask = self.extra_args.get("denoise_mask", None)
def __enter__(self):
if BACKEND == "WebUI":
if self.init_latent is not None:
self.model.init_latent = torch.nn.functional.interpolate(input=self.init_latent, size=self.x.shape[2:4], mode=self.mode)
if self.mask is not None:
self.model.mask = torch.nn.functional.interpolate(input=self.mask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
if self.nmask is not None:
self.model.nmask = torch.nn.functional.interpolate(input=self.nmask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
if BACKEND == "ComfyUI":
if self.latent_image is not None:
self.model.latent_image = torch.nn.functional.interpolate(input=self.latent_image, size=self.x.shape[2:4], mode=self.mode)
if self.noise is not None:
self.model.noise = torch.nn.functional.interpolate(input=self.latent_image, size=self.x.shape[2:4], mode=self.mode)
if self.denoise_mask is not None:
self.extra_args["denoise_mask"] = torch.nn.functional.interpolate(input=self.denoise_mask, size=self.x.shape[2:4], mode=self.mode)
return self
def __exit__(self, type, value, traceback):
if BACKEND == "WebUI":
del self.model.init_latent, self.model.mask, self.model.nmask
self.model.init_latent, self.model.mask, self.model.nmask = self.init_latent, self.mask, self.nmask
if BACKEND == "ComfyUI":
del self.model.latent_image, self.model.noise
self.model.latent_image, self.model.noise = self.latent_image, self.noise
@torch.no_grad()
def dy_sampling_step(x, model, dt, sigma_hat, **extra_args):
original_shape = x.shape
batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
extra_row = x.shape[2] % 2 == 1
extra_col = x.shape[3] % 2 == 1
if extra_row:
extra_row_content = x[:, :, -1:, :]
x = x[:, :, :-1, :]
if extra_col:
extra_col_content = x[:, :, :, -1:]
x = x[:, :, :, :-1]
a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2)
c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)
with _Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
d = sampling.to_d(c, sigma_hat, denoised)
c = c + d * dt
d_list = c.view(batch_size, channels, m * n, 1, 1)
a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n)
if extra_row or extra_col:
x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
x_expanded[:, :, :2 * m, :2 * n] = x
if extra_row:
x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content
if extra_col:
x_expanded[:, :, :2 * m, -1:] = extra_col_content
if extra_row and extra_col:
x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
x = x_expanded
return x
@torch.no_grad()
def sample_euler_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
s_tmax=float('inf'), s_noise=1.):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
# print(i)
# i第一步为0
gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
# print(sigma_hat)
dt = sigmas[i + 1] - sigma_hat
if gamma > 0:
x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = sampling.to_d(x, sigma_hat, denoised)
if sigmas[i + 1] > 0:
if i // 2 == 1:
x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def smea_sampling_step(x, model, dt, sigma_hat, **extra_args):
m, n = x.shape[2], x.shape[3]
x = torch.nn.functional.interpolate(input=x, scale_factor=(1.25, 1.25), mode='nearest-exact')
with _Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
d = sampling.to_d(x, sigma_hat, denoised)
x = x + d * dt
x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
return x
@torch.no_grad()
def sample_euler_smea_dy(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
s_tmax=float('inf'), s_noise=1.):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
dt = sigmas[i + 1] - sigma_hat
if gamma > 0:
x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = sampling.to_d(x, sigma_hat, denoised)
# Euler method
x = x + d * dt
if sigmas[i + 1] > 0:
if i + 1 // 2 == 1:
x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
if i + 1 // 2 == 0:
x = smea_sampling_step(x, model, dt, sigma_hat, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
return x
@torch.no_grad()
def sample_euler_negative(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
s_tmax=float('inf'), s_noise=1.):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
# print(i)
# i第一步为0
gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
# print(sigma_hat)
dt = sigmas[i + 1] - sigma_hat
if gamma > 0:
x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = sampling.to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Euler method
if sigmas[i + 1] > 0 and i // 2 == 1:
x = - x - d * dt
else:
x = x + d * dt
return x
@torch.no_grad()
def sample_euler_dy_negative(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
s_tmax=float('inf'), s_noise=1.):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
# print(i)
# i第一步为0
gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
# print(sigma_hat)
dt = sigmas[i + 1] - sigma_hat
if gamma > 0:
x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = sampling.to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Euler method
if sigmas[i + 1] > 0 and i // 2 == 1:
x = dy_sampling_step(x, model, dt, sigma_hat, **extra_args)
x = - x - d * dt
else:
x = x + d * dt
return x
|