| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from tqdm.auto import trange |
|
|
|
|
| @torch.no_grad() |
| def sample_dpmpp_2m_alt(model, x, sigmas, extra_args=None, callback=None, disable=None): |
| """DPM-Solver++(2M) alt""" |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| sigma_fn = lambda t: t.neg().exp() |
| t_fn = lambda sigma: sigma.log().neg() |
| old_denoised = None |
|
|
| for i in trange(len(sigmas) - 1, disable=disable): |
| denoised = model(x, sigmas[i] * s_in, **extra_args) |
| if callback is not None: |
| callback( |
| { |
| "x": x, |
| "i": i, |
| "sigma": sigmas[i], |
| "sigma_hat": sigmas[i], |
| "denoised": denoised, |
| } |
| ) |
| t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) |
| h = t_next - t |
| if old_denoised is None or sigmas[i + 1] == 0: |
| x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised |
| else: |
| h_last = t - t_fn(sigmas[i - 1]) |
| r = h_last / h |
| denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised |
| x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d |
| sigma_progress = i / len(sigmas) |
| adjustment_factor = 1 + (0.15 * (sigma_progress * sigma_progress)) |
| old_denoised = denoised * adjustment_factor |
| return x |
|
|
|
|
| def add_sample_dpmpp_2m_alt_comfy() -> None: |
| try: |
| from comfy.samplers import KSampler, k_diffusion_sampling |
| except ImportError: |
| return |
|
|
| if "dpmpp_2m_alt" not in KSampler.SAMPLERS: |
| try: |
| idx = KSampler.SAMPLERS.index("dpmpp_2m") |
| KSampler.SAMPLERS.insert(idx + 1, "dpmpp_2m_alt") |
| setattr(k_diffusion_sampling, "sample_dpmpp_2m_alt", sample_dpmpp_2m_alt) |
| import importlib |
|
|
| importlib.reload(k_diffusion_sampling) |
| except ValueError: |
| pass |
|
|
|
|
| def add_custom_samplers(): |
| samplers = [ |
| add_sample_dpmpp_2m_alt_comfy, |
| ] |
| for add_sampler in samplers: |
| add_sampler() |
|
|