| import torch |
| from tqdm.auto import trange |
|
|
| from lib_es.compat import get_ancestral_step, to_d |
| from lib_es.utils import default_noise_sampler, sampler_metadata |
|
|
|
|
| def sigma_fn(t): |
| return t.neg().exp() |
|
|
|
|
| def t_fn(sigma): |
| return sigma.log().neg() |
|
|
|
|
| def phi1_fn(t): |
| return torch.expm1(t) / t |
|
|
|
|
| def phi2_fn(t): |
| return (phi1_fn(t) - 1.0) / t |
|
|
|
|
| @torch.no_grad() |
| def res_multistep( |
| model, |
| x, |
| sigmas, |
| extra_args=None, |
| callback=None, |
| disable=None, |
| s_noise=1.0, |
| noise_sampler=None, |
| eta=1.0, |
| cfg_pp=False, |
| ): |
| extra_args = {} if extra_args is None else extra_args |
| noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
| s_in = x.new_ones([x.shape[0]]) |
|
|
| old_sigma_down = None |
| old_denoised = None |
|
|
| if cfg_pp: |
| model.need_last_noise_uncond = True |
|
|
| for i in trange(len(sigmas) - 1, disable=disable): |
| if cfg_pp: |
| model.last_noise_uncond = None |
|
|
| denoised = model(x, sigmas[i] * s_in, **extra_args) |
| sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) |
|
|
| if callback is not None: |
| callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) |
|
|
| uncond_d = model.last_noise_uncond if cfg_pp else None |
| if cfg_pp and uncond_d is not None: |
| sigma_batch = sigmas[i] * s_in |
| sigma_view = sigma_batch[(...,) + (None,) * (x.ndim - 1)] |
| uncond_denoised = x - uncond_d * sigma_view |
| else: |
| uncond_denoised = None |
|
|
| if sigma_down == 0 or old_denoised is None or old_sigma_down is None: |
| if cfg_pp: |
| if uncond_denoised is None: |
| raise RuntimeError("CFG++ path requires model.last_noise_uncond") |
| d = uncond_d |
| x = denoised + d * sigma_down |
| else: |
| d = to_d(x, sigmas[i], denoised) |
| dt = sigma_down - sigmas[i] |
| x = x + d * dt |
| else: |
| t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1]) |
| h = t_next - t |
| c2 = (t_prev - t_old) / h |
|
|
| phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) |
| b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0) |
| b2 = torch.nan_to_num(phi2_val / c2, nan=0.0) |
|
|
| if cfg_pp: |
| if uncond_denoised is None: |
| raise RuntimeError("CFG++ path requires model.last_noise_uncond") |
| x = x + (denoised - uncond_denoised) |
| x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised) |
| else: |
| x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised) |
|
|
| if sigmas[i + 1] > 0: |
| x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up |
|
|
| old_denoised = uncond_denoised if cfg_pp else denoised |
| old_sigma_down = sigma_down |
|
|
| return x |
|
|
|
|
| @sampler_metadata( |
| "Res Multistep", |
| {"scheduler": "sgm_uniform"}, |
| ) |
| @torch.no_grad() |
| def sample_res_multistep( |
| model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None |
| ): |
| return res_multistep( |
| model, |
| x, |
| sigmas, |
| extra_args=extra_args, |
| callback=callback, |
| disable=disable, |
| s_noise=s_noise, |
| noise_sampler=noise_sampler, |
| eta=0.0, |
| cfg_pp=False, |
| ) |
|
|
|
|
| @sampler_metadata( |
| "Res Multistep CFG++", |
| {"scheduler": "sgm_uniform"}, |
| ) |
| @torch.no_grad() |
| def sample_res_multistep_cfg_pp( |
| model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None |
| ): |
| return res_multistep( |
| model, |
| x, |
| sigmas, |
| extra_args=extra_args, |
| callback=callback, |
| disable=disable, |
| s_noise=s_noise, |
| noise_sampler=noise_sampler, |
| eta=0.0, |
| cfg_pp=True, |
| ) |
|
|
|
|
| @sampler_metadata( |
| "Res Multistep Ancestral", |
| {"uses_ensd": True, "scheduler": "sgm_uniform"}, |
| ) |
| @torch.no_grad() |
| def sample_res_multistep_ancestral( |
| model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None |
| ): |
| return res_multistep( |
| model, |
| x, |
| sigmas, |
| extra_args=extra_args, |
| callback=callback, |
| disable=disable, |
| s_noise=s_noise, |
| noise_sampler=noise_sampler, |
| eta=eta, |
| cfg_pp=False, |
| ) |
|
|
|
|
| @sampler_metadata( |
| "Res Multistep Ancestral CFG++", |
| {"uses_ensd": True, "scheduler": "sgm_uniform"}, |
| ) |
| @torch.no_grad() |
| def sample_res_multistep_ancestral_cfg_pp( |
| model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None |
| ): |
| return res_multistep( |
| model, |
| x, |
| sigmas, |
| extra_args=extra_args, |
| callback=callback, |
| disable=disable, |
| s_noise=s_noise, |
| noise_sampler=noise_sampler, |
| eta=eta, |
| cfg_pp=True, |
| ) |
|
|