dikdimon commited on
Commit
1056502
·
verified ·
1 Parent(s): 2fbbd91

Update webUI_ExtraSchedulers/scripts/res_solver.py

Browse files
webUI_ExtraSchedulers/scripts/res_solver.py CHANGED
@@ -1,398 +1,379 @@
1
- import torch
2
- from torch import no_grad, FloatTensor
3
- from tqdm import tqdm
4
- from itertools import pairwise
5
- from typing import Protocol, Optional, Dict, Any, TypedDict, NamedTuple, Union, List
6
- import math
7
-
8
- from tqdm.auto import trange
9
-
10
- # copied from kdiffusion/sampling.py and utils.py
11
- def default_noise_sampler(x):
12
- return lambda sigma, sigma_next: torch.randn_like(x)
13
- def append_dims(x, target_dims):
14
- """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
15
- dims_to_append = target_dims - x.ndim
16
- if dims_to_append < 0:
17
- raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
18
- return x[(...,) + (None,) * dims_to_append]
19
- def to_d(x, sigma, denoised):
20
- """Converts a denoiser output to a Karras ODE derivative."""
21
- return (x - denoised) / append_dims(sigma, x.ndim)
22
-
23
-
24
- class DenoiserModel(Protocol):
25
- def __call__(self, x: FloatTensor, t: FloatTensor, *args, **kwargs) -> FloatTensor: ...
26
-
27
- class RefinedExpCallbackPayload(TypedDict):
28
- x: FloatTensor
29
- i: int
30
- sigma: FloatTensor
31
- sigma_hat: FloatTensor
32
-
33
- class RefinedExpCallback(Protocol):
34
- def __call__(self, payload: RefinedExpCallbackPayload) -> None: ...
35
-
36
- class NoiseSampler(Protocol):
37
- def __call__(self, x: FloatTensor) -> FloatTensor: ...
38
-
39
- class StepOutput(NamedTuple):
40
- x_next: FloatTensor
41
- denoised: FloatTensor
42
- denoised2: FloatTensor
43
- vel: FloatTensor
44
- vel_2: FloatTensor
45
-
46
- def _gamma(
47
- n: int,
48
- ) -> int:
49
- """
50
- https://en.wikipedia.org/wiki/Gamma_function
51
- for every positive integer n,
52
- Γ(n) = (n-1)!
53
- """
54
- return math.factorial(n-1)
55
-
56
- def _incomplete_gamma(
57
- s: int,
58
- x: float,
59
- gamma_s: Optional[int] = None
60
- ) -> float:
61
- """
62
- https://en.wikipedia.org/wiki/Incomplete_gamma_function#Special_values
63
- if s is a positive integer,
64
- Γ(s, x) = (s-1)!*∑{k=0..s-1}(x^k/k!)
65
- """
66
- if gamma_s is None:
67
- gamma_s = _gamma(s)
68
-
69
- sum_: float = 0
70
- # {k=0..s-1} inclusive
71
- for k in range(s):
72
- numerator: float = x**k
73
- denom: int = math.factorial(k)
74
- quotient: float = numerator/denom
75
- sum_ += quotient
76
- incomplete_gamma_: float = sum_ * math.exp(-x) * gamma_s
77
- return incomplete_gamma_
78
-
79
- # by Katherine Crowson
80
- def _phi_1(neg_h: FloatTensor):
81
- return torch.nan_to_num(torch.expm1(neg_h) / neg_h, nan=1.0)
82
-
83
- # by Katherine Crowson
84
- def _phi_2(neg_h: FloatTensor):
85
- return torch.nan_to_num((torch.expm1(neg_h) - neg_h) / neg_h**2, nan=0.5)
86
-
87
- # by Katherine Crowson
88
- def _phi_3(neg_h: FloatTensor):
89
- return torch.nan_to_num((torch.expm1(neg_h) - neg_h - neg_h**2 / 2) / neg_h**3, nan=1 / 6)
90
-
91
- def _phi(
92
- neg_h: float,
93
- j: int,
94
- ):
95
- """
96
- For j={1,2,3}: you could alternatively use Kat's phi_1, phi_2, phi_3 which perform fewer steps
97
-
98
- Lemma 1
99
- https://arxiv.org/abs/2308.02157
100
- ϕj(-h) = 1/h^j*∫{0..h}(e^(τ-h)*(τ^(j-1))/((j-1)!)dτ)
101
-
102
- https://www.wolframalpha.com/input?i=integrate+e%5E%28%CF%84-h%29*%28%CF%84%5E%28j-1%29%2F%28j-1%29%21%29d%CF%84
103
- = 1/h^j*[(e^(-h)*(-τ)^(-j)*τ(j))/((j-1)!)]{0..h}
104
- https://www.wolframalpha.com/input?i=integrate+e%5E%28%CF%84-h%29*%28%CF%84%5E%28j-1%29%2F%28j-1%29%21%29d%CF%84+between+0+and+h
105
- = 1/h^j*((e^(-h)*(-h)^(-j)*h^j*(Γ(j)-Γ(j,-h)))/(j-1)!)
106
- = (e^(-h)*(-h)^(-j)*h^j*(Γ(j)-Γ(j,-h))/((j-1)!*h^j)
107
- = (e^(-h)*(-h)^(-j)*(Γ(j)-Γ(j,-h))/(j-1)!
108
- = (e^(-h)*(-h)^(-j)*(Γ(j)-Γ(j,-h))/Γ(j)
109
- = (e^(-h)*(-h)^(-j)*(1-Γ(j,-h)/Γ(j))
110
-
111
- requires j>0
112
- """
113
- assert j > 0
114
- gamma_: float = _gamma(j)
115
- incomp_gamma_: float = _incomplete_gamma(j, neg_h, gamma_s=gamma_)
116
-
117
- phi_: float = math.exp(neg_h) * neg_h**-j * (1-incomp_gamma_/gamma_)
118
-
119
- return phi_
120
-
121
- class RESDECoeffsSecondOrder(NamedTuple):
122
- a2_1: float
123
- b1: float
124
- b2: float
125
-
126
- def _de_second_order(
127
- h: float,
128
- c2: float,
129
- simple_phi_calc = False,
130
- ) -> RESDECoeffsSecondOrder:
131
- """
132
- Table 3
133
- https://arxiv.org/abs/2308.02157
134
- ϕi,j := ϕi,j(-h) = ϕi(-cj*h)
135
- a2_1 = c2ϕ1,2
136
- = c2ϕ1(-c2*h)
137
- b1 = ϕ1 - ϕ2/c2
138
- """
139
- if simple_phi_calc:
140
- # Kat computed simpler expressions for phi for cases j={1,2,3}
141
- a2_1: float = c2 * _phi_1(-c2*h)
142
- phi1: float = _phi_1(-h)
143
- phi2: float = _phi_2(-h)
144
- else:
145
- # I computed general solution instead.
146
- # they're close, but there are slight differences. not sure which would be more prone to numerical error.
147
- a2_1: float = c2 * _phi(j=1, neg_h=-c2*h)
148
- phi1: float = _phi(j=1, neg_h=-h)
149
- phi2: float = _phi(j=2, neg_h=-h)
150
- phi2_c2: float = phi2/c2
151
- b1: float = phi1 - phi2_c2
152
- b2: float = phi2_c2
153
- return RESDECoeffsSecondOrder(
154
- a2_1=a2_1,
155
- b1=b1,
156
- b2=b2,
157
- )
158
-
159
- def _refined_exp_sosu_step(
160
- model: DenoiserModel,
161
- x: FloatTensor,
162
- sigma: FloatTensor,
163
- sigma_next: FloatTensor,
164
- c2 = 0.5,
165
- extra_args: Dict[str, Any] = {},
166
- pbar: Optional[tqdm] = None,
167
- simple_phi_calc = False,
168
- momentum = 0.0,
169
- vel = None,
170
- vel_2 = None,
171
- time = None
172
- ) -> StepOutput:
173
- """
174
- Algorithm 1 "RES Second order Single Update Step with c2"
175
- https://arxiv.org/abs/2308.02157
176
-
177
- Parameters:
178
- model (`DenoiserModel`): a k-diffusion wrapped denoiser model (e.g. a subclass of DiscreteEpsDDPMDenoiser)
179
- x (`FloatTensor`): noised latents (or RGB I suppose), e.g. torch.randn((B, C, H, W)) * sigma[0]
180
- sigma (`FloatTensor`): timestep to denoise
181
- sigma_next (`FloatTensor`): timestep+1 to denoise
182
- c2 (`float`, *optional*, defaults to .5): partial step size for solving ODE. .5 = midpoint method
183
- extra_args (`Dict[str, Any]`, *optional*, defaults to `{}`): kwargs to pass to `model#__call__()`
184
- pbar (`tqdm`, *optional*, defaults to `None`): progress bar to update after each model call
185
- simple_phi_calc (`bool`, *optional*, defaults to `True`): True = calculate phi_i,j(-h) via simplified formulae specific to j={1,2}. False = Use general solution that works for any j. Mathematically equivalent, but could be numeric differences.
186
- """
187
-
188
- def momentum_func(diff, velocity, timescale=1.0, offset=-momentum / 2.0): # Diff is current diff, vel is previous diff
189
- if velocity is None:
190
- momentum_vel = diff
191
- else:
192
- momentum_vel = momentum * (timescale + offset) * velocity + (1 - momentum * (timescale + offset)) * diff
193
- return momentum_vel
194
-
195
- lam_next, lam = (s.log().neg() for s in (sigma_next, sigma))
196
-
197
- # type hints aren't strictly true regarding float vs FloatTensor.
198
- # everything gets promoted to `FloatTensor` after interacting with `sigma: FloatTensor`.
199
- # I will use float to indicate any variables which are scalars.
200
- h: float = lam_next - lam
201
- a2_1, b1, b2 = _de_second_order(h=h, c2=c2, simple_phi_calc=simple_phi_calc)
202
-
203
- denoised: FloatTensor = model(x, sigma.repeat(x.size(0)), **extra_args)
204
- # if pbar is not None:
205
- # pbar.update(0.5)
206
-
207
- c2_h: float = c2*h
208
-
209
- diff_2 = momentum_func(a2_1*h*denoised, vel_2, time)
210
- vel_2 = diff_2
211
- x_2: FloatTensor = math.exp(-c2_h)*x + diff_2
212
- lam_2: float = lam + c2_h
213
- sigma_2: float = lam_2.neg().exp()
214
-
215
- denoised2: FloatTensor = model(x_2, sigma_2.repeat(x_2.size(0)), **extra_args)
216
- if pbar is not None:
217
- pbar.update()
218
-
219
- diff = momentum_func(h*(b1*denoised + b2*denoised2), vel, time)
220
- vel = diff
221
-
222
- x_next: FloatTensor = math.exp(-h)*x + diff
223
-
224
- return StepOutput(
225
- x_next=x_next,
226
- denoised=denoised,
227
- denoised2=denoised2,
228
- vel=vel,
229
- vel_2=vel_2,
230
- )
231
-
232
-
233
- @no_grad()
234
- def sample_refined_exp_s(
235
- model: FloatTensor,
236
- x: FloatTensor,
237
- sigmas: FloatTensor,
238
- denoise_to_zero: bool = True,
239
- extra_args: Dict[str, Any] = {},
240
- callback: Optional[RefinedExpCallback] = None,
241
- disable: Optional[bool] = None,
242
- ita: FloatTensor = torch.zeros((1,)),
243
- c2 = .5,
244
- noise_sampler: NoiseSampler = torch.randn_like,
245
- simple_phi_calc = False,
246
- momentum = 0.0,
247
- ):
248
- """
249
- Refined Exponential Solver (S).
250
- Algorithm 2 "RES Single-Step Sampler" with Algorithm 1 second-order step
251
- https://arxiv.org/abs/2308.02157
252
-
253
- Parameters:
254
- model (`DenoiserModel`): a k-diffusion wrapped denoiser model (e.g. a subclass of DiscreteEpsDDPMDenoiser)
255
- x (`FloatTensor`): noised latents (or RGB I suppose), e.g. torch.randn((B, C, H, W)) * sigma[0]
256
- sigmas (`FloatTensor`): sigmas (ideally an exponential schedule!) e.g. get_sigmas_exponential(n=25, sigma_min=model.sigma_min, sigma_max=model.sigma_max)
257
- denoise_to_zero (`bool`, *optional*, defaults to `True`): whether to finish with a first-order step down to 0 (rather than stopping at sigma_min). True = fully denoise image. False = match Algorithm 2 in paper
258
- extra_args (`Dict[str, Any]`, *optional*, defaults to `{}`): kwargs to pass to `model#__call__()`
259
- callback (`RefinedExpCallback`, *optional*, defaults to `None`): you can supply this callback to see the intermediate denoising results, e.g. to preview each step of the denoising process
260
- disable (`bool`, *optional*, defaults to `False`): whether to hide `tqdm`'s progress bar animation from being printed
261
- ita (`FloatTensor`, *optional*, defaults to 0.): degree of stochasticity, η, for each timestep. tensor shape must be broadcastable to 1-dimensional tensor with length `len(sigmas) if denoise_to_zero else len(sigmas)-1`. each element should be from 0 to 1.
262
- - if used: batch noise doesn't match non-batch
263
- c2 (`float`, *optional*, defaults to .5): partial step size for solving ODE. .5 = midpoint method
264
- noise_sampler (`NoiseSampler`, *optional*, defaults to `torch.randn_like`): method used for adding noise
265
- simple_phi_calc (`bool`, *optional*, defaults to `True`): True = calculate phi_i,j(-h) via simplified formulae specific to j={1,2}. False = Use general solution that works for any j. Mathematically equivalent, but could be numeric differences.
266
- """
267
- #assert sigmas[-1] == 0
268
- device = x.device
269
- ita = ita.to(device)
270
- sigmas = sigmas.to(device)
271
-
272
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
273
-
274
- vel, vel_2 = None, None
275
- with tqdm(disable=disable, total=len(sigmas)-(1 if denoise_to_zero else 2)) as pbar:
276
- for i, (sigma, sigma_next) in enumerate(pairwise(sigmas[:-1].split(1))):
277
- time = sigmas[i] / sigma_max
278
- if 'sigma' not in locals():
279
- sigma = sigmas[i]
280
- eps = torch.randn_like(x).float()
281
- sigma_hat = sigma * (1 + ita)
282
- x_hat = x + (sigma_hat ** 2 - sigma ** 2).sqrt() * eps
283
- x_next, denoised, denoised2, vel, vel_2 = _refined_exp_sosu_step(
284
- model,
285
- x_hat,
286
- sigma_hat,
287
- sigma_next,
288
- c2=c2,
289
- extra_args=extra_args,
290
- pbar=pbar,
291
- simple_phi_calc=simple_phi_calc,
292
- momentum = momentum,
293
- vel = vel,
294
- vel_2 = vel_2,
295
- time = time
296
- )
297
- if callback is not None:
298
- payload = RefinedExpCallbackPayload(
299
- x=x,
300
- i=i,
301
- sigma=sigma,
302
- sigma_hat=sigma_hat,
303
- denoised=denoised,
304
- denoised2=denoised2,
305
- )
306
- callback(payload)
307
- x = x_next
308
- if denoise_to_zero:
309
- eps = torch.randn_like(x).float()
310
- sigma_hat = sigma * (1 + ita)
311
- x_hat = x + (sigma_hat ** 2 - sigma ** 2).sqrt() * eps
312
- x_next: FloatTensor = model(x_hat, sigma.to(x_hat.device).repeat(x_hat.size(0)), **extra_args)
313
- pbar.update()
314
-
315
- if callback is not None:
316
- payload = RefinedExpCallbackPayload(
317
- x=x,
318
- i=i,
319
- sigma=sigma,
320
- sigma_hat=sigma_hat,
321
- denoised=denoised,
322
- denoised2=denoised2,
323
- )
324
- callback(payload)
325
-
326
-
327
- x = x_next
328
- return x
329
-
330
- # Many thanks to Kat + Birch-San for this wonderful sampler implementation! https://github.com/Birch-san/sdxl-play/commits/res/
331
- def sample_res_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler_type="gaussian", noise_sampler=None, denoise_to_zero=True, simple_phi_calc=False, c2=0.5, ita=torch.Tensor((0.0,)), momentum=0.0):
332
- return sample_refined_exp_s(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, noise_sampler=noise_sampler, denoise_to_zero=denoise_to_zero, simple_phi_calc=simple_phi_calc, c2=c2, ita=ita, momentum=momentum)
333
-
334
-
335
- ## modified from ReForge, original implementation ComfyUI
336
- @torch.no_grad()
337
- def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfgpp=False):
338
- extra_args = {} if extra_args is None else extra_args
339
- seed = extra_args.get("seed", None)
340
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
341
- s_in = x.new_ones([x.shape[0]])
342
- sigma_fn = lambda t: t.neg().exp()
343
- t_fn = lambda sigma: sigma.log().neg()
344
- phi1_fn = lambda t: torch.expm1(t) / t
345
- phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
346
- old_denoised = None
347
-
348
- sigmas = sigmas.to(x.device)
349
-
350
- if cfgpp:
351
- model.need_last_noise_uncond = True
352
- model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
353
-
354
- for i in trange(len(sigmas) - 1, disable=disable):
355
- if s_churn > 0:
356
- gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
357
- sigma_hat = sigmas[i] * (gamma + 1)
358
- else:
359
- gamma = 0
360
- sigma_hat = sigmas[i]
361
- if gamma > 0:
362
- eps = torch.randn_like(x) * s_noise
363
- x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
364
- denoised = model(x, sigma_hat * s_in, **extra_args)
365
-
366
- if callback is not None:
367
- callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
368
- if sigmas[i + 1] == 0 or old_denoised is None:
369
- # Euler method
370
- if cfgpp:
371
- d = model.last_noise_uncond
372
- x = denoised + d * sigmas[i + 1]
373
- else:
374
- d = to_d(x, sigma_hat, denoised)
375
- dt = sigmas[i + 1] - sigma_hat
376
- x = x + d * dt
377
- else:
378
- # Second order multistep method in https://arxiv.org/pdf/2308.02157
379
- t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
380
- h = t_next - t
381
- c2 = (t_prev - t) / h
382
- phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
383
- b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
384
- b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
385
- if cfgpp:
386
- d = model.last_noise_uncond
387
- x = denoised + d * sigma_hat
388
-
389
- x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
390
- old_denoised = denoised
391
- return x
392
- @torch.no_grad()
393
- def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
394
- return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfgpp=False)
395
- @torch.no_grad()
396
- def sample_res_multistep_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
397
- return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfgpp=True)
398
-
 
1
+ import torch
2
+ from torch import FloatTensor
3
+ from tqdm import tqdm
4
+ from itertools import pairwise
5
+ from typing import Protocol, Optional, Dict, Any, TypedDict, NamedTuple
6
+ import math
7
+
8
+ from tqdm.auto import trange
9
+
10
+ from k_diffusion.sampling import (
11
+ default_noise_sampler,
12
+ to_d,
13
+ )
14
+
15
+ class DenoiserModel(Protocol):
16
+ def __call__(self, x: FloatTensor, t: FloatTensor, *args, **kwargs) -> FloatTensor: ...
17
+
18
+ class RefinedExpCallbackPayload(TypedDict):
19
+ x: FloatTensor
20
+ i: int
21
+ sigma: FloatTensor
22
+ sigma_hat: FloatTensor
23
+
24
+ class RefinedExpCallback(Protocol):
25
+ def __call__(self, payload: RefinedExpCallbackPayload) -> None: ...
26
+
27
+ class NoiseSampler(Protocol):
28
+ def __call__(self, x: FloatTensor) -> FloatTensor: ...
29
+
30
+ class StepOutput(NamedTuple):
31
+ x_next: FloatTensor
32
+ denoised: FloatTensor
33
+ denoised2: FloatTensor
34
+ vel: FloatTensor
35
+ vel_2: FloatTensor
36
+
37
+ def _gamma(n: int,) -> int:
38
+ """
39
+ https://en.wikipedia.org/wiki/Gamma_function
40
+ for every positive integer n,
41
+ Γ(n) = (n-1)!
42
+ """
43
+ return math.factorial(n-1)
44
+
45
+ def _incomplete_gamma(s: int, x: float, gamma_s: Optional[int] = None) -> float:
46
+ """
47
+ https://en.wikipedia.org/wiki/Incomplete_gamma_function#Special_values
48
+ if s is a positive integer,
49
+ Γ(s, x) = (s-1)!*∑{k=0..s-1}(x^k/k!)
50
+ """
51
+ if gamma_s is None:
52
+ gamma_s = _gamma(s)
53
+
54
+ sum_: float = 0
55
+ # {k=0..s-1} inclusive
56
+ for k in range(s):
57
+ numerator: float = x**k
58
+ denom: int = math.factorial(k)
59
+ quotient: float = numerator/denom
60
+ sum_ += quotient
61
+ incomplete_gamma_: float = sum_ * math.exp(-x) * gamma_s
62
+ return incomplete_gamma_
63
+
64
+ # by Katherine Crowson
65
+ def _phi_1(neg_h: FloatTensor):
66
+ return torch.nan_to_num(torch.expm1(neg_h) / neg_h, nan=1.0)
67
+
68
+ # by Katherine Crowson
69
+ def _phi_2(neg_h: FloatTensor):
70
+ return torch.nan_to_num((torch.expm1(neg_h) - neg_h) / neg_h**2, nan=0.5)
71
+
72
+ # by Katherine Crowson
73
+ def _phi_3(neg_h: FloatTensor):
74
+ return torch.nan_to_num((torch.expm1(neg_h) - neg_h - neg_h**2 / 2) / neg_h**3, nan=1 / 6)
75
+
76
+ def _phi(neg_h: float, j: int,):
77
+ """
78
+ For j={1,2,3}: you could alternatively use Kat's phi_1, phi_2, phi_3 which perform fewer steps
79
+
80
+ Lemma 1
81
+ https://arxiv.org/abs/2308.02157
82
+ ϕj(-h) = 1/h^j*∫{0..h}(e^(τ-h)*(τ^(j-1))/((j-1)!)dτ)
83
+
84
+ https://www.wolframalpha.com/input?i=integrate+e%5E%28%CF%84-h%29*%28%CF%84%5E%28j-1%29%2F%28j-1%29%21%29d%CF%84
85
+ = 1/h^j*[(e^(-h)*()^(-j)*τ(j))/((j-1)!)]{0..h}
86
+ https://www.wolframalpha.com/input?i=integrate+e%5E%28%CF%84-h%29*%28%CF%84%5E%28j-1%29%2F%28j-1%29%21%29d%CF%84+between+0+and+h
87
+ = 1/h^j*((e^(-h)*(-h)^(-j)*h^j*(Γ(j)-Γ(j,-h)))/(j-1)!)
88
+ = (e^(-h)*(-h)^(-j)*h^j*(Γ(j)-Γ(j,-h))/((j-1)!*h^j)
89
+ = (e^(-h)*(-h)^(-j)*(Γ(j)-Γ(j,-h))/(j-1)!
90
+ = (e^(-h)*(-h)^(-j)*(Γ(j)-Γ(j,-h))/Γ(j)
91
+ = (e^(-h)*(-h)^(-j)*(1-Γ(j,-h)/Γ(j))
92
+
93
+ requires j>0
94
+ """
95
+ assert j > 0
96
+ gamma_: float = _gamma(j)
97
+ incomp_gamma_: float = _incomplete_gamma(j, neg_h, gamma_s=gamma_)
98
+
99
+ phi_: float = math.exp(neg_h) * neg_h**-j * (1-incomp_gamma_/gamma_)
100
+
101
+ return phi_
102
+
103
+ class RESDECoeffsSecondOrder(NamedTuple):
104
+ a2_1: float
105
+ b1: float
106
+ b2: float
107
+
108
+ def _de_second_order(h: float, c2: float, simple_phi_calc=False,) -> RESDECoeffsSecondOrder:
109
+ """
110
+ Table 3
111
+ https://arxiv.org/abs/2308.02157
112
+ ϕi,j := ϕi,j(-h) = ϕi(-cj*h)
113
+ a2_1 = c2ϕ1,2
114
+ = c2ϕ1(-c2*h)
115
+ b1 = ϕ1 - ϕ2/c2
116
+ """
117
+ if simple_phi_calc:
118
+ # Kat computed simpler expressions for phi for cases j={1,2,3}
119
+ a2_1: float = _phi_1(-c2*h) * c2
120
+ phi1: float = _phi_1(-h)
121
+ phi2: float = _phi_2(-h)
122
+ else:
123
+ # I computed general solution instead.
124
+ # they're close, but there are slight differences. not sure which would be more prone to numerical error.
125
+ a2_1: float = _phi(j=1, neg_h=-c2*h) * c2
126
+ phi1: float = _phi(j=1, neg_h=-h)
127
+ phi2: float = _phi(j=2, neg_h=-h)
128
+ phi2_c2: float = phi2/c2
129
+ b1: float = phi1 - phi2_c2
130
+ b2: float = phi2_c2
131
+ return RESDECoeffsSecondOrder(
132
+ a2_1=a2_1,
133
+ b1=b1,
134
+ b2=b2,
135
+ )
136
+
137
+ def _refined_exp_sosu_step(
138
+ model: DenoiserModel,
139
+ x: FloatTensor,
140
+ sigma: FloatTensor,
141
+ sigma_next: FloatTensor,
142
+ c2 = 0.5,
143
+ extra_args: Dict[str, Any] = {},
144
+ pbar: Optional[tqdm] = None,
145
+ simple_phi_calc = False,
146
+ momentum = 0.0,
147
+ vel = None,
148
+ vel_2 = None,
149
+ time = None
150
+ ) -> StepOutput:
151
+ """
152
+ Algorithm 1 "RES Second order Single Update Step with c2"
153
+ https://arxiv.org/abs/2308.02157
154
+
155
+ Parameters:
156
+ model (`DenoiserModel`): a k-diffusion wrapped denoiser model (e.g. a subclass of DiscreteEpsDDPMDenoiser)
157
+ x (`FloatTensor`): noised latents (or RGB I suppose), e.g. torch.randn((B, C, H, W)) * sigma[0]
158
+ sigma (`FloatTensor`): timestep to denoise
159
+ sigma_next (`FloatTensor`): timestep+1 to denoise
160
+ c2 (`float`, *optional*, defaults to .5): partial step size for solving ODE. .5 = midpoint method
161
+ extra_args (`Dict[str, Any]`, *optional*, defaults to `{}`): kwargs to pass to `model#__call__()`
162
+ pbar (`tqdm`, *optional*, defaults to `None`): progress bar to update after each model call
163
+ simple_phi_calc (`bool`, *optional*, defaults to `True`): True = calculate phi_i,j(-h) via simplified formulae specific to j={1,2}. False = Use general solution that works for any j. Mathematically equivalent, but could be numeric differences.
164
+ """
165
+
166
+ def momentum_func(diff, velocity, timescale=1.0, offset=-momentum / 2.0): # Diff is current diff, vel is previous diff
167
+ if velocity is None:
168
+ momentum_vel = diff
169
+ else:
170
+ momentum_vel = momentum * (timescale + offset) * velocity + (1 - momentum * (timescale + offset)) * diff
171
+ return momentum_vel
172
+
173
+ lam_next, lam = (s.log().neg() for s in (sigma_next, sigma))
174
+
175
+ # type hints aren't strictly true regarding float vs FloatTensor.
176
+ # everything gets promoted to `FloatTensor` after interacting with `sigma: FloatTensor`.
177
+ # I will use float to indicate any variables which are scalars.
178
+ h: float = lam_next - lam
179
+ a2_1, b1, b2 = _de_second_order(h=h, c2=c2, simple_phi_calc=simple_phi_calc)
180
+
181
+ denoised: FloatTensor = model(x, sigma.repeat(x.size(0)), **extra_args)
182
+
183
+ c2_h: float = c2*h
184
+
185
+ diff_2 = momentum_func(a2_1*h*denoised, vel_2, time)
186
+ vel_2 = diff_2
187
+ x_2: FloatTensor = math.exp(-c2_h)*x + diff_2
188
+ lam_2: float = lam + c2_h
189
+ sigma_2: float = lam_2.neg().exp()
190
+
191
+ denoised2: FloatTensor = model(x_2, sigma_2.repeat(x_2.size(0)), **extra_args)
192
+ if pbar is not None:
193
+ pbar.update()
194
+
195
+ diff = momentum_func(h*(b1*denoised + b2*denoised2), vel, time)
196
+ vel = diff
197
+
198
+ x_next: FloatTensor = math.exp(-h)*x + diff
199
+
200
+ return StepOutput(
201
+ x_next=x_next,
202
+ denoised=denoised,
203
+ denoised2=denoised2,
204
+ vel=vel,
205
+ vel_2=vel_2,
206
+ )
207
+
208
+
209
+ @torch.no_grad()
210
+ def sample_refined_exp_s(
211
+ model: FloatTensor,
212
+ x: FloatTensor,
213
+ sigmas: FloatTensor,
214
+ denoise_to_zero: bool = True,
215
+ extra_args: Dict[str, Any] = {},
216
+ callback: Optional[RefinedExpCallback] = None,
217
+ disable: Optional[bool] = None,
218
+ ita: FloatTensor = torch.zeros((1,)),
219
+ c2 = .5,
220
+ noise_sampler: NoiseSampler = default_noise_sampler,
221
+ simple_phi_calc = False,
222
+ momentum = 0.0,
223
+ ):
224
+ """
225
+ Refined Exponential Solver (S).
226
+ Algorithm 2 "RES Single-Step Sampler" with Algorithm 1 second-order step
227
+ https://arxiv.org/abs/2308.02157
228
+
229
+ Parameters:
230
+ model (`DenoiserModel`): a k-diffusion wrapped denoiser model (e.g. a subclass of DiscreteEpsDDPMDenoiser)
231
+ x (`FloatTensor`): noised latents (or RGB I suppose), e.g. torch.randn((B, C, H, W)) * sigma[0]
232
+ sigmas (`FloatTensor`): sigmas (ideally an exponential schedule!) e.g. get_sigmas_exponential(n=25, sigma_min=model.sigma_min, sigma_max=model.sigma_max)
233
+ denoise_to_zero (`bool`, *optional*, defaults to `True`): whether to finish with a first-order step down to 0 (rather than stopping at sigma_min). True = fully denoise image. False = match Algorithm 2 in paper
234
+ extra_args (`Dict[str, Any]`, *optional*, defaults to `{}`): kwargs to pass to `model#__call__()`
235
+ callback (`RefinedExpCallback`, *optional*, defaults to `None`): you can supply this callback to see the intermediate denoising results, e.g. to preview each step of the denoising process
236
+ disable (`bool`, *optional*, defaults to `False`): whether to hide `tqdm`'s progress bar animation from being printed
237
+ ita (`FloatTensor`, *optional*, defaults to 0.): degree of stochasticity, η, for each timestep. tensor shape must be broadcastable to 1-dimensional tensor with length `len(sigmas) if denoise_to_zero else len(sigmas)-1`. each element should be from 0 to 1.
238
+ - if used: batch noise doesn't match non-batch
239
+ c2 (`float`, *optional*, defaults to .5): partial step size for solving ODE. .5 = midpoint method
240
+ noise_sampler (`NoiseSampler`, *optional*, defaults to `torch.randn_like`): method used for adding noise
241
+ simple_phi_calc (`bool`, *optional*, defaults to `True`): True = calculate phi_i,j(-h) via simplified formulae specific to j={1,2}. False = Use general solution that works for any j. Mathematically equivalent, but could be numeric differences.
242
+ """
243
+ #assert sigmas[-1] == 0
244
+ device = x.device
245
+ ita = ita.to(device)
246
+ sigmas = sigmas.to(device)
247
+
248
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
249
+
250
+ seed = (int(x[0,0,0,0].item()) * 1234567890) % 65536
251
+ generator = torch.Generator(device='cpu').manual_seed(seed)
252
+
253
+ vel, vel_2 = None, None
254
+ with tqdm(disable=disable, total=len(sigmas)-(1 if denoise_to_zero else 2)) as pbar:
255
+ for i, (sigma, sigma_next) in enumerate(pairwise(sigmas[:-1].split(1))):
256
+ time = sigmas[i] / sigma_max
257
+ if 'sigma' not in locals():
258
+ sigma = sigmas[i]
259
+ eps = torch.randn(x.shape, generator=generator).to(x)
260
+ sigma_hat = sigma * (1 + ita)
261
+ x_hat = x + (sigma_hat ** 2 - sigma ** 2).sqrt() * eps
262
+ x_next, denoised, denoised2, vel, vel_2 = _refined_exp_sosu_step(
263
+ model,
264
+ x_hat,
265
+ sigma_hat,
266
+ sigma_next,
267
+ c2=c2,
268
+ extra_args=extra_args,
269
+ pbar=pbar,
270
+ simple_phi_calc=simple_phi_calc,
271
+ momentum = momentum,
272
+ vel = vel,
273
+ vel_2 = vel_2,
274
+ time = time
275
+ )
276
+ if callback is not None:
277
+ payload = RefinedExpCallbackPayload(
278
+ x=x,
279
+ i=i,
280
+ sigma=sigma,
281
+ sigma_hat=sigma_hat,
282
+ denoised=denoised,
283
+ denoised2=denoised2,
284
+ )
285
+ callback(payload)
286
+ x = x_next
287
+ if denoise_to_zero:
288
+ eps = torch.randn(x.shape, generator=generator).to(x)
289
+ sigma_hat = sigma * (1 + ita)
290
+ x_hat = x + (sigma_hat ** 2 - sigma ** 2).sqrt() * eps
291
+ x_next: FloatTensor = model(x_hat, sigma.to(x_hat.device).repeat(x_hat.size(0)), **extra_args)
292
+ pbar.update()
293
+
294
+ if callback is not None:
295
+ payload = RefinedExpCallbackPayload(
296
+ x=x,
297
+ i=i,
298
+ sigma=sigma,
299
+ sigma_hat=sigma_hat,
300
+ denoised=denoised,
301
+ denoised2=denoised2,
302
+ )
303
+ callback(payload)
304
+ x = x_next
305
+ return x
306
+
307
+ # Many thanks to Kat + Birch-San for this wonderful sampler implementation! https://github.com/Birch-san/sdxl-play/commits/res/
308
+ def sample_res_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler_type="gaussian", noise_sampler=None, denoise_to_zero=True, simple_phi_calc=False, c2=0.5, ita=torch.Tensor((0.0,)), momentum=0.0):
309
+ return sample_refined_exp_s(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, noise_sampler=noise_sampler, denoise_to_zero=denoise_to_zero, simple_phi_calc=simple_phi_calc, c2=c2, ita=ita, momentum=momentum)
310
+
311
+
312
+ ## modified from ReForge, original implementation ComfyUI
313
+ @torch.no_grad()
314
+ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfgpp=False):
315
+ extra_args = {} if extra_args is None else extra_args
316
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
317
+ s_in = x.new_ones([x.shape[0]])
318
+ sigma_fn = lambda t: t.neg().exp()
319
+ t_fn = lambda sigma: sigma.log().neg()
320
+ phi1_fn = lambda t: torch.expm1(t) / t
321
+ phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
322
+ old_denoised = None
323
+
324
+ sigmas = sigmas.to(x.device)
325
+
326
+ if s_churn > 0.0:
327
+ seed = (int(x[0,0,0,0].item()) * 1234567890) % 65536
328
+ generator = torch.Generator(device='cpu').manual_seed(seed)
329
+ else:
330
+ generator = None
331
+
332
+ if cfgpp:
333
+ model.need_last_noise_uncond = True
334
+ model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
335
+
336
+ for i in trange(len(sigmas) - 1, disable=disable):
337
+ if s_churn > 0:
338
+ gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
339
+ sigma_hat = sigmas[i] * (gamma + 1)
340
+ else:
341
+ gamma = 0
342
+ sigma_hat = sigmas[i]
343
+ if gamma > 0:
344
+ eps = torch.randn(x.shape, generator=generator).to(x) * s_noise
345
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
346
+ denoised = model(x, sigma_hat * s_in, **extra_args)
347
+
348
+ if callback is not None:
349
+ callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
350
+ if sigmas[i + 1] == 0 or old_denoised is None:
351
+ # Euler method
352
+ if cfgpp:
353
+ d = model.last_noise_uncond
354
+ x = denoised + d * sigmas[i + 1]
355
+ else:
356
+ d = to_d(x, sigma_hat, denoised)
357
+ dt = sigmas[i + 1] - sigma_hat
358
+ x = x + d * dt
359
+ else:
360
+ # Second order multistep method in https://arxiv.org/pdf/2308.02157
361
+ t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
362
+ h = t_next - t
363
+ c2 = (t_prev - t) / h
364
+ phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
365
+ b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
366
+ b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
367
+ if cfgpp:
368
+ d = model.last_noise_uncond
369
+ x = denoised + d * sigma_hat
370
+
371
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
372
+ old_denoised = denoised
373
+ return x
374
+ @torch.no_grad()
375
+ def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
376
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfgpp=False)
377
+ @torch.no_grad()
378
+ def sample_res_multistep_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
379
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfgpp=True)