BiliSakura commited on
Commit
044c88b
·
verified ·
1 Parent(s): efc1ede

Delete ADM-G-512/scheduler/scheduling_adm.py

Browse files
Files changed (1) hide show
  1. ADM-G-512/scheduler/scheduling_adm.py +0 -590
ADM-G-512/scheduler/scheduling_adm.py DELETED
@@ -1,590 +0,0 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
-
6
- import enum
7
- import math
8
- from dataclasses import dataclass
9
- from typing import Optional, Tuple, Union
10
-
11
- import numpy as np
12
- import torch
13
-
14
- from diffusers.configuration_utils import ConfigMixin, register_to_config
15
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
16
- from diffusers.utils import BaseOutput
17
-
18
- try:
19
- from diffusers.utils.torch_utils import randn_tensor
20
- except ImportError: # pragma: no cover
21
- def randn_tensor(shape, generator=None, device=None, dtype=None):
22
- return torch.randn(shape, generator=generator, device=device, dtype=dtype)
23
-
24
-
25
- # ---------------------------------------------------------------------------
26
- # Internal diffusion math (OpenAI ADM / improved-diffusion)
27
- # ---------------------------------------------------------------------------
28
-
29
-
30
- def _randn_like(tensor: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
31
- return randn_tensor(tensor.shape, generator=generator, device=tensor.device, dtype=tensor.dtype)
32
-
33
-
34
- def _extract_into_tensor(arr, timesteps, broadcast_shape):
35
- res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
36
- while len(res.shape) < len(broadcast_shape):
37
- res = res[..., None]
38
- return res.expand(broadcast_shape)
39
-
40
-
41
- def _get_named_beta_schedule(schedule_name: str, num_diffusion_timesteps: int):
42
- if schedule_name == "linear":
43
- scale = 1000 / num_diffusion_timesteps
44
- return np.linspace(scale * 0.0001, scale * 0.02, num_diffusion_timesteps, dtype=np.float64)
45
- if schedule_name == "cosine":
46
- return _betas_for_alpha_bar(
47
- num_diffusion_timesteps,
48
- lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
49
- )
50
- raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
51
-
52
-
53
- def _betas_for_alpha_bar(num_diffusion_timesteps: int, alpha_bar, max_beta: float = 0.999):
54
- betas = []
55
- for i in range(num_diffusion_timesteps):
56
- t1 = i / num_diffusion_timesteps
57
- t2 = (i + 1) / num_diffusion_timesteps
58
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
59
- return np.array(betas)
60
-
61
-
62
- def _space_timesteps(num_timesteps: int, section_counts):
63
- if isinstance(section_counts, str):
64
- if section_counts.startswith("ddim"):
65
- desired_count = int(section_counts[len("ddim") :])
66
- for i in range(1, num_timesteps):
67
- if len(range(0, num_timesteps, i)) == desired_count:
68
- return set(range(0, num_timesteps, i))
69
- raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
70
- section_counts = [int(x) for x in section_counts.split(",")]
71
-
72
- size_per = num_timesteps // len(section_counts)
73
- extra = num_timesteps % len(section_counts)
74
- start_idx = 0
75
- all_steps = []
76
- for i, section_count in enumerate(section_counts):
77
- size = size_per + (1 if i < extra else 0)
78
- if size < section_count:
79
- raise ValueError(f"cannot divide section of {size} steps into {section_count}")
80
- frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
81
- cur_idx = 0.0
82
- for _ in range(section_count):
83
- all_steps.append(start_idx + round(cur_idx))
84
- cur_idx += frac_stride
85
- start_idx += size
86
- return set(all_steps)
87
-
88
-
89
- class _ModelMeanType(enum.Enum):
90
- PREVIOUS_X = enum.auto()
91
- START_X = enum.auto()
92
- EPSILON = enum.auto()
93
-
94
-
95
- class _ModelVarType(enum.Enum):
96
- LEARNED = enum.auto()
97
- FIXED_SMALL = enum.auto()
98
- FIXED_LARGE = enum.auto()
99
- LEARNED_RANGE = enum.auto()
100
-
101
-
102
- class _GaussianDiffusion:
103
- def __init__(self, *, betas, model_mean_type, model_var_type, rescale_timesteps: bool = False):
104
- self.model_mean_type = model_mean_type
105
- self.model_var_type = model_var_type
106
- self.rescale_timesteps = rescale_timesteps
107
- betas = np.array(betas, dtype=np.float64)
108
- self.betas = betas
109
- self.num_timesteps = int(betas.shape[0])
110
-
111
- alphas = 1.0 - betas
112
- self.alphas_cumprod = np.cumprod(alphas, axis=0)
113
- self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
114
- self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
115
- self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
116
- self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
117
- self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
118
- self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
119
- self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
120
-
121
- def _predict_xstart_from_eps(self, x_t, t, eps):
122
- return _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(
123
- self.sqrt_recipm1_alphas_cumprod, t, x_t.shape
124
- ) * eps
125
-
126
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
127
- return (
128
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
129
- ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
130
-
131
- def _predict_xstart_from_xprev(self, x_t, t, xprev):
132
- return _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor(
133
- self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
134
- ) * x_t
135
-
136
- def q_posterior_mean_variance(self, x_start, x_t, t):
137
- posterior_mean = _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(
138
- self.posterior_mean_coef2, t, x_t.shape
139
- ) * x_t
140
- posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
141
- posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
142
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
143
-
144
- def p_mean_variance_from_output(
145
- self,
146
- model_output: torch.Tensor,
147
- x: torch.Tensor,
148
- t: torch.Tensor,
149
- clip_denoised: bool = True,
150
- ):
151
- _, c = x.shape[:2]
152
-
153
- if self.model_var_type == _ModelVarType.LEARNED_RANGE:
154
- model_output, model_var_values = torch.split(model_output, c, dim=1)
155
- min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
156
- max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
157
- frac = (model_var_values + 1) / 2
158
- model_log_variance = frac * max_log + (1 - frac) * min_log
159
- model_variance = torch.exp(model_log_variance)
160
- else:
161
- model_variance, model_log_variance = {
162
- _ModelVarType.FIXED_LARGE: (
163
- np.append(self.posterior_variance[1], self.betas[1:]),
164
- np.log(np.append(self.posterior_variance[1], self.betas[1:])),
165
- ),
166
- _ModelVarType.FIXED_SMALL: (self.posterior_variance, self.posterior_log_variance_clipped),
167
- }[self.model_var_type]
168
- model_variance = _extract_into_tensor(model_variance, t, x.shape)
169
- model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
170
-
171
- if self.model_mean_type == _ModelMeanType.START_X:
172
- pred_xstart = model_output
173
- elif self.model_mean_type == _ModelMeanType.EPSILON:
174
- pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
175
- else:
176
- pred_xstart = self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
177
- if clip_denoised:
178
- pred_xstart = pred_xstart.clamp(-1, 1)
179
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
180
- return {"mean": model_mean, "variance": model_variance, "log_variance": model_log_variance, "pred_xstart": pred_xstart}
181
-
182
- def p_mean_variance(self, model, x, t, clip_denoised: bool = True, model_kwargs=None):
183
- model_kwargs = {} if model_kwargs is None else model_kwargs
184
- if self.rescale_timesteps:
185
- ts = t.float() * (1000.0 / self.num_timesteps)
186
- else:
187
- ts = t
188
- model_output = model(x, ts, **model_kwargs)
189
- return self.p_mean_variance_from_output(model_output, x, t, clip_denoised=clip_denoised)
190
-
191
- def condition_mean(self, cond_grad: torch.Tensor, p_mean_var: dict, x: torch.Tensor) -> torch.Tensor:
192
- """Apply classifier guidance to the reverse-process mean (Sohl-Dickstein et al., 2015)."""
193
- del x
194
- return p_mean_var["mean"].float() + p_mean_var["variance"] * cond_grad.float()
195
-
196
- def p_sample_from_output(
197
- self,
198
- model_output: torch.Tensor,
199
- x: torch.Tensor,
200
- t: torch.Tensor,
201
- clip_denoised: bool = True,
202
- generator: Optional[torch.Generator] = None,
203
- cond_grad: Optional[torch.Tensor] = None,
204
- ):
205
- out = self.p_mean_variance_from_output(model_output, x, t, clip_denoised=clip_denoised)
206
- if cond_grad is not None:
207
- out["mean"] = self.condition_mean(cond_grad, out, x)
208
- noise = _randn_like(x, generator=generator)
209
- nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
210
- sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
211
- return {"sample": sample, "pred_xstart": out["pred_xstart"]}
212
-
213
- def p_sample(self, model, x, t, clip_denoised=True, model_kwargs=None, generator: Optional[torch.Generator] = None):
214
- out = self.p_mean_variance(model, x, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
215
- noise = _randn_like(x, generator=generator)
216
- nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
217
- sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
218
- return {"sample": sample, "pred_xstart": out["pred_xstart"]}
219
-
220
- def p_sample_loop(self, model, shape, noise=None, clip_denoised=True, model_kwargs=None, device=None, progress=False):
221
- final = None
222
- for sample in self.p_sample_loop_progressive(
223
- model, shape, noise=noise, clip_denoised=clip_denoised, model_kwargs=model_kwargs, device=device, progress=progress
224
- ):
225
- final = sample
226
- return final["sample"]
227
-
228
- def p_sample_loop_progressive(self, model, shape, noise=None, clip_denoised=True, model_kwargs=None, device=None, progress=False):
229
- if device is None:
230
- device = next(model.parameters()).device
231
- img = noise if noise is not None else torch.randn(*shape, device=device)
232
- indices = list(range(self.num_timesteps))[::-1]
233
- if progress:
234
- from tqdm.auto import tqdm
235
-
236
- indices = tqdm(indices)
237
- for i in indices:
238
- t = torch.tensor([i] * shape[0], device=device)
239
- with torch.no_grad():
240
- out = self.p_sample(model, img, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
241
- yield out
242
- img = out["sample"]
243
-
244
- def condition_score(self, cond_grad: torch.Tensor, p_mean_var: dict, x: torch.Tensor, t: torch.Tensor) -> dict:
245
- """Apply classifier guidance to the score (Song et al., 2020) for DDIM."""
246
- alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
247
- eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
248
- eps = eps - (1 - alpha_bar).sqrt() * cond_grad
249
- out = dict(p_mean_var)
250
- out["pred_xstart"] = self._predict_xstart_from_eps(x_t=x, t=t, eps=eps)
251
- out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
252
- return out
253
-
254
- def ddim_sample_from_output(
255
- self,
256
- model_output: torch.Tensor,
257
- x: torch.Tensor,
258
- t: torch.Tensor,
259
- clip_denoised: bool = True,
260
- eta: float = 0.0,
261
- generator: Optional[torch.Generator] = None,
262
- cond_grad: Optional[torch.Tensor] = None,
263
- ):
264
- out = self.p_mean_variance_from_output(model_output, x, t, clip_denoised=clip_denoised)
265
- if cond_grad is not None:
266
- out = self.condition_score(cond_grad, out, x, t)
267
- pred_xstart = out["pred_xstart"]
268
- eps = self._predict_eps_from_xstart(x, t, pred_xstart)
269
- alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
270
- alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
271
- sigma = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
272
- noise = _randn_like(x, generator=generator)
273
- mean_pred = pred_xstart * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps
274
- nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
275
- sample = mean_pred + nonzero_mask * sigma * noise
276
- return {"sample": sample, "pred_xstart": pred_xstart}
277
-
278
- def ddim_sample(
279
- self,
280
- model,
281
- x,
282
- t,
283
- clip_denoised=True,
284
- model_kwargs=None,
285
- eta=0.0,
286
- generator: Optional[torch.Generator] = None,
287
- ):
288
- model_kwargs = {} if model_kwargs is None else model_kwargs
289
- if self.rescale_timesteps:
290
- ts = t.float() * (1000.0 / self.num_timesteps)
291
- else:
292
- ts = t
293
- model_output = model(x, ts, **model_kwargs)
294
- return self.ddim_sample_from_output(
295
- model_output, x, t, clip_denoised=clip_denoised, eta=eta, generator=generator
296
- )
297
-
298
-
299
- class _WrappedModel:
300
- def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
301
- self.model = model
302
- self.timestep_map = timestep_map
303
- self.rescale_timesteps = rescale_timesteps
304
- self.original_num_steps = original_num_steps
305
-
306
- def __call__(self, x, ts, **kwargs):
307
- map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
308
- new_ts = map_tensor[ts]
309
- if self.rescale_timesteps:
310
- new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
311
- return self.model(x, new_ts, **kwargs)
312
-
313
-
314
- class _SpacedDiffusion(_GaussianDiffusion):
315
- def __init__(self, use_timesteps, **kwargs):
316
- self.use_timesteps = set(use_timesteps)
317
- self.timestep_map = []
318
- self.original_num_steps = len(kwargs["betas"])
319
- base_diffusion = _GaussianDiffusion(**kwargs)
320
- last_alpha_cumprod = 1.0
321
- new_betas = []
322
- for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
323
- if i in self.use_timesteps:
324
- new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
325
- last_alpha_cumprod = alpha_cumprod
326
- self.timestep_map.append(i)
327
- kwargs["betas"] = np.array(new_betas)
328
- super().__init__(**kwargs)
329
-
330
- def p_mean_variance(self, model, *args, **kwargs):
331
- return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
332
-
333
- def _wrap_model(self, model):
334
- if isinstance(model, _WrappedModel):
335
- return model
336
- return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps)
337
-
338
-
339
- def _create_spaced_diffusion(
340
- *,
341
- steps: int = 1000,
342
- learn_sigma: bool = False,
343
- sigma_small: bool = False,
344
- noise_schedule: str = "linear",
345
- predict_xstart: bool = False,
346
- rescale_timesteps: bool = False,
347
- timestep_respacing: str = "",
348
- ) -> _SpacedDiffusion:
349
- betas = _get_named_beta_schedule(noise_schedule, steps)
350
- if not timestep_respacing:
351
- timestep_respacing = [steps]
352
- return _SpacedDiffusion(
353
- use_timesteps=_space_timesteps(steps, timestep_respacing),
354
- betas=betas,
355
- model_mean_type=_ModelMeanType.EPSILON if not predict_xstart else _ModelMeanType.START_X,
356
- model_var_type=(_ModelVarType.FIXED_LARGE if not sigma_small else _ModelVarType.FIXED_SMALL)
357
- if not learn_sigma
358
- else _ModelVarType.LEARNED_RANGE,
359
- rescale_timesteps=rescale_timesteps,
360
- )
361
-
362
-
363
- # ---------------------------------------------------------------------------
364
- # Public Diffusers scheduler API
365
- # ---------------------------------------------------------------------------
366
-
367
-
368
- @dataclass
369
- class ADMSchedulerOutput(BaseOutput):
370
- """
371
- Output class for the ADM scheduler's `step` function.
372
-
373
- Args:
374
- prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
375
- Computed sample `(x_{t-1})` of the previous timestep. `prev_sample` should be used as the next model input.
376
- pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
377
- The predicted denoised sample `(x_{0})` based on the model output.
378
- """
379
-
380
- prev_sample: torch.FloatTensor
381
- pred_original_sample: Optional[torch.FloatTensor] = None
382
-
383
-
384
- class ADMScheduler(SchedulerMixin, ConfigMixin):
385
- """
386
- DDPM / DDIM scheduler for ADM (Ablated Diffusion Model) with OpenAI-style Gaussian diffusion.
387
-
388
- This scheduler implements spaced diffusion used by ADM checkpoints. Call `set_timesteps` before inference, then
389
- alternate UNet forward passes with `step`.
390
- """
391
-
392
- config_name = "scheduler_config.json"
393
- order = 1
394
-
395
- @register_to_config
396
- def __init__(
397
- self,
398
- steps: int = 1000,
399
- learn_sigma: bool = False,
400
- sigma_small: bool = False,
401
- noise_schedule: str = "linear",
402
- predict_xstart: bool = False,
403
- rescale_timesteps: bool = False,
404
- timestep_respacing: str = "",
405
- ):
406
- self.timesteps = None
407
- self.num_inference_steps = None
408
- self._diffusion: Optional[_SpacedDiffusion] = None
409
- self._use_ddim = False
410
- self._eta = 0.0
411
-
412
- def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
413
- """
414
- Ensures interchangeability with schedulers that scale the denoising model input depending on the timestep.
415
-
416
- Args:
417
- sample (`torch.Tensor`):
418
- The input sample.
419
- timestep (`int`, *optional*):
420
- The current timestep in the diffusion chain.
421
-
422
- Returns:
423
- `torch.Tensor`:
424
- The (unchanged) input sample.
425
- """
426
- del timestep
427
- return sample
428
-
429
- def set_timesteps(
430
- self,
431
- num_inference_steps: int,
432
- device: Optional[Union[str, torch.device]] = None,
433
- use_ddim: bool = False,
434
- timestep_respacing: Optional[str] = None,
435
- ) -> torch.Tensor:
436
- """
437
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
438
-
439
- Args:
440
- num_inference_steps (`int`):
441
- The number of diffusion steps used when generating samples with a pre-trained model.
442
- device (`str` or `torch.device`, *optional*):
443
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
444
- use_ddim (`bool`, *optional*, defaults to `False`):
445
- Whether to use DDIM sampling instead of DDPM.
446
- timestep_respacing (`str`, *optional*):
447
- Override for the respacing string. If `None`, respacing is derived from `num_inference_steps`.
448
-
449
- Returns:
450
- `torch.Tensor`:
451
- Timestep indices used during denoising, in descending order.
452
- """
453
- if timestep_respacing is None:
454
- timestep_respacing = f"ddim{num_inference_steps}" if use_ddim else str(num_inference_steps)
455
-
456
- self._diffusion = _create_spaced_diffusion(
457
- steps=self.config.steps,
458
- learn_sigma=self.config.learn_sigma,
459
- sigma_small=self.config.sigma_small,
460
- noise_schedule=self.config.noise_schedule,
461
- predict_xstart=self.config.predict_xstart,
462
- rescale_timesteps=self.config.rescale_timesteps,
463
- timestep_respacing=timestep_respacing,
464
- )
465
- self._use_ddim = use_ddim
466
- self.num_inference_steps = num_inference_steps
467
-
468
- indices = list(range(self._diffusion.num_timesteps))[::-1]
469
- timesteps = torch.tensor(indices, dtype=torch.long)
470
- if device is not None:
471
- timesteps = timesteps.to(device)
472
- self.timesteps = timesteps
473
- return self.timesteps
474
-
475
- def scale_timesteps_for_model(self, timestep: torch.Tensor) -> torch.Tensor:
476
- """
477
- Map respaced scheduler indices to the timestep embeddings expected by the ADM UNet.
478
-
479
- Args:
480
- timestep (`torch.Tensor`):
481
- Current scheduler timestep indices of shape `(batch_size,)`.
482
-
483
- Returns:
484
- `torch.Tensor`:
485
- Timesteps to pass to the UNet forward pass.
486
- """
487
- if self._diffusion is None:
488
- raise ValueError("Call `set_timesteps` before running the scheduler.")
489
-
490
- map_tensor = torch.tensor(self._diffusion.timestep_map, device=timestep.device, dtype=timestep.dtype)
491
- model_timesteps = map_tensor[timestep]
492
- if self._diffusion.rescale_timesteps:
493
- model_timesteps = model_timesteps.float() * (1000.0 / self._diffusion.original_num_steps)
494
- return model_timesteps
495
-
496
- def step(
497
- self,
498
- model_output: torch.Tensor,
499
- timestep: Union[int, torch.Tensor],
500
- sample: torch.Tensor,
501
- generator: Optional[torch.Generator] = None,
502
- return_dict: bool = True,
503
- clip_denoised: bool = True,
504
- eta: Optional[float] = None,
505
- cond_grad: Optional[torch.Tensor] = None,
506
- ) -> Union[ADMSchedulerOutput, Tuple[torch.Tensor, ...]]:
507
- """
508
- Predict the sample at the previous timestep from the model output.
509
-
510
- Args:
511
- model_output (`torch.Tensor`):
512
- The direct output from the ADM UNet.
513
- timestep (`int` or `torch.Tensor`):
514
- The current discrete timestep index in the respaced diffusion chain.
515
- sample (`torch.Tensor`):
516
- A current instance of a sample created by the diffusion process.
517
- generator (`torch.Generator`, *optional*):
518
- A random number generator for the sampling noise.
519
- return_dict (`bool`, *optional*, defaults to `True`):
520
- Whether or not to return an [`ADMSchedulerOutput`] instead of a plain tuple.
521
- clip_denoised (`bool`, *optional*, defaults to `True`):
522
- Whether to clamp the predicted `x_0` to `[-1, 1]`.
523
- eta (`float`, *optional*):
524
- DDIM stochasticity parameter. Only used when `use_ddim=True` was passed to `set_timesteps`.
525
- cond_grad (`torch.Tensor`, *optional*):
526
- Classifier guidance gradient for ADM-G (`classifier_scale * grad log p(y|x_t)`).
527
-
528
- Returns:
529
- [`ADMSchedulerOutput`] or `tuple`:
530
- If `return_dict` is `True`, an [`ADMSchedulerOutput`] is returned, otherwise a tuple is returned where
531
- the first element is the previous sample.
532
- """
533
- if self._diffusion is None:
534
- raise ValueError("Call `set_timesteps` before `step`.")
535
-
536
- if not torch.is_tensor(timestep):
537
- timestep = torch.tensor([timestep], device=sample.device, dtype=torch.long)
538
- elif timestep.ndim == 0:
539
- timestep = timestep.reshape(1).to(device=sample.device, dtype=torch.long)
540
- else:
541
- timestep = timestep.to(device=sample.device, dtype=torch.long)
542
-
543
- ddim_eta = self._eta if eta is None else eta
544
-
545
- if self._use_ddim:
546
- out = self._diffusion.ddim_sample_from_output(
547
- model_output,
548
- sample,
549
- timestep,
550
- clip_denoised=clip_denoised,
551
- eta=ddim_eta,
552
- generator=generator,
553
- cond_grad=cond_grad,
554
- )
555
- else:
556
- out = self._diffusion.p_sample_from_output(
557
- model_output,
558
- sample,
559
- timestep,
560
- clip_denoised=clip_denoised,
561
- generator=generator,
562
- cond_grad=cond_grad,
563
- )
564
-
565
- prev_sample = out["sample"]
566
- pred_original_sample = out.get("pred_xstart")
567
-
568
- if not return_dict:
569
- return (prev_sample, pred_original_sample)
570
-
571
- return ADMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
572
-
573
- def create_runtime(self, num_inference_steps: Optional[int] = None, use_ddim: bool = False) -> _SpacedDiffusion:
574
- """
575
- Build a spaced diffusion object for legacy loop-based sampling (`p_sample_loop`).
576
-
577
- Prefer `set_timesteps` + `step` for Diffusers-style inference.
578
- """
579
- timestep_respacing = self.config.timestep_respacing
580
- if num_inference_steps is not None:
581
- timestep_respacing = f"ddim{num_inference_steps}" if use_ddim else str(num_inference_steps)
582
- return _create_spaced_diffusion(
583
- steps=self.config.steps,
584
- learn_sigma=self.config.learn_sigma,
585
- sigma_small=self.config.sigma_small,
586
- noise_schedule=self.config.noise_schedule,
587
- predict_xstart=self.config.predict_xstart,
588
- rescale_timesteps=self.config.rescale_timesteps,
589
- timestep_respacing=timestep_respacing,
590
- )