Commit
·
e47d403
1
Parent(s):
239ab44
Upload k_diffusion_dpmpp.diff
Browse files- k_diffusion_dpmpp.diff +145 -0
k_diffusion_dpmpp.diff
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/README.md b/README.md
|
| 2 |
+
index 4f7c92f..e386624 100644
|
| 3 |
+
--- a/README.md
|
| 4 |
+
+++ b/README.md
|
| 5 |
+
@@ -1,3 +1,12 @@
|
| 6 |
+
+# THIS IS A FORK
|
| 7 |
+
+
|
| 8 |
+
+Forked from https://github.com/crowsonkb/k-diffusion
|
| 9 |
+
+
|
| 10 |
+
+Changes:
|
| 11 |
+
+
|
| 12 |
+
+1. Add DPM++ 2M sampling fix by @hallatore https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
|
| 13 |
+
+2. Add MPS fix for MacOS by @brkirch https://github.com/brkirch/k-diffusion
|
| 14 |
+
+
|
| 15 |
+
# k-diffusion
|
| 16 |
+
|
| 17 |
+
An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well.
|
| 18 |
+
diff --git a/k_diffusion/external.py b/k_diffusion/external.py
|
| 19 |
+
index 79b51ce..b41d0eb 100644
|
| 20 |
+
--- a/k_diffusion/external.py
|
| 21 |
+
+++ b/k_diffusion/external.py
|
| 22 |
+
@@ -79,7 +79,9 @@ class DiscreteSchedule(nn.Module):
|
| 23 |
+
|
| 24 |
+
def t_to_sigma(self, t):
|
| 25 |
+
t = t.float()
|
| 26 |
+
- low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
| 27 |
+
+ low_idx = t.floor().long()
|
| 28 |
+
+ high_idx = t.ceil().long()
|
| 29 |
+
+ w = t - low_idx if t.device.type == 'mps' else t.frac()
|
| 30 |
+
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
| 31 |
+
return log_sigma.exp()
|
| 32 |
+
|
| 33 |
+
diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py
|
| 34 |
+
index f050f88..9f859d4 100644
|
| 35 |
+
--- a/k_diffusion/sampling.py
|
| 36 |
+
+++ b/k_diffusion/sampling.py
|
| 37 |
+
@@ -16,7 +16,7 @@ def append_zero(x):
|
| 38 |
+
|
| 39 |
+
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
| 40 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 41 |
+
- ramp = torch.linspace(0, 1, n)
|
| 42 |
+
+ ramp = torch.linspace(0, 1, n, device=device)
|
| 43 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 44 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 45 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 46 |
+
@@ -400,7 +400,13 @@ class DPMSolver(nn.Module):
|
| 47 |
+
|
| 48 |
+
for i in range(len(orders)):
|
| 49 |
+
eps_cache = {}
|
| 50 |
+
- t, t_next = ts[i], ts[i + 1]
|
| 51 |
+
+
|
| 52 |
+
+ # MacOS fix
|
| 53 |
+
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 54 |
+
+ t, t_next = ts[i].detach().clone(), ts[i + 1].detach().clone()
|
| 55 |
+
+ else:
|
| 56 |
+
+ t, t_next = ts[i], ts[i + 1]
|
| 57 |
+
+
|
| 58 |
+
if eta:
|
| 59 |
+
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
|
| 60 |
+
t_next_ = torch.minimum(t_end, self.t(sd))
|
| 61 |
+
@@ -512,7 +518,12 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
| 62 |
+
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
| 63 |
+
s_in = x.new_ones([x.shape[0]])
|
| 64 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 65 |
+
- t_fn = lambda sigma: sigma.log().neg()
|
| 66 |
+
+
|
| 67 |
+
+ # MacOS fix
|
| 68 |
+
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 69 |
+
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
|
| 70 |
+
+ else:
|
| 71 |
+
+ t_fn = lambda sigma: sigma.log().neg()
|
| 72 |
+
|
| 73 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 74 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 75 |
+
@@ -547,7 +558,12 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
| 76 |
+
extra_args = {} if extra_args is None else extra_args
|
| 77 |
+
s_in = x.new_ones([x.shape[0]])
|
| 78 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 79 |
+
- t_fn = lambda sigma: sigma.log().neg()
|
| 80 |
+
+
|
| 81 |
+
+ # MacOS fix
|
| 82 |
+
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 83 |
+
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
|
| 84 |
+
+ else:
|
| 85 |
+
+ t_fn = lambda sigma: sigma.log().neg()
|
| 86 |
+
|
| 87 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 88 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 89 |
+
@@ -587,7 +603,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
| 90 |
+
extra_args = {} if extra_args is None else extra_args
|
| 91 |
+
s_in = x.new_ones([x.shape[0]])
|
| 92 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 93 |
+
- t_fn = lambda sigma: sigma.log().neg()
|
| 94 |
+
+
|
| 95 |
+
+ # MacOS fix
|
| 96 |
+
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 97 |
+
+ t_fn = lambda sigma: sigma.detach().clone().log().neg()
|
| 98 |
+
+ else:
|
| 99 |
+
+ t_fn = lambda sigma: sigma.log().neg()
|
| 100 |
+
+
|
| 101 |
+
old_denoised = None
|
| 102 |
+
|
| 103 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 104 |
+
@@ -596,12 +618,22 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
| 105 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 106 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 107 |
+
h = t_next - t
|
| 108 |
+
+
|
| 109 |
+
+ t_min = min(sigma_fn(t_next), sigma_fn(t))
|
| 110 |
+
+ t_max = max(sigma_fn(t_next), sigma_fn(t))
|
| 111 |
+
+
|
| 112 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
| 113 |
+
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
| 114 |
+
+ x = (t_min / t_max) * x - (-h).expm1() * denoised
|
| 115 |
+
else:
|
| 116 |
+
h_last = t - t_fn(sigmas[i - 1])
|
| 117 |
+
- r = h_last / h
|
| 118 |
+
+
|
| 119 |
+
+ h_min = min(h_last, h)
|
| 120 |
+
+ h_max = max(h_last, h)
|
| 121 |
+
+ r = h_max / h_min
|
| 122 |
+
+
|
| 123 |
+
+ h_d = (h_max + h_min) / 2
|
| 124 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
| 125 |
+
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
| 126 |
+
+ x = (t_min / t_max) * x - (-h_d).expm1() * denoised_d
|
| 127 |
+
+
|
| 128 |
+
old_denoised = denoised
|
| 129 |
+
return x
|
| 130 |
+
diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py
|
| 131 |
+
index 9afedb9..ce6014b 100644
|
| 132 |
+
--- a/k_diffusion/utils.py
|
| 133 |
+
+++ b/k_diffusion/utils.py
|
| 134 |
+
@@ -42,7 +42,10 @@ def append_dims(x, target_dims):
|
| 135 |
+
dims_to_append = target_dims - x.ndim
|
| 136 |
+
if dims_to_append < 0:
|
| 137 |
+
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
| 138 |
+
- return x[(...,) + (None,) * dims_to_append]
|
| 139 |
+
+ expanded = x[(...,) + (None,) * dims_to_append]
|
| 140 |
+
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
| 141 |
+
+ # https://github.com/pytorch/pytorch/issues/84364
|
| 142 |
+
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def n_params(module):
|