Update webUI_ExtraSchedulers/scripts/clybius_dpmpp_4m_sde.py
Browse files
webUI_ExtraSchedulers/scripts/clybius_dpmpp_4m_sde.py
CHANGED
|
@@ -1,124 +1,110 @@
|
|
| 1 |
-
# by Clybius : github.com/Clybius/ComfyUI-Extra-Samplers/
|
| 2 |
-
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
sde_diff = momentum_func(phi_2 * d, vel_sde, time)
|
| 112 |
-
vel_sde = sde_diff
|
| 113 |
-
x = x + vel_sde
|
| 114 |
-
|
| 115 |
-
if eta:
|
| 116 |
-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
| 117 |
-
|
| 118 |
-
denoised_1, denoised_2, denoised_3 = denoised, denoised_1, denoised_2
|
| 119 |
-
h_1, h_2, h_3 = h, h_1, h_2
|
| 120 |
-
|
| 121 |
-
if callback is not None:
|
| 122 |
-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 123 |
-
|
| 124 |
-
return x
|
|
|
|
| 1 |
+
# by Clybius : github.com/Clybius/ComfyUI-Extra-Samplers/
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from tqdm.auto import trange
|
| 5 |
+
|
| 6 |
+
from k_diffusion.sampling import default_noise_sampler
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def sample_clyb_4m_sde_momentumized(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None, momentum=0.0):
|
| 10 |
+
"""DPM-Solver++(3M) SDE, modified with an extra SDE, and momentumized in both the SDE and ODE(?). 'its a first' - Clybius 2023
|
| 11 |
+
The expression for d1 is derived from the extrapolation formula given in the paper “Diffusion Monte Carlo with stochastic Hamiltonians” by M. Foulkes, L. Mitas, R. Needs, and G. Rajagopal. The formula is given as follows:
|
| 12 |
+
d1 = d1_0 + (d1_0 - d1_1) * r2 / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r2 + r1) * (r0 + r1))
|
| 13 |
+
(if this is an incorrect citing, we blame Google's Bard and OpenAI's ChatGPT for this and NOT me :^) )
|
| 14 |
+
|
| 15 |
+
where d1_0, d1_1, and d1_2 are defined as follows:
|
| 16 |
+
d1_0 = (denoised - denoised_1) / r2
|
| 17 |
+
d1_1 = (denoised_1 - denoised_2) / r1
|
| 18 |
+
d1_2 = (denoised_2 - denoised_3) / r0
|
| 19 |
+
|
| 20 |
+
The variables r0, r1, and r2 are defined as follows:
|
| 21 |
+
r0 = h_3 / h_2
|
| 22 |
+
r1 = h_2 / h
|
| 23 |
+
r2 = h / h_1
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def momentum_func(diff, velocity, timescale=1.0, offset=-momentum / 2.0): # Diff is current diff, vel is previous diff
|
| 27 |
+
if velocity is None:
|
| 28 |
+
momentum_vel = diff
|
| 29 |
+
else:
|
| 30 |
+
momentum_vel = momentum * (timescale + offset) * velocity + (1 - momentum * (timescale + offset)) * diff
|
| 31 |
+
return momentum_vel
|
| 32 |
+
|
| 33 |
+
sigma_max = sigmas.max()
|
| 34 |
+
|
| 35 |
+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
| 36 |
+
|
| 37 |
+
extra_args = {} if extra_args is None else extra_args
|
| 38 |
+
s_in = x.new_ones([x.shape[0]])
|
| 39 |
+
|
| 40 |
+
denoised_1, denoised_2, denoised_3 = None, None, None
|
| 41 |
+
h_1, h_2, h_3 = None, None, None
|
| 42 |
+
vel, vel_sde = None, None
|
| 43 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 44 |
+
time = sigmas[i] / sigma_max
|
| 45 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 46 |
+
|
| 47 |
+
if sigmas[i + 1] == 0:
|
| 48 |
+
# Denoising step
|
| 49 |
+
x = denoised
|
| 50 |
+
else:
|
| 51 |
+
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
| 52 |
+
h = s - t
|
| 53 |
+
h_eta = h * (eta + 1)
|
| 54 |
+
x_diff = momentum_func((-h_eta).expm1().neg() * denoised, vel, time)
|
| 55 |
+
vel = x_diff
|
| 56 |
+
x = torch.exp(-h_eta) * x + vel
|
| 57 |
+
|
| 58 |
+
if h_3 is not None:
|
| 59 |
+
r0 = h_1 / h
|
| 60 |
+
r1 = h_2 / h
|
| 61 |
+
r2 = h_3 / h
|
| 62 |
+
d1_0 = (denoised - denoised_1) / r0
|
| 63 |
+
d1_1 = (denoised_1 - denoised_2) / r1
|
| 64 |
+
d1_2 = (denoised_2 - denoised_3) / r2
|
| 65 |
+
# d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) + ((d1_0 - d1_1) * r2 / (r1 + r2) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r1 + r2) * (r0 + r1))
|
| 66 |
+
# d2 = (d1_0 - d1_1) / (r0 + r1) + ((d1_0 - d1_1) * r2 / (r1 + r2) - (d1_1 - d1_2) * r1 / (r0 + r1)) / ((r1 + r2) * (r0 + r1))
|
| 67 |
+
|
| 68 |
+
# r0 = h_3 / h_2
|
| 69 |
+
# r1 = h_2 / h
|
| 70 |
+
# r2 = h / h_1
|
| 71 |
+
# d1_0 = (denoised - denoised_1) / r2
|
| 72 |
+
# d1_1 = (denoised_1 - denoised_2) / r1
|
| 73 |
+
# d1_2 = (denoised_2 - denoised_3) / r0
|
| 74 |
+
d1 = d1_0 + (d1_0 - d1_1) * r2 / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r2 + r1) * (r0 + r1))
|
| 75 |
+
d2 = (d1_0 - d1_1) / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) / ((r2 + r1) * (r0 + r1))
|
| 76 |
+
phi_3 = h_eta.neg().expm1() / h_eta + 1
|
| 77 |
+
phi_4 = phi_3 / h_eta - 0.5
|
| 78 |
+
sde_diff = momentum_func(phi_3 * d1 - phi_4 * d2, vel_sde, time)
|
| 79 |
+
vel_sde = sde_diff
|
| 80 |
+
x = x + vel_sde
|
| 81 |
+
elif h_2 is not None:
|
| 82 |
+
r0 = h_1 / h
|
| 83 |
+
r1 = h_2 / h
|
| 84 |
+
d1_0 = (denoised - denoised_1) / r0
|
| 85 |
+
d1_1 = (denoised_1 - denoised_2) / r1
|
| 86 |
+
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
|
| 87 |
+
d2 = (d1_0 - d1_1) / (r0 + r1)
|
| 88 |
+
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
| 89 |
+
phi_3 = phi_2 / h_eta - 0.5
|
| 90 |
+
sde_diff = momentum_func(phi_2 * d1 - phi_3 * d2, vel_sde, time)
|
| 91 |
+
vel_sde = sde_diff
|
| 92 |
+
x = x + vel_sde
|
| 93 |
+
elif h_1 is not None:
|
| 94 |
+
r = h_1 / h
|
| 95 |
+
d = (denoised - denoised_1) / r
|
| 96 |
+
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
| 97 |
+
sde_diff = momentum_func(phi_2 * d, vel_sde, time)
|
| 98 |
+
vel_sde = sde_diff
|
| 99 |
+
x = x + vel_sde
|
| 100 |
+
|
| 101 |
+
if eta:
|
| 102 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
| 103 |
+
|
| 104 |
+
denoised_1, denoised_2, denoised_3 = denoised, denoised_1, denoised_2
|
| 105 |
+
h_1, h_2, h_3 = h, h_1, h_2
|
| 106 |
+
|
| 107 |
+
if callback is not None:
|
| 108 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 109 |
+
|
| 110 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|