|
|
import tqdm
|
|
|
import torch
|
|
|
import comfy.k_diffusion.sampling
|
|
|
from tqdm.auto import trange
|
|
|
from comfy.k_diffusion.sampling import to_d
|
|
|
|
|
|
|
|
|
INITIALIZED = False
|
|
|
|
|
|
|
|
|
def default_noise_sampler(x):
|
|
|
return lambda sigma, sigma_next: comfy.k_diffusion.sampling.torch.randn_like(x)
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
|
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
if callback is not None:
|
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
|
|
|
|
sigma_from, sigma_to = sigmas[i], sigmas[i+1]
|
|
|
|
|
|
|
|
|
t = model.inner_model.inner_model.model_sampling.timestep(sigma_from)
|
|
|
t_s = (1 - gamma) * t
|
|
|
sigma_to_s = model.inner_model.inner_model.model_sampling.sigma(t_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noise_est = (x - denoised) / sigma_from
|
|
|
x /= torch.sqrt(1.0 + sigma_from ** 2.0)
|
|
|
|
|
|
alpha_cumprod = 1 / ((sigma_from * sigma_from) + 1)
|
|
|
alpha_cumprod_prev = 1 / ((sigma_to * sigma_to) + 1)
|
|
|
alpha = (alpha_cumprod / alpha_cumprod_prev)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
|
|
|
|
|
|
|
|
|
first_step = sigma_to == 0
|
|
|
last_step = i == len(sigmas) - 2
|
|
|
|
|
|
if not first_step:
|
|
|
if gamma > 0 and not last_step:
|
|
|
noise = noise_sampler(sigma_from, sigma_to)
|
|
|
|
|
|
|
|
|
variance = ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) * (1 - alpha_cumprod / alpha_cumprod_prev)
|
|
|
x += variance.sqrt() * noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x *= torch.sqrt(1.0 + sigma_to ** 2.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def sample_tcd_euler_a(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
|
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
if callback is not None:
|
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
|
|
|
|
d = to_d(x, sigmas[i], denoised)
|
|
|
|
|
|
sigma_from = sigmas[i]
|
|
|
sigma_to = sigmas[i + 1]
|
|
|
|
|
|
t = model.inner_model.inner_model.model_sampling.timestep(sigma_from)
|
|
|
down_t = (1 - gamma) * t
|
|
|
sigma_down = model.inner_model.inner_model.model_sampling.sigma(down_t)
|
|
|
|
|
|
if sigma_down > sigma_to:
|
|
|
sigma_down = sigma_to
|
|
|
|
|
|
sigma_up = (sigma_to ** 2 - sigma_down ** 2) ** 0.5
|
|
|
|
|
|
|
|
|
d = to_d(x, sigma_from, denoised)
|
|
|
dt = sigma_down - sigma_from
|
|
|
x += d * dt
|
|
|
|
|
|
if sigma_to > 0 and gamma > 0:
|
|
|
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigma_up, noise_sampler(sigma_from, sigma_to), x)
|
|
|
return x |