dikdimon commited on
Commit
02f9345
·
verified ·
1 Parent(s): 079899a

Upload sd_samplers_kdiffusion.py using SD-Hub

Browse files
Files changed (1) hide show
  1. sd_samplers_kdiffusion.py +275 -0
sd_samplers_kdiffusion.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import inspect
3
+ import k_diffusion.sampling
4
+ from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
5
+ from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
6
+ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
7
+
8
+ from modules.shared import opts
9
+ import modules.shared as shared
10
+
11
+ samplers_k_diffusion = [
12
+ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
13
+ ('DPM++ 2M Karras Sharp v1', 'sample_dpmpp_2m_v1', ['k_dpmpp_2m_ka_v1'], {'scheduler': 'karras'}),
14
+ ('DPM++ 2M Test', 'sample_dpmpp_2m_test', ['k_dpmpp_2m'], {}),
15
+ ('DPM++ 2M Karras Test', 'sample_dpmpp_2m_test', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
16
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
17
+ ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
18
+ ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
19
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
20
+ ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
21
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
22
+ ('Euler', 'sample_euler', ['k_euler'], {}),
23
+ ('LMS', 'sample_lms', ['k_lms'], {}),
24
+ ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
25
+ ('Heun++', 'sample_heunpp2', ['heunpp2'], {}),
26
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
27
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
28
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
29
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
30
+ ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
31
+ ]
32
+
33
+
34
+ samplers_data_k_diffusion = [
35
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
36
+ for label, funcname, aliases, options in samplers_k_diffusion
37
+ if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
38
+ ]
39
+
40
+ from tqdm.auto import trange
41
+
42
+ @torch.no_grad()
43
+ def sample_dpmpp_2m_alt(model, x, sigmas, extra_args=None, callback=None, disable=None):
44
+ """DPM-Solver++(2M)."""
45
+ extra_args = {} if extra_args is None else extra_args
46
+ s_in = x.new_ones([x.shape[0]])
47
+ sigma_fn = lambda t: t.neg().exp()
48
+ t_fn = lambda sigma: sigma.log().neg()
49
+ old_denoised = None
50
+
51
+ for i in trange(len(sigmas) - 1, disable=disable):
52
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
53
+ if callback is not None:
54
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
55
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
56
+ h = t_next - t
57
+ if old_denoised is None or sigmas[i + 1] == 0:
58
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
59
+ else:
60
+ h_last = t - t_fn(sigmas[i - 1])
61
+ r = h_last / h
62
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
63
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
64
+ sigma_progress = i / len(sigmas)
65
+ adjustment_factor = 1 + (0.15 * (sigma_progress * sigma_progress))
66
+ old_denoised = denoised * adjustment_factor
67
+ return x
68
+
69
+ k_diffusion.sampling.sample_dpmpp_2m_alt = sample_dpmpp_2m_alt
70
+
71
+ samplers_data_k_diffusion.insert(9, sd_samplers_common.SamplerData('DPM++ 2M alt', lambda model: KDiffusionSampler('sample_dpmpp_2m_alt', model), ['k_dpmpp_2m_alt'], {}))
72
+ samplers_data_k_diffusion.insert(10, sd_samplers_common.SamplerData('DPM++ 2M alt Karras', lambda model: KDiffusionSampler('sample_dpmpp_2m_alt', model), ['k_dpmpp_2m_alt_ka'], {'scheduler': 'karras'}))
73
+
74
+
75
+ sampler_extra_params = {
76
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
77
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
78
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
79
+ 'sample_dpm_fast': ['s_noise'],
80
+ 'sample_dpm_2_ancestral': ['s_noise'],
81
+ 'sample_dpmpp_2s_ancestral': ['s_noise'],
82
+ 'sample_dpmpp_sde': ['s_noise'],
83
+ 'sample_dpmpp_2m_sde': ['s_noise'],
84
+ 'sample_dpmpp_3m_sde': ['s_noise'],
85
+ }
86
+
87
+ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
88
+ k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
89
+
90
+
91
+ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
92
+ @property
93
+ def inner_model(self):
94
+ if self.model_wrap is None:
95
+ denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
96
+
97
+ if denoiser_constructor is not None:
98
+ self.model_wrap = denoiser_constructor()
99
+ else:
100
+ denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
101
+ self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
102
+
103
+ return self.model_wrap
104
+
105
+
106
+ class KDiffusionSampler(sd_samplers_common.Sampler):
107
+ def __init__(self, funcname, sd_model, options=None):
108
+ super().__init__(funcname)
109
+
110
+ self.extra_params = sampler_extra_params.get(funcname, [])
111
+
112
+ self.options = options or {}
113
+ self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
114
+
115
+ self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
116
+ self.model_wrap = self.model_wrap_cfg.inner_model
117
+
118
+ def get_sigmas(self, p, steps):
119
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
120
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
121
+ discard_next_to_last_sigma = True
122
+ p.extra_generation_params["Discard penultimate sigma"] = True
123
+
124
+ steps += 1 if discard_next_to_last_sigma else 0
125
+
126
+ scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
127
+ if scheduler_name == 'Automatic':
128
+ scheduler_name = self.config.options.get('scheduler', None)
129
+
130
+ scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
131
+
132
+ m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
133
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
134
+
135
+ if p.sampler_noise_scheduler_override:
136
+ sigmas = p.sampler_noise_scheduler_override(steps)
137
+ elif scheduler is None or scheduler.function is None:
138
+ sigmas = self.model_wrap.get_sigmas(steps)
139
+ else:
140
+ sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
141
+
142
+ if scheduler.label != 'Automatic' and not p.is_hr_pass:
143
+ p.extra_generation_params["Schedule type"] = scheduler.label
144
+ elif scheduler.label != p.extra_generation_params.get("Schedule type"):
145
+ p.extra_generation_params["Hires schedule type"] = scheduler.label
146
+
147
+ if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
148
+ sigmas_kwargs['sigma_min'] = opts.sigma_min
149
+ p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
150
+
151
+ if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max:
152
+ sigmas_kwargs['sigma_max'] = opts.sigma_max
153
+ p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
154
+
155
+ if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho:
156
+ sigmas_kwargs['rho'] = opts.rho
157
+ p.extra_generation_params["Schedule rho"] = opts.rho
158
+
159
+ if scheduler.need_inner_model:
160
+ sigmas_kwargs['inner_model'] = self.model_wrap
161
+
162
+ if scheduler.label == 'Beta':
163
+ p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
164
+ p.extra_generation_params["Beta schedule beta"] = opts.beta_dist_beta
165
+
166
+ sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu)
167
+
168
+ if discard_next_to_last_sigma:
169
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
170
+
171
+ return sigmas.cpu()
172
+
173
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
174
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
175
+
176
+ sigmas = self.get_sigmas(p, steps)
177
+ sigma_sched = sigmas[steps - t_enc - 1:]
178
+
179
+ if hasattr(shared.sd_model, 'add_noise_to_latent'):
180
+ xi = shared.sd_model.add_noise_to_latent(x, noise, sigma_sched[0])
181
+ else:
182
+ xi = x + noise * sigma_sched[0]
183
+
184
+ if opts.img2img_extra_noise > 0:
185
+ p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
186
+ extra_noise_params = ExtraNoiseParams(noise, x, xi)
187
+ extra_noise_callback(extra_noise_params)
188
+ noise = extra_noise_params.noise
189
+ xi += noise * opts.img2img_extra_noise
190
+
191
+ extra_params_kwargs = self.initialize(p)
192
+ parameters = inspect.signature(self.func).parameters
193
+
194
+ if 'sigma_min' in parameters:
195
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
196
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
197
+ if 'sigma_max' in parameters:
198
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
199
+ if 'n' in parameters:
200
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
201
+ if 'sigma_sched' in parameters:
202
+ extra_params_kwargs['sigma_sched'] = sigma_sched
203
+ if 'sigmas' in parameters:
204
+ extra_params_kwargs['sigmas'] = sigma_sched
205
+
206
+ if self.config.options.get('brownian_noise', False):
207
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
208
+ extra_params_kwargs['noise_sampler'] = noise_sampler
209
+
210
+ if self.config.options.get('solver_type', None) == 'heun':
211
+ extra_params_kwargs['solver_type'] = 'heun'
212
+
213
+ self.model_wrap_cfg.init_latent = x
214
+ self.last_latent = x
215
+ self.sampler_extra_args = {
216
+ 'cond': conditioning,
217
+ 'image_cond': image_conditioning,
218
+ 'uncond': unconditional_conditioning,
219
+ 'cond_scale': p.cfg_scale,
220
+ 's_min_uncond': self.s_min_uncond
221
+ }
222
+
223
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
224
+
225
+ self.add_infotext(p)
226
+
227
+ return samples
228
+
229
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
230
+ steps = steps or p.steps
231
+
232
+ sigmas = self.get_sigmas(p, steps)
233
+
234
+ if opts.sgm_noise_multiplier:
235
+ p.extra_generation_params["SGM noise multiplier"] = True
236
+ x = x * torch.sqrt(1.0 + sigmas[0] ** 2.0)
237
+ else:
238
+ x = x * sigmas[0]
239
+
240
+ extra_params_kwargs = self.initialize(p)
241
+ parameters = inspect.signature(self.func).parameters
242
+
243
+ if 'n' in parameters:
244
+ extra_params_kwargs['n'] = steps
245
+
246
+ if 'sigma_min' in parameters:
247
+ extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
248
+ extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
249
+
250
+ if 'sigmas' in parameters:
251
+ extra_params_kwargs['sigmas'] = sigmas
252
+
253
+ if self.config.options.get('brownian_noise', False):
254
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
255
+ extra_params_kwargs['noise_sampler'] = noise_sampler
256
+
257
+ if self.config.options.get('solver_type', None) == 'heun':
258
+ extra_params_kwargs['solver_type'] = 'heun'
259
+
260
+ self.last_latent = x
261
+ self.sampler_extra_args = {
262
+ 'cond': conditioning,
263
+ 'image_cond': image_conditioning,
264
+ 'uncond': unconditional_conditioning,
265
+ 'cond_scale': p.cfg_scale,
266
+ 's_min_uncond': self.s_min_uncond
267
+ }
268
+
269
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
270
+
271
+ self.add_infotext(p)
272
+
273
+ return samples
274
+
275
+