Commit
·
ad28782
1
Parent(s):
34d028e
Update repositories/k-diffusion/k_diffusion/sampling.py
Browse files
repositories/k-diffusion/k_diffusion/sampling.py
CHANGED
|
@@ -649,3 +649,39 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|
| 649 |
old_denoised = denoised
|
| 650 |
h_last = h
|
| 651 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
old_denoised = denoised
|
| 650 |
h_last = h
|
| 651 |
return x
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
@torch.no_grad()
|
| 655 |
+
def sample_dpmpp_2m_test(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
| 656 |
+
"""DPM-Solver++(2M)."""
|
| 657 |
+
extra_args = {} if extra_args is None else extra_args
|
| 658 |
+
s_in = x.new_ones([x.shape[0]])
|
| 659 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 660 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 661 |
+
old_denoised = None
|
| 662 |
+
|
| 663 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 664 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 665 |
+
if callback is not None:
|
| 666 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 667 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 668 |
+
h = t_next - t
|
| 669 |
+
|
| 670 |
+
t_min = min(sigma_fn(t_next), sigma_fn(t))
|
| 671 |
+
t_max = max(sigma_fn(t_next), sigma_fn(t))
|
| 672 |
+
|
| 673 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
| 674 |
+
x = (t_min / t_max) * x - (-h).expm1() * denoised
|
| 675 |
+
else:
|
| 676 |
+
h_last = t - t_fn(sigmas[i - 1])
|
| 677 |
+
|
| 678 |
+
h_min = min(h_last, h)
|
| 679 |
+
h_max = max(h_last, h)
|
| 680 |
+
r = h_max / h_min
|
| 681 |
+
|
| 682 |
+
h_d = (h_max + h_min) / 2
|
| 683 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
| 684 |
+
x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d
|
| 685 |
+
|
| 686 |
+
old_denoised = denoised
|
| 687 |
+
return x
|