dikdimon's picture
Upload 58 files
97923d1 verified
import torch
from tqdm.auto import trange
from lib_es.compat import get_ancestral_step, to_d
from lib_es.utils import default_noise_sampler, sampler_metadata
def sigma_fn(t):
return t.neg().exp()
def t_fn(sigma):
return sigma.log().neg()
def phi1_fn(t):
return torch.expm1(t) / t
def phi2_fn(t):
return (phi1_fn(t) - 1.0) / t
@torch.no_grad()
def res_multistep(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_noise=1.0,
noise_sampler=None,
eta=1.0,
cfg_pp=False,
):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
old_sigma_down = None
old_denoised = None
if cfg_pp:
model.need_last_noise_uncond = True
for i in trange(len(sigmas) - 1, disable=disable):
if cfg_pp:
model.last_noise_uncond = None
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
uncond_d = model.last_noise_uncond if cfg_pp else None
if cfg_pp and uncond_d is not None:
sigma_batch = sigmas[i] * s_in
sigma_view = sigma_batch[(...,) + (None,) * (x.ndim - 1)]
uncond_denoised = x - uncond_d * sigma_view
else:
uncond_denoised = None
if sigma_down == 0 or old_denoised is None or old_sigma_down is None:
if cfg_pp:
if uncond_denoised is None:
raise RuntimeError("CFG++ path requires model.last_noise_uncond")
d = uncond_d
x = denoised + d * sigma_down
else:
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
h = t_next - t
c2 = (t_prev - t_old) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
if cfg_pp:
if uncond_denoised is None:
raise RuntimeError("CFG++ path requires model.last_noise_uncond")
x = x + (denoised - uncond_denoised)
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
else:
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
old_denoised = uncond_denoised if cfg_pp else denoised
old_sigma_down = sigma_down
return x
@sampler_metadata(
"Res Multistep",
{"scheduler": "sgm_uniform"},
)
@torch.no_grad()
def sample_res_multistep(
model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None
):
return res_multistep(
model,
x,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
s_noise=s_noise,
noise_sampler=noise_sampler,
eta=0.0,
cfg_pp=False,
)
@sampler_metadata(
"Res Multistep CFG++",
{"scheduler": "sgm_uniform"},
)
@torch.no_grad()
def sample_res_multistep_cfg_pp(
model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None
):
return res_multistep(
model,
x,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
s_noise=s_noise,
noise_sampler=noise_sampler,
eta=0.0,
cfg_pp=True,
)
@sampler_metadata(
"Res Multistep Ancestral",
{"uses_ensd": True, "scheduler": "sgm_uniform"},
)
@torch.no_grad()
def sample_res_multistep_ancestral(
model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None
):
return res_multistep(
model,
x,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
s_noise=s_noise,
noise_sampler=noise_sampler,
eta=eta,
cfg_pp=False,
)
@sampler_metadata(
"Res Multistep Ancestral CFG++",
{"uses_ensd": True, "scheduler": "sgm_uniform"},
)
@torch.no_grad()
def sample_res_multistep_ancestral_cfg_pp(
model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None
):
return res_multistep(
model,
x,
sigmas,
extra_args=extra_args,
callback=callback,
disable=disable,
s_noise=s_noise,
noise_sampler=noise_sampler,
eta=eta,
cfg_pp=True,
)