dikdimon commited on
Commit
448a955
·
verified ·
1 Parent(s): 1056502

Update webUI_ExtraSchedulers/scripts/samplers_cfgpp.py

Browse files
webUI_ExtraSchedulers/scripts/samplers_cfgpp.py CHANGED
@@ -1,264 +1,278 @@
1
- import torch
2
- from tqdm.auto import trange
3
-
4
- # copied from kdiffusion/sampling.py and utils.py
5
- def default_noise_sampler(x):
6
- return lambda sigma, sigma_next: torch.randn_like(x)
7
- def get_ancestral_step(sigma_from, sigma_to, eta=1.):
8
- """Calculates the noise level (sigma_down) to step down to and the amount
9
- of noise to add (sigma_up) when doing an ancestral sampling step."""
10
- if not eta:
11
- return sigma_to, 0.
12
- sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
13
- sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
14
- return sigma_down, sigma_up
15
- def append_dims(x, target_dims):
16
- """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
17
- dims_to_append = target_dims - x.ndim
18
- if dims_to_append < 0:
19
- raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
20
- return x[(...,) + (None,) * dims_to_append]
21
- def to_d(x, sigma, denoised):
22
- """Converts a denoiser output to a Karras ODE derivative."""
23
- return (x - denoised) / append_dims(sigma, x.ndim)
24
-
25
-
26
- @torch.no_grad()
27
- def sample_euler_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
28
- """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
29
- extra_args = {} if extra_args is None else extra_args
30
- model.need_last_noise_uncond = True
31
- model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
32
- s_in = x.new_ones([x.shape[0]])
33
-
34
- for i in trange(len(sigmas) - 1, disable=disable):
35
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
36
- eps = torch.randn_like(x) * s_noise
37
- sigma_hat = sigmas[i] * (gamma + 1)
38
- if gamma > 0:
39
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
40
- denoised = model(x, sigma_hat * s_in, **extra_args)
41
- d = model.last_noise_uncond
42
-
43
- if callback is not None:
44
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
45
-
46
- # Euler method
47
- x = denoised + d * sigmas[i+1]
48
- return x
49
-
50
- class _Rescaler:
51
- def __init__(self, model, x, mode, **extra_args):
52
- self.model = model
53
- self.x = x
54
- self.mode = mode
55
- self.extra_args = extra_args
56
- self.init_latent, self.mask, self.nmask = model.init_latent, model.mask, model.nmask
57
-
58
- def __enter__(self):
59
- if self.init_latent is not None:
60
- self.model.init_latent = torch.nn.functional.interpolate(input=self.init_latent, size=self.x.shape[2:4], mode=self.mode)
61
- if self.mask is not None:
62
- self.model.mask = torch.nn.functional.interpolate(input=self.mask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
63
- if self.nmask is not None:
64
- self.model.nmask = torch.nn.functional.interpolate(input=self.nmask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
65
-
66
- return self
67
-
68
- def __exit__(self, type, value, traceback):
69
- del self.model.init_latent, self.model.mask, self.model.nmask
70
- self.model.init_latent, self.model.mask, self.model.nmask = self.init_latent, self.mask, self.nmask
71
-
72
- @torch.no_grad()
73
- def dy_sampling_step_cfgpp(x, model, sigma_hat, **extra_args):
74
- original_shape = x.shape
75
- batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
76
- extra_row = x.shape[2] % 2 == 1
77
- extra_col = x.shape[3] % 2 == 1
78
-
79
- if extra_row:
80
- extra_row_content = x[:, :, -1:, :]
81
- x = x[:, :, :-1, :]
82
- if extra_col:
83
- extra_col_content = x[:, :, :, -1:]
84
- x = x[:, :, :, :-1]
85
-
86
- a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2)
87
- c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)
88
-
89
- with _Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
90
- denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
91
- d = model.last_noise_uncond
92
- c = denoised + d * sigma_hat
93
-
94
- d_list = c.view(batch_size, channels, m * n, 1, 1)
95
- a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
96
- x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n)
97
-
98
- if extra_row or extra_col:
99
- x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
100
- x_expanded[:, :, :2 * m, :2 * n] = x
101
- if extra_row:
102
- x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content
103
- if extra_col:
104
- x_expanded[:, :, :2 * m, -1:] = extra_col_content
105
- if extra_row and extra_col:
106
- x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
107
- x = x_expanded
108
-
109
- return x
110
-
111
- @torch.no_grad()
112
- def smea_sampling_step_cfgpp(x, model, sigma_hat, **extra_args):
113
- m, n = x.shape[2], x.shape[3]
114
- x = torch.nn.functional.interpolate(input=x, scale_factor=(1.25, 1.25), mode='nearest-exact')
115
- with _Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
116
- denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
117
- d = model.last_noise_uncond
118
- x = denoised + d * sigma_hat
119
- x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
120
- return x
121
-
122
-
123
- @torch.no_grad()
124
- def sample_euler_dy_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
125
- """CFG++ version of Euler Dy by KoishiStar."""
126
- extra_args = {} if extra_args is None else extra_args
127
- model.need_last_noise_uncond = True
128
- model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
129
- s_in = x.new_ones([x.shape[0]])
130
-
131
- for i in trange(len(sigmas) - 1, disable=disable):
132
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
133
- eps = torch.randn_like(x) * s_noise
134
- sigma_hat = sigmas[i] * (gamma + 1)
135
- if gamma > 0:
136
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
137
- denoised = model(x, sigma_hat * s_in, **extra_args)
138
- d = model.last_noise_uncond
139
-
140
- if callback is not None:
141
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
142
-
143
- # Euler method
144
- x = denoised + d * sigmas[i+1]
145
-
146
- if sigmas[i + 1] > 0:
147
- if i // 2 == 1:
148
- x = dy_sampling_step_cfgpp(x, model, sigma_hat, **extra_args)
149
-
150
- return x
151
-
152
- @torch.no_grad()
153
- def sample_euler_negative_dy_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
154
- """CFG++ version of Euler Negative Dy by KoishiStar."""
155
- extra_args = {} if extra_args is None else extra_args
156
- model.need_last_noise_uncond = True
157
- model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
158
- s_in = x.new_ones([x.shape[0]])
159
-
160
- for i in trange(len(sigmas) - 1, disable=disable):
161
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
162
- eps = torch.randn_like(x) * s_noise
163
- sigma_hat = sigmas[i] * (gamma + 1)
164
- if gamma > 0:
165
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
166
- denoised = model(x, sigma_hat * s_in, **extra_args)
167
- d = model.last_noise_uncond
168
-
169
- if callback is not None:
170
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
171
-
172
- # Euler method
173
- if sigmas[i + 1] > 0 and i // 2 == 1:
174
- x = -denoised - d * sigmas[i+1]
175
- else:
176
- x = denoised + d * sigmas[i+1]
177
-
178
- if sigmas[i + 1] > 0:
179
- if i // 2 == 1:
180
- x = dy_sampling_step_cfgpp(x, model, sigma_hat, **extra_args)
181
-
182
- return x
183
-
184
- @torch.no_grad()
185
- def sample_euler_negative_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
186
- """based on Euler Negative by KoishiStar"""
187
- extra_args = {} if extra_args is None else extra_args
188
- model.need_last_noise_uncond = True
189
- model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
190
- s_in = x.new_ones([x.shape[0]])
191
-
192
- for i in trange(len(sigmas) - 1, disable=disable):
193
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
194
- eps = torch.randn_like(x) * s_noise
195
- sigma_hat = sigmas[i] * (gamma + 1)
196
- if gamma > 0:
197
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
198
- denoised = model(x, sigma_hat * s_in, **extra_args)
199
- d = model.last_noise_uncond
200
-
201
- if callback is not None:
202
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
203
-
204
- # Euler method
205
- if sigmas[i + 1] > 0 and i // 2 == 1:
206
- x = -denoised - d * sigmas[i+1]
207
- else:
208
- x = denoised + d * sigmas[i+1]
209
- return x
210
-
211
-
212
- @torch.no_grad()
213
- def sample_euler_smea_dy_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
214
- """CFG++ version of Euler SMEA Dy by KoishiStar."""
215
- extra_args = {} if extra_args is None else extra_args
216
- model.need_last_noise_uncond = True
217
- model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
218
- s_in = x.new_ones([x.shape[0]])
219
-
220
- for i in trange(len(sigmas) - 1, disable=disable):
221
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
222
- eps = torch.randn_like(x) * s_noise
223
- sigma_hat = sigmas[i] * (gamma + 1)
224
- if gamma > 0:
225
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
226
- denoised = model(x, sigma_hat * s_in, **extra_args)
227
- d = model.last_noise_uncond
228
-
229
- if callback is not None:
230
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
231
-
232
- # Euler method
233
- x = denoised + d * sigmas[i+1]
234
-
235
- if sigmas[i + 1] > 0:
236
- if i + 1 // 2 == 1: # ?? this is i == 1; why not if i // 2 == 1 same as Euler Dy
237
- x = dy_sampling_step_cfgpp(x, model, sigma_hat, **extra_args)
238
- if i + 1 // 2 == 0: # ?? this is i == 0
239
- x = smea_sampling_step_cfgpp(x, model, sigma_hat, **extra_args)
240
- return x
241
-
242
- @torch.no_grad()
243
- def sample_euler_ancestral_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
244
- """Ancestral sampling with Euler method steps."""
245
- extra_args = {} if extra_args is None else extra_args
246
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
247
- model.need_last_noise_uncond = True
248
- model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
249
- s_in = x.new_ones([x.shape[0]])
250
-
251
- for i in trange(len(sigmas) - 1, disable=disable):
252
- denoised = model(x, sigmas[i] * s_in, **extra_args)
253
- d = model.last_noise_uncond
254
-
255
- sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
256
-
257
- if callback is not None:
258
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
259
-
260
- # Euler method
261
- x = denoised + d * sigma_down
262
- if sigmas[i + 1] > 0:
263
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
264
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm.auto import trange
3
+
4
+ from k_diffusion.sampling import (
5
+ default_noise_sampler,
6
+ get_ancestral_step,
7
+ )
8
+
9
+
10
+ @torch.no_grad()
11
+ def sample_euler_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
12
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
13
+ extra_args = {} if extra_args is None else extra_args
14
+ model.need_last_noise_uncond = True
15
+ model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
16
+ s_in = x.new_ones([x.shape[0]])
17
+
18
+ if s_churn > 0.0:
19
+ seed = (int(x[0,0,0,0].item()) * 1234567890) % 65536
20
+ generator = torch.Generator(device='cpu').manual_seed(seed)
21
+ else:
22
+ generator = None
23
+
24
+ for i in trange(len(sigmas) - 1, disable=disable):
25
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
26
+ sigma_hat = sigmas[i] * (gamma + 1)
27
+ if gamma > 0:
28
+ eps = torch.randn(x.shape, generator=generator).to(x) * s_noise
29
+ x.add_(eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5)
30
+ denoised = model(x, sigma_hat * s_in, **extra_args)
31
+ d = model.last_noise_uncond
32
+
33
+ if callback is not None:
34
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
35
+
36
+ # Euler method
37
+ x = denoised + d * sigmas[i+1]
38
+ return x
39
+
40
+ class _Rescaler:
41
+ def __init__(self, model, x, mode, **extra_args):
42
+ self.model = model
43
+ self.x = x
44
+ self.mode = mode
45
+ self.extra_args = extra_args
46
+ self.init_latent, self.mask, self.nmask = model.init_latent, model.mask, model.nmask
47
+
48
+ def __enter__(self):
49
+ if self.init_latent is not None:
50
+ self.model.init_latent = torch.nn.functional.interpolate(input=self.init_latent, size=self.x.shape[2:4], mode=self.mode)
51
+ if self.mask is not None:
52
+ self.model.mask = torch.nn.functional.interpolate(input=self.mask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
53
+ if self.nmask is not None:
54
+ self.model.nmask = torch.nn.functional.interpolate(input=self.nmask.unsqueeze(0), size=self.x.shape[2:4], mode=self.mode).squeeze(0)
55
+
56
+ return self
57
+
58
+ def __exit__(self, type, value, traceback):
59
+ del self.model.init_latent, self.model.mask, self.model.nmask
60
+ self.model.init_latent, self.model.mask, self.model.nmask = self.init_latent, self.mask, self.nmask
61
+
62
+ @torch.no_grad()
63
+ def dy_sampling_step_cfgpp(x, model, sigma_hat, **extra_args):
64
+ original_shape = x.shape
65
+ batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
66
+ extra_row = x.shape[2] % 2 == 1
67
+ extra_col = x.shape[3] % 2 == 1
68
+
69
+ if extra_row:
70
+ extra_row_content = x[:, :, -1:, :]
71
+ x = x[:, :, :-1, :]
72
+ if extra_col:
73
+ extra_col_content = x[:, :, :, -1:]
74
+ x = x[:, :, :, :-1]
75
+
76
+ a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2)
77
+ c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)
78
+
79
+ with _Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
80
+ denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
81
+ d = model.last_noise_uncond
82
+ c = denoised + d * sigma_hat
83
+
84
+ d_list = c.view(batch_size, channels, m * n, 1, 1)
85
+ a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
86
+ x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n)
87
+
88
+ if extra_row or extra_col:
89
+ x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
90
+ x_expanded[:, :, :2 * m, :2 * n] = x
91
+ if extra_row:
92
+ x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content
93
+ if extra_col:
94
+ x_expanded[:, :, :2 * m, -1:] = extra_col_content
95
+ if extra_row and extra_col:
96
+ x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
97
+ x = x_expanded
98
+
99
+ return x
100
+
101
+ @torch.no_grad()
102
+ def smea_sampling_step_cfgpp(x, model, sigma_hat, **extra_args):
103
+ m, n = x.shape[2], x.shape[3]
104
+ x = torch.nn.functional.interpolate(input=x, scale_factor=(1.25, 1.25), mode='nearest-exact')
105
+ with _Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
106
+ denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
107
+ d = model.last_noise_uncond
108
+ x = denoised + d * sigma_hat
109
+ x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
110
+ return x
111
+
112
+
113
+ @torch.no_grad()
114
+ def sample_euler_dy_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
115
+ """CFG++ version of Euler Dy by KoishiStar."""
116
+ extra_args = {} if extra_args is None else extra_args
117
+ model.need_last_noise_uncond = True
118
+ model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
119
+ s_in = x.new_ones([x.shape[0]])
120
+
121
+ if s_churn > 0.0:
122
+ seed = (int(x[0,0,0,0].item()) * 1234567890) % 65536
123
+ generator = torch.Generator(device='cpu').manual_seed(seed)
124
+ else:
125
+ generator = None
126
+
127
+ for i in trange(len(sigmas) - 1, disable=disable):
128
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
129
+ sigma_hat = sigmas[i] * (gamma + 1)
130
+ if gamma > 0:
131
+ eps = torch.randn(x.shape, generator=generator).to(x) * s_noise
132
+ x .add_(eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5)
133
+ denoised = model(x, sigma_hat * s_in, **extra_args)
134
+ d = model.last_noise_uncond
135
+
136
+ if callback is not None:
137
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
138
+
139
+ # Euler method
140
+ x = denoised + d * sigmas[i+1]
141
+
142
+ if sigmas[i + 1] > 0:
143
+ if i // 2 == 1:
144
+ x = dy_sampling_step_cfgpp(x, model, sigma_hat, **extra_args)
145
+
146
+ return x
147
+
148
+ @torch.no_grad()
149
+ def sample_euler_negative_dy_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
150
+ """CFG++ version of Euler Negative Dy by KoishiStar."""
151
+ extra_args = {} if extra_args is None else extra_args
152
+ model.need_last_noise_uncond = True
153
+ model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
154
+ s_in = x.new_ones([x.shape[0]])
155
+
156
+ if s_churn > 0.0:
157
+ seed = (int(x[0,0,0,0].item()) * 1234567890) % 65536
158
+ generator = torch.Generator(device='cpu').manual_seed(seed)
159
+ else:
160
+ generator = None
161
+
162
+ for i in trange(len(sigmas) - 1, disable=disable):
163
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
164
+ sigma_hat = sigmas[i] * (gamma + 1)
165
+ if gamma > 0:
166
+ eps = torch.randn(x.shape, generator=generator).to(x) * s_noise
167
+ x.add_(eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5)
168
+ denoised = model(x, sigma_hat * s_in, **extra_args)
169
+ d = model.last_noise_uncond
170
+
171
+ if callback is not None:
172
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
173
+
174
+ # Euler method
175
+ if sigmas[i + 1] > 0 and i // 2 == 1:
176
+ x = -denoised - d * sigmas[i+1]
177
+ else:
178
+ x = denoised + d * sigmas[i+1]
179
+
180
+ if sigmas[i + 1] > 0:
181
+ if i // 2 == 1:
182
+ x = dy_sampling_step_cfgpp(x, model, sigma_hat, **extra_args)
183
+
184
+ return x
185
+
186
+ @torch.no_grad()
187
+ def sample_euler_negative_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
188
+ """based on Euler Negative by KoishiStar"""
189
+ extra_args = {} if extra_args is None else extra_args
190
+ model.need_last_noise_uncond = True
191
+ model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
192
+ s_in = x.new_ones([x.shape[0]])
193
+
194
+ if s_churn > 0.0:
195
+ seed = (int(x[0,0,0,0].item()) * 1234567890) % 65536
196
+ generator = torch.Generator(device='cpu').manual_seed(seed)
197
+ else:
198
+ generator = None
199
+
200
+ for i in trange(len(sigmas) - 1, disable=disable):
201
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
202
+ sigma_hat = sigmas[i] * (gamma + 1)
203
+ if gamma > 0:
204
+ eps = torch.randn(x.shape, generator=generator).to(x) * s_noise
205
+ x.add_(eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5)
206
+ denoised = model(x, sigma_hat * s_in, **extra_args)
207
+ d = model.last_noise_uncond
208
+
209
+ if callback is not None:
210
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
211
+
212
+ # Euler method
213
+ if sigmas[i + 1] > 0 and i // 2 == 1:
214
+ x = -denoised - d * sigmas[i+1]
215
+ else:
216
+ x = denoised + d * sigmas[i+1]
217
+ return x
218
+
219
+
220
+ @torch.no_grad()
221
+ def sample_euler_smea_dy_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
222
+ """CFG++ version of Euler SMEA Dy by KoishiStar."""
223
+ extra_args = {} if extra_args is None else extra_args
224
+ model.need_last_noise_uncond = True
225
+ model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
226
+ s_in = x.new_ones([x.shape[0]])
227
+
228
+ if s_churn > 0.0:
229
+ seed = (int(x[0,0,0,0].item()) * 1234567890) % 65536
230
+ generator = torch.Generator(device='cpu').manual_seed(seed)
231
+ else:
232
+ generator = None
233
+
234
+ for i in trange(len(sigmas) - 1, disable=disable):
235
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
236
+ sigma_hat = sigmas[i] * (gamma + 1)
237
+ if gamma > 0:
238
+ eps = torch.randn(x.shape, generator=generator).to(x) * s_noise
239
+ x.add_(eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5)
240
+ denoised = model(x, sigma_hat * s_in, **extra_args)
241
+ d = model.last_noise_uncond
242
+
243
+ if callback is not None:
244
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
245
+
246
+ # Euler method
247
+ x = denoised + d * sigmas[i+1]
248
+
249
+ if sigmas[i + 1] > 0:
250
+ if i + 1 // 2 == 1: # ?? this is i == 1; why not if i // 2 == 1 same as Euler Dy
251
+ x = dy_sampling_step_cfgpp(x, model, sigma_hat, **extra_args)
252
+ if i + 1 // 2 == 0: # ?? this is i == 0
253
+ x = smea_sampling_step_cfgpp(x, model, sigma_hat, **extra_args)
254
+ return x
255
+
256
+ @torch.no_grad()
257
+ def sample_euler_ancestral_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
258
+ """Ancestral sampling with Euler method steps."""
259
+ extra_args = {} if extra_args is None else extra_args
260
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
261
+ model.need_last_noise_uncond = True
262
+ model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
263
+ s_in = x.new_ones([x.shape[0]])
264
+
265
+ for i in trange(len(sigmas) - 1, disable=disable):
266
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
267
+ d = model.last_noise_uncond
268
+
269
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
270
+
271
+ if callback is not None:
272
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
273
+
274
+ # Euler method
275
+ x = denoised + d * sigma_down
276
+ if sigmas[i + 1] > 0:
277
+ x.add_(noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up)
278
+ return x