|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
from tqdm.auto import trange
|
|
|
|
|
|
|
|
|
|
|
|
def to_d(x, sigma, denoised):
|
|
|
"""Converts a denoiser output to a Karras ODE derivative."""
|
|
|
return (x - denoised) / append_dims(sigma, x.ndim)
|
|
|
def append_dims(x, target_dims):
|
|
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
|
|
dims_to_append = target_dims - x.ndim
|
|
|
if dims_to_append < 0:
|
|
|
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
|
|
return x[(...,) + (None,) * dims_to_append]
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def sample_gradient_e(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
|
|
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
old_d = None
|
|
|
|
|
|
sigmas = sigmas.to(x.device)
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
|
|
|
d = to_d(x, sigmas[i], denoised)
|
|
|
if callback is not None:
|
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
|
dt = sigmas[i + 1] - sigmas[i]
|
|
|
if i == 0:
|
|
|
x = x + d * dt
|
|
|
else:
|
|
|
|
|
|
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
|
|
x = x + d_bar * dt
|
|
|
old_d = d
|
|
|
return x
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def sample_gradient_e_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
|
|
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
old_d = None
|
|
|
|
|
|
model.need_last_noise_uncond = True
|
|
|
model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
|
|
|
d = model.last_noise_uncond
|
|
|
|
|
|
if callback is not None:
|
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
|
dt = sigmas[i + 1] - sigmas[i]
|
|
|
if i == 0:
|
|
|
x = denoised + d * sigmas[i+1]
|
|
|
else:
|
|
|
|
|
|
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
|
|
x = denoised + d_bar * sigmas[i+1]
|
|
|
old_d = d
|
|
|
return x
|
|
|
|