dikdimon commited on
Commit
236682d
·
verified ·
1 Parent(s): 09e9f65

Update webUI_ExtraSchedulers/scripts/clybius_dpmpp_4m_sde.py

Browse files
webUI_ExtraSchedulers/scripts/clybius_dpmpp_4m_sde.py CHANGED
@@ -1,124 +1,110 @@
1
- # by Clybius : github.com/Clybius/ComfyUI-Extra-Samplers/
2
-
3
- import math
4
-
5
- import torch
6
- from torch import nn, FloatTensor
7
- import torchsde
8
- import kornia
9
- from tqdm.auto import trange, tqdm
10
- import numpy as np
11
-
12
- import sample
13
-
14
- from k_diffusion.sampling import BrownianTreeNoiseSampler, PIDStepSizeController, get_ancestral_step, to_d, default_noise_sampler, DPMSolver
15
-
16
-
17
- # copied from kdiffusion/sampling.py and utils.py
18
- def default_noise_sampler(x):
19
- return lambda sigma, sigma_next: torch.randn_like(x)
20
-
21
-
22
- @torch.no_grad()
23
- def sample_clyb_4m_sde_momentumized(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None, momentum=0.0):
24
- """DPM-Solver++(3M) SDE, modified with an extra SDE, and momentumized in both the SDE and ODE(?). 'its a first' - Clybius 2023
25
- The expression for d1 is derived from the extrapolation formula given in the paper “Diffusion Monte Carlo with stochastic Hamiltonians” by M. Foulkes, L. Mitas, R. Needs, and G. Rajagopal. The formula is given as follows:
26
- d1 = d1_0 + (d1_0 - d1_1) * r2 / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r2 + r1) * (r0 + r1))
27
- (if this is an incorrect citing, we blame Google's Bard and OpenAI's ChatGPT for this and NOT me :^) )
28
-
29
- where d1_0, d1_1, and d1_2 are defined as follows:
30
- d1_0 = (denoised - denoised_1) / r2
31
- d1_1 = (denoised_1 - denoised_2) / r1
32
- d1_2 = (denoised_2 - denoised_3) / r0
33
-
34
- The variables r0, r1, and r2 are defined as follows:
35
- r0 = h_3 / h_2
36
- r1 = h_2 / h
37
- r2 = h / h_1
38
- """
39
-
40
- def momentum_func(diff, velocity, timescale=1.0, offset=-momentum / 2.0): # Diff is current diff, vel is previous diff
41
- if velocity is None:
42
- momentum_vel = diff
43
- else:
44
- momentum_vel = momentum * (timescale + offset) * velocity + (1 - momentum * (timescale + offset)) * diff
45
- return momentum_vel
46
-
47
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
48
-
49
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
50
-
51
- extra_args = {} if extra_args is None else extra_args
52
- s_in = x.new_ones([x.shape[0]])
53
-
54
- denoised_1, denoised_2, denoised_3 = None, None, None
55
- h_1, h_2, h_3 = None, None, None
56
- vel, vel_sde = None, None
57
- for i in trange(len(sigmas) - 1, disable=disable):
58
- time = sigmas[i] / sigma_max
59
- denoised = model(x, sigmas[i] * s_in, **extra_args)
60
-
61
- if sigmas[i + 1] == 0:
62
- # Denoising step
63
- x = denoised
64
- else:
65
- t, s = -sigmas[i].log(), -sigmas[i + 1].log()
66
- h = s - t
67
- h_eta = h * (eta + 1)
68
- x_diff = momentum_func((-h_eta).expm1().neg() * denoised, vel, time)
69
- vel = x_diff
70
- x = torch.exp(-h_eta) * x + vel
71
-
72
- if h_3 is not None:
73
- r0 = h_1 / h
74
- r1 = h_2 / h
75
- r2 = h_3 / h
76
- d1_0 = (denoised - denoised_1) / r0
77
- d1_1 = (denoised_1 - denoised_2) / r1
78
- d1_2 = (denoised_2 - denoised_3) / r2
79
- # d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) + ((d1_0 - d1_1) * r2 / (r1 + r2) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r1 + r2) * (r0 + r1))
80
- # d2 = (d1_0 - d1_1) / (r0 + r1) + ((d1_0 - d1_1) * r2 / (r1 + r2) - (d1_1 - d1_2) * r1 / (r0 + r1)) / ((r1 + r2) * (r0 + r1))
81
-
82
- # r0 = h_3 / h_2
83
- # r1 = h_2 / h
84
- # r2 = h / h_1
85
- # d1_0 = (denoised - denoised_1) / r2
86
- # d1_1 = (denoised_1 - denoised_2) / r1
87
- # d1_2 = (denoised_2 - denoised_3) / r0
88
- d1 = d1_0 + (d1_0 - d1_1) * r2 / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r2 + r1) * (r0 + r1))
89
- d2 = (d1_0 - d1_1) / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) / ((r2 + r1) * (r0 + r1))
90
- phi_3 = h_eta.neg().expm1() / h_eta + 1
91
- phi_4 = phi_3 / h_eta - 0.5
92
- sde_diff = momentum_func(phi_3 * d1 - phi_4 * d2, vel_sde, time)
93
- vel_sde = sde_diff
94
- x = x + vel_sde
95
- elif h_2 is not None:
96
- r0 = h_1 / h
97
- r1 = h_2 / h
98
- d1_0 = (denoised - denoised_1) / r0
99
- d1_1 = (denoised_1 - denoised_2) / r1
100
- d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
101
- d2 = (d1_0 - d1_1) / (r0 + r1)
102
- phi_2 = h_eta.neg().expm1() / h_eta + 1
103
- phi_3 = phi_2 / h_eta - 0.5
104
- sde_diff = momentum_func(phi_2 * d1 - phi_3 * d2, vel_sde, time)
105
- vel_sde = sde_diff
106
- x = x + vel_sde
107
- elif h_1 is not None:
108
- r = h_1 / h
109
- d = (denoised - denoised_1) / r
110
- phi_2 = h_eta.neg().expm1() / h_eta + 1
111
- sde_diff = momentum_func(phi_2 * d, vel_sde, time)
112
- vel_sde = sde_diff
113
- x = x + vel_sde
114
-
115
- if eta:
116
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
117
-
118
- denoised_1, denoised_2, denoised_3 = denoised, denoised_1, denoised_2
119
- h_1, h_2, h_3 = h, h_1, h_2
120
-
121
- if callback is not None:
122
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
123
-
124
- return x
 
1
+ # by Clybius : github.com/Clybius/ComfyUI-Extra-Samplers/
2
+
3
+ import torch
4
+ from tqdm.auto import trange
5
+
6
+ from k_diffusion.sampling import default_noise_sampler
7
+
8
+ @torch.no_grad()
9
+ def sample_clyb_4m_sde_momentumized(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None, momentum=0.0):
10
+ """DPM-Solver++(3M) SDE, modified with an extra SDE, and momentumized in both the SDE and ODE(?). 'its a first' - Clybius 2023
11
+ The expression for d1 is derived from the extrapolation formula given in the paper “Diffusion Monte Carlo with stochastic Hamiltonians” by M. Foulkes, L. Mitas, R. Needs, and G. Rajagopal. The formula is given as follows:
12
+ d1 = d1_0 + (d1_0 - d1_1) * r2 / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r2 + r1) * (r0 + r1))
13
+ (if this is an incorrect citing, we blame Google's Bard and OpenAI's ChatGPT for this and NOT me :^) )
14
+
15
+ where d1_0, d1_1, and d1_2 are defined as follows:
16
+ d1_0 = (denoised - denoised_1) / r2
17
+ d1_1 = (denoised_1 - denoised_2) / r1
18
+ d1_2 = (denoised_2 - denoised_3) / r0
19
+
20
+ The variables r0, r1, and r2 are defined as follows:
21
+ r0 = h_3 / h_2
22
+ r1 = h_2 / h
23
+ r2 = h / h_1
24
+ """
25
+
26
+ def momentum_func(diff, velocity, timescale=1.0, offset=-momentum / 2.0): # Diff is current diff, vel is previous diff
27
+ if velocity is None:
28
+ momentum_vel = diff
29
+ else:
30
+ momentum_vel = momentum * (timescale + offset) * velocity + (1 - momentum * (timescale + offset)) * diff
31
+ return momentum_vel
32
+
33
+ sigma_max = sigmas.max()
34
+
35
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
36
+
37
+ extra_args = {} if extra_args is None else extra_args
38
+ s_in = x.new_ones([x.shape[0]])
39
+
40
+ denoised_1, denoised_2, denoised_3 = None, None, None
41
+ h_1, h_2, h_3 = None, None, None
42
+ vel, vel_sde = None, None
43
+ for i in trange(len(sigmas) - 1, disable=disable):
44
+ time = sigmas[i] / sigma_max
45
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
46
+
47
+ if sigmas[i + 1] == 0:
48
+ # Denoising step
49
+ x = denoised
50
+ else:
51
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
52
+ h = s - t
53
+ h_eta = h * (eta + 1)
54
+ x_diff = momentum_func((-h_eta).expm1().neg() * denoised, vel, time)
55
+ vel = x_diff
56
+ x = torch.exp(-h_eta) * x + vel
57
+
58
+ if h_3 is not None:
59
+ r0 = h_1 / h
60
+ r1 = h_2 / h
61
+ r2 = h_3 / h
62
+ d1_0 = (denoised - denoised_1) / r0
63
+ d1_1 = (denoised_1 - denoised_2) / r1
64
+ d1_2 = (denoised_2 - denoised_3) / r2
65
+ # d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) + ((d1_0 - d1_1) * r2 / (r1 + r2) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r1 + r2) * (r0 + r1))
66
+ # d2 = (d1_0 - d1_1) / (r0 + r1) + ((d1_0 - d1_1) * r2 / (r1 + r2) - (d1_1 - d1_2) * r1 / (r0 + r1)) / ((r1 + r2) * (r0 + r1))
67
+
68
+ # r0 = h_3 / h_2
69
+ # r1 = h_2 / h
70
+ # r2 = h / h_1
71
+ # d1_0 = (denoised - denoised_1) / r2
72
+ # d1_1 = (denoised_1 - denoised_2) / r1
73
+ # d1_2 = (denoised_2 - denoised_3) / r0
74
+ d1 = d1_0 + (d1_0 - d1_1) * r2 / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) * r2 / ((r2 + r1) * (r0 + r1))
75
+ d2 = (d1_0 - d1_1) / (r2 + r1) + ((d1_0 - d1_1) * r2 / (r2 + r1) - (d1_1 - d1_2) * r1 / (r0 + r1)) / ((r2 + r1) * (r0 + r1))
76
+ phi_3 = h_eta.neg().expm1() / h_eta + 1
77
+ phi_4 = phi_3 / h_eta - 0.5
78
+ sde_diff = momentum_func(phi_3 * d1 - phi_4 * d2, vel_sde, time)
79
+ vel_sde = sde_diff
80
+ x = x + vel_sde
81
+ elif h_2 is not None:
82
+ r0 = h_1 / h
83
+ r1 = h_2 / h
84
+ d1_0 = (denoised - denoised_1) / r0
85
+ d1_1 = (denoised_1 - denoised_2) / r1
86
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
87
+ d2 = (d1_0 - d1_1) / (r0 + r1)
88
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
89
+ phi_3 = phi_2 / h_eta - 0.5
90
+ sde_diff = momentum_func(phi_2 * d1 - phi_3 * d2, vel_sde, time)
91
+ vel_sde = sde_diff
92
+ x = x + vel_sde
93
+ elif h_1 is not None:
94
+ r = h_1 / h
95
+ d = (denoised - denoised_1) / r
96
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
97
+ sde_diff = momentum_func(phi_2 * d, vel_sde, time)
98
+ vel_sde = sde_diff
99
+ x = x + vel_sde
100
+
101
+ if eta:
102
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
103
+
104
+ denoised_1, denoised_2, denoised_3 = denoised, denoised_1, denoised_2
105
+ h_1, h_2, h_3 = h, h_1, h_2
106
+
107
+ if callback is not None:
108
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
109
+
110
+ return x