| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import List, Literal, Optional, Tuple, Union |
|
|
| import torch |
| from fairseq2.logging import get_log_writer |
| from fairseq2.typing import CPU |
| from torch import Tensor |
|
|
| logger = get_log_writer(__name__) |
|
|
|
|
| def sigmoid(x): |
| return 1 / (1 + math.exp(-x)) |
|
|
|
|
| def logit(x): |
| return math.log(x / (1 - x)) |
|
|
|
|
| @dataclass |
| class DDIMSchedulerOutput: |
| """ |
| Output class for the scheduler's `step` function output. |
| |
| Args: |
| prev_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
| Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the |
| denoising loop. |
| pred_original_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
| The predicted denoised sample `(x_{0})` based on the model output from the current timestep. |
| `pred_original_sample` can be used to preview progress or for guidance. |
| """ |
|
|
| prev_sample: Tensor |
| pred_original_sample: Tensor |
|
|
|
|
| @dataclass |
| class DDIMSchedulerConfig: |
| num_diffusion_train_steps: int = 1000 |
| """The number of diffusion steps to train the model.""" |
|
|
| beta_start: float = 0.0001 |
| """The starting `beta` value of inference.""" |
|
|
| beta_end: float = 0.02 |
| """The final `beta` value.""" |
| """In DDPM (https://arxiv.org/pdf/2006.11239), $\beta_t$ is increasing |
| linearly from $\beta_1$ (`beta_start`)=1e−4 to $\beta_T$ (`beta_end`)=0.02. |
| These constants were chosen to be small relative to data scaled to [−1, 1], |
| ensuring that reverse and forward processes have approximately |
| the same functional form while keeping the signal-to-noise ratio at $x_T$ as small as possible. |
| Another common choice in HF:diffusers `beta_start=0.00085, beta_end=0.012,` |
| Note that `beta_start` and `beta_end` are irrelevant for `squaredcos_cap_v2` |
| """ |
|
|
| beta_schedule: Literal[ |
| "linear", |
| "scaled_linear", |
| "squaredcos_cap_v2", |
| "sigmoid", |
| ] = "squaredcos_cap_v2" |
| """The beta schedule, a mapping from a beta range to a sequence of betas |
| for stepping the model (length=`num_diffusion_train_steps`). |
| Choose from: |
| - `linear`: Linearly spaced betas between `beta_start` and `beta_end`. |
| Referred to as `sqrt_linear` in stable-diffusion. |
| - `scaled_linear`: Squared values after linearly spacing form sqrt(beta_start) to sqrt(beta_end). |
| Referred to as `linear` in stable-diffusion. |
| -`squaredcos_cap_v2`: Creates a beta schedule that discretizes |
| math:: $\bar alpha(t) = {cos((t/T + s) / (1+s) * \pi/2)}^2$, HF:diffusers sets `s` to 0.008. |
| For the intuition behind how a cosine schedule compares to a linear schedule |
| see Figure 3 of https://arxiv.org/pdf/2102.09672 |
| - `sigmoid` our sigmoid schedule (see Equation 14 of the LCM paper). |
| """ |
|
|
| scaled_linear_exponent: float = 2.0 |
| """Exponent for the scaled linear beta schedule. Default is quadratic (scaled_linear_exponent=2)""" |
|
|
| sigmoid_schedule_alpha: float = 1.5 |
| sigmoid_schedule_beta: float = 0 |
| """alpha and beta hyper-parameters of the sigmoid beta-schedule""" |
|
|
| clip_sample: bool = False |
| """Clip the predicted sample for numerical stability.""" |
|
|
| clip_sample_range: float = 1.0 |
| """The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.""" |
|
|
| set_alpha_to_one: bool = True |
| """Each diffusion step uses the alphas product value at that step and at the previous one. For the final step |
| there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, |
| otherwise it uses the alpha value at step 0.""" |
|
|
| prediction_type: Literal["sample", "epsilon", "v_prediction"] = "sample" |
| """If `sample`, the model predicts the clean ground truth embeddings. |
| If `epsilon`, the model predicts the added noise of the diffusion process. |
| If `v_epsilon`, the model predicts an interpolation of the ground truth clean |
| embeddings and the added noise. As introduced in section 2.4 of the Imagen paper |
| (https://imagen.research.google/video/paper.pdf) |
| """ |
|
|
| thresholding: bool = False |
| """Whether to use the "dynamic thresholding" method. |
| This is unsuitable for latent-space diffusion models such as Stable Diffusion.""" |
|
|
| dynamic_thresholding_ratio: float = 0.995 |
| """The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.""" |
|
|
| sample_max_value: float = 1.0 |
| """The threshold value for dynamic thresholding. Valid only when `thresholding=True`.""" |
|
|
| rescale_betas_zero_snr: bool = True |
| """Whether to rescale the betas to have zero terminal SNR. This enables the |
| model to generate very bright and dark samples instead of limiting it to samples |
| with medium brightness. Loosely related to |
| [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).""" |
|
|
| |
| timestep_spacing: Literal["linspace", "leading", "trailing"] = "trailing" |
| """The way the timesteps should be scaled. Refer to Table 2 of |
| https://arxiv.org/abs/2305.08891 for more information.""" |
|
|
|
|
| class DDIMScheduler: |
| def __init__(self, config: DDIMSchedulerConfig): |
| self.config = config |
|
|
| |
| self.num_diffusion_train_steps = self.config.num_diffusion_train_steps |
|
|
| self.prediction_type = self.config.prediction_type |
|
|
| beta_schedule = self.config.beta_schedule |
|
|
| if beta_schedule == "linear": |
| self.betas = torch.linspace( |
| self.config.beta_start, |
| self.config.beta_end, |
| self.num_diffusion_train_steps, |
| dtype=torch.float32, |
| ) |
| elif beta_schedule == "scaled_linear": |
| |
| exponent = self.config.scaled_linear_exponent |
| self.betas = ( |
| torch.linspace( |
| self.config.beta_start ** (1 / exponent), |
| self.config.beta_end ** (1 / exponent), |
| self.num_diffusion_train_steps, |
| dtype=torch.float32, |
| ) |
| ** exponent |
| ) |
| elif beta_schedule == "squaredcos_cap_v2": |
| |
| |
| self.betas = betas_for_alpha_bar( |
| self.num_diffusion_train_steps, |
| alpha_transform_type="cosine", |
| ) |
|
|
| elif beta_schedule == "sigmoid": |
| self.betas = betas_for_alpha_bar( |
| self.num_diffusion_train_steps, |
| alpha_transform_type="sigmoid", |
| sigmoid_alpha=self.config.sigmoid_schedule_alpha, |
| sigmoid_beta=self.config.sigmoid_schedule_beta, |
| ) |
|
|
| else: |
| raise NotImplementedError( |
| f"We do not recognize beta_schedule={beta_schedule}" |
| ) |
|
|
| |
| if self.config.rescale_betas_zero_snr: |
| self.betas = rescale_zero_terminal_snr(self.betas) |
|
|
| self.alphas = 1.0 - self.betas |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
|
|
| |
| |
| |
| |
| self.final_alpha_cumprod = ( |
| torch.tensor(1.0) |
| if self.config.set_alpha_to_one |
| else self.alphas_cumprod[0] |
| ) |
|
|
| |
| self.init_noise_sigma = 1.0 |
|
|
| |
| self.num_inference_steps: Optional[int] = None |
|
|
| def _get_variance(self, timestep, prev_timestep): |
| alpha_prod_t = self.alphas_cumprod[timestep] |
| alpha_prod_t_prev = ( |
| self.alphas_cumprod[prev_timestep] |
| if prev_timestep >= 0 |
| else self.final_alpha_cumprod |
| ) |
| beta_prod_t = 1 - alpha_prod_t |
| beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
| variance = (beta_prod_t_prev / beta_prod_t) * ( |
| 1 - alpha_prod_t / alpha_prod_t_prev |
| ) |
| return variance |
|
|
| def get_variances(self) -> Tensor: |
| alpha_prod_t = self.alphas_cumprod |
| alpha_prod_t_prev = torch.cat( |
| (torch.tensor([self.final_alpha_cumprod]), alpha_prod_t[:-1]) |
| ) |
| beta_prod_t = 1 - alpha_prod_t |
| beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
| variance = (beta_prod_t_prev / beta_prod_t) * ( |
| 1 - alpha_prod_t / alpha_prod_t_prev |
| ) |
| return variance |
|
|
| def get_snrs(self) -> Tensor: |
| alphas_cumprod = self.alphas_cumprod |
| snr = alphas_cumprod / (1 - alphas_cumprod) |
| return snr |
|
|
| def _threshold_sample(self, sample: Tensor) -> Tensor: |
| """ |
| "Dynamic thresholding: At each sampling step we set s to a certain |
| percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), |
| and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by |
| s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) |
| inwards, thereby actively preventing pixels from saturation at each step. |
| We find that dynamic thresholding results in significantly better |
| photorealism as well as better image-text alignment, |
| especially when using very large guidance weights." |
| |
| https://arxiv.org/abs/2205.11487 |
| """ |
| dtype = sample.dtype |
| batch_size, channels, *remaining_dims = sample.shape |
|
|
| if dtype not in (torch.float32, torch.float64): |
| sample = ( |
| sample.float() |
| ) |
|
|
| |
| sample = sample.reshape(batch_size, -1) |
|
|
| abs_sample = sample.abs() |
|
|
| s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) |
| s = torch.clamp( |
| s, min=1, max=self.config.sample_max_value |
| ) |
| s = s.unsqueeze(1) |
| sample = ( |
| torch.clamp(sample, -s, s) / s |
| ) |
|
|
| sample = sample.reshape(batch_size, channels, *remaining_dims) |
| sample = sample.to(dtype) |
|
|
| return sample |
|
|
| def set_timesteps( |
| self, num_inference_steps: int, device: Union[str, torch.device] = None |
| ): |
| """ |
| Sets the discrete timesteps used for the diffusion chain (to be run before inference). |
| |
| Args: |
| num_inference_steps (`int`): |
| The number of diffusion steps used when generating samples with a pre-trained model. |
| """ |
|
|
| if num_inference_steps > self.config.num_diffusion_train_steps: |
| raise ValueError( |
| f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.num_diffusion_train_steps`:" |
| f" {self.num_diffusion_train_steps} as the unet model trained with this scheduler can only handle" |
| f" maximal {self.num_diffusion_train_steps} timesteps." |
| ) |
|
|
| self.num_inference_steps = num_inference_steps |
|
|
| |
| |
|
|
| if self.config.timestep_spacing == "linspace": |
| |
| |
| timesteps = torch.linspace( |
| 0, |
| self.config.num_diffusion_train_steps - 1, |
| self.num_inference_steps, |
| device=device, |
| dtype=torch.long, |
| ) |
| timesteps = torch.flip(timesteps, dims=(0,)).round() |
|
|
| elif self.config.timestep_spacing == "leading": |
| |
| |
|
|
| leading_step_ratio = ( |
| self.num_diffusion_train_steps // self.num_inference_steps |
| ) |
| timesteps = torch.arange( |
| start=0, |
| end=self.num_diffusion_train_steps, |
| step=leading_step_ratio, |
| device=device, |
| dtype=torch.long, |
| ) |
| timesteps = torch.flip(timesteps, dims=(0,)).round() |
|
|
| elif self.config.timestep_spacing == "trailing": |
| |
| |
| trailing_step_ratio: float = ( |
| self.num_diffusion_train_steps / self.num_inference_steps |
| ) |
| |
| timesteps = torch.arange( |
| self.config.num_diffusion_train_steps, |
| 0, |
| -trailing_step_ratio, |
| device=device, |
| dtype=torch.long, |
| ).round() |
| timesteps -= 1 |
| else: |
| raise ValueError( |
| f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." |
| ) |
|
|
| self.timesteps = timesteps |
| logger.debug( |
| f"With `{self.config.timestep_spacing}`, setting inference timesteps to {self.timesteps}" |
| ) |
|
|
| def step( |
| self, |
| model_output: Tensor, |
| timestep: int, |
| sample: Tensor, |
| eta: float = 0.0, |
| use_clipped_model_output: bool = False, |
| generator=None, |
| variance_noise: Optional[Tensor] = None, |
| prediction_type: Optional[str] = None, |
| epsilon_scaling: Optional[float] = None, |
| ) -> DDIMSchedulerOutput: |
| """ |
| INFERENCE ONLY. |
| Predict the sample from the previous timestep by reversing the SDE. |
| This function propagates the diffusion |
| process from the learned model outputs. |
| |
| Args: |
| model_output (`Tensor`): |
| The direct output from learned diffusion model. |
| timestep (`float`): |
| The current discrete timestep in the diffusion chain. |
| sample (`Tensor`): |
| A current instance of a sample created by the diffusion process. |
| eta (`float`): |
| The weight of noise for added noise in diffusion step. |
| use_clipped_model_output (`bool`, defaults to `False`): |
| If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary |
| because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no |
| clipping has happened, "corrected" `model_output` would coincide with the one provided as input and |
| `use_clipped_model_output` has no effect. |
| generator (`torch.Generator`, *optional*): |
| A random number generator. |
| variance_noise (`Tensor`): |
| Alternative to generating noise with `generator` by directly providing the noise for the variance |
| itself. Useful for methods such as [`CycleDiffusion`]. |
| prediction_type: Optional[str] if provided we step with a different prediction_type |
| than the one in the config |
| epsilon_scaling: Optional[float] if not None, the predicted epsilon will be scaled down by |
| the provided factor as introduced in https://arxiv.org/pdf/2308.15321 |
| |
| Returns: |
| DDIMSchedulerOutput |
| |
| """ |
| if self.num_inference_steps is None: |
| raise ValueError( |
| "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
| ) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| prev_timestep = ( |
| timestep - self.config.num_diffusion_train_steps // self.num_inference_steps |
| ) |
|
|
| |
| alpha_prod_t = self.alphas_cumprod[timestep] |
| alpha_prod_t_prev = ( |
| self.alphas_cumprod[prev_timestep] |
| if prev_timestep >= 0 |
| else self.final_alpha_cumprod |
| ) |
|
|
| beta_prod_t = 1 - alpha_prod_t |
|
|
| |
| |
| prediction_type = prediction_type or self.prediction_type |
| if prediction_type == "epsilon": |
| pred_original_sample = ( |
| sample - beta_prod_t ** (0.5) * model_output |
| ) / alpha_prod_t ** (0.5) |
| pred_epsilon = model_output |
| elif prediction_type == "sample": |
| pred_original_sample = model_output |
| pred_epsilon = ( |
| sample - alpha_prod_t ** (0.5) * pred_original_sample |
| ) / beta_prod_t ** (0.5) |
| elif prediction_type == "v_prediction": |
| pred_original_sample = (alpha_prod_t**0.5) * sample - ( |
| beta_prod_t**0.5 |
| ) * model_output |
| pred_epsilon = (alpha_prod_t**0.5) * model_output + ( |
| beta_prod_t**0.5 |
| ) * sample |
| else: |
| raise ValueError( |
| f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" |
| " `v_prediction`" |
| ) |
|
|
| |
| if epsilon_scaling is not None: |
| pred_epsilon = pred_epsilon / epsilon_scaling |
|
|
| |
| if self.config.thresholding: |
| pred_original_sample = self._threshold_sample(pred_original_sample) |
| elif self.config.clip_sample: |
| pred_original_sample = pred_original_sample.clamp( |
| -self.config.clip_sample_range, self.config.clip_sample_range |
| ) |
|
|
| |
| |
| variance = self._get_variance(timestep, prev_timestep) |
| std_dev_t = eta * variance ** (0.5) |
| if use_clipped_model_output: |
| |
| pred_epsilon = ( |
| sample - alpha_prod_t ** (0.5) * pred_original_sample |
| ) / beta_prod_t ** (0.5) |
| |
| pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** ( |
| 0.5 |
| ) * pred_epsilon |
| |
| prev_sample = ( |
| alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction |
| ) |
|
|
| if eta > 0: |
| if variance_noise is not None and generator is not None: |
| raise ValueError( |
| "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" |
| " `variance_noise` stays `None`." |
| ) |
|
|
| if variance_noise is None: |
| variance_noise = randn_tensor( |
| model_output.shape, |
| generator=generator, |
| device=model_output.device, |
| dtype=model_output.dtype, |
| ) |
| variance = std_dev_t * variance_noise |
| prev_sample = prev_sample + variance |
|
|
| return DDIMSchedulerOutput( |
| prev_sample=prev_sample, pred_original_sample=pred_original_sample |
| ) |
|
|
| def add_noise( |
| self, |
| original_samples: Tensor, |
| noise: Tensor, |
| timesteps: Tensor, |
| ) -> Tensor: |
| """TRAINING ONLY |
| Forward noising process during training""" |
| |
| |
| |
| self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) |
| alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) |
| timesteps = timesteps.to(original_samples.device).to(torch.int32) |
|
|
| sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 |
| sqrt_alpha_prod = sqrt_alpha_prod.flatten() |
| while len(sqrt_alpha_prod.shape) < len(original_samples.shape): |
| sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) |
|
|
| sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() |
| while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) |
|
|
| noisy_samples = ( |
| sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise |
| ) |
| return noisy_samples |
|
|
| def get_velocity(self, sample: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor: |
| |
| self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) |
| alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) |
| timesteps = timesteps.to(sample.device).to(torch.int32) |
|
|
| sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 |
| sqrt_alpha_prod = sqrt_alpha_prod.flatten() |
| while len(sqrt_alpha_prod.shape) < len(sample.shape): |
| sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) |
|
|
| sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() |
| while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) |
|
|
| velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample |
| return velocity |
|
|
| def get_epsilon( |
| self, model_output: Tensor, sample: Tensor, timestep: int |
| ) -> Tensor: |
| """Given model inputs (sample) and outputs (model_output) |
| Predict the noise residual according to the scheduler's |
| prediction type""" |
|
|
| pred_type = self.prediction_type |
|
|
| alpha_prod_t = self.alphas_cumprod[timestep] |
|
|
| beta_prod_t = 1 - alpha_prod_t |
|
|
| if pred_type == "epsilon": |
| return model_output |
|
|
| elif pred_type == "sample": |
| return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** ( |
| 0.5 |
| ) |
|
|
| elif pred_type == "v_prediction": |
| return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
| else: |
| raise ValueError( |
| f"The scheduler's prediction type {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" |
| ) |
|
|
|
|
| def randn_tensor( |
| shape: Union[Tuple, List], |
| generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, |
| device: Optional["torch.device"] = None, |
| dtype: Optional["torch.dtype"] = None, |
| layout: Optional["torch.layout"] = None, |
| ): |
| """A helper function to create random tensors on the desired `device` with the desired `dtype`. When |
| passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor |
| is always created on the CPU. |
| """ |
| |
| rand_device = device |
| batch_size = shape[0] |
|
|
| layout = layout or torch.strided |
| device = device or torch.device("cpu") |
|
|
| if generator is not None: |
| gen_device_type = ( |
| generator.device.type |
| if not isinstance(generator, list) |
| else generator[0].device.type |
| ) |
| if gen_device_type != device.type and gen_device_type == "cpu": |
| rand_device = CPU |
| if device != "mps": |
| logger.info( |
| f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." |
| f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" |
| f" slighly speed up this function by passing a generator that was created on the {device} device." |
| ) |
| elif gen_device_type != device.type and gen_device_type == "cuda": |
| raise ValueError( |
| f"Cannot generate a {device} tensor from a generator of type {gen_device_type}." |
| ) |
|
|
| |
| if isinstance(generator, list) and len(generator) == 1: |
| generator = generator[0] |
|
|
| if isinstance(generator, list): |
| shape = (1,) + shape[1:] |
| latents_list = [ |
| torch.randn( |
| shape, |
| generator=generator[i], |
| device=rand_device, |
| dtype=dtype, |
| layout=layout, |
| ) |
| for i in range(batch_size) |
| ] |
| latents = torch.cat(latents_list, dim=0).to(device) |
| else: |
| latents = torch.randn( |
| shape, generator=generator, device=rand_device, dtype=dtype, layout=layout |
| ).to(device) |
|
|
| return latents |
|
|
|
|
| def betas_for_alpha_bar( |
| num_diffusion_timesteps: int, |
| max_beta: float = 0.999, |
| alpha_transform_type: Literal["cosine", "exp", "sigmoid"] = "cosine", |
| sigmoid_alpha: float = 1.5, |
| sigmoid_beta: float = 0, |
| ) -> Tensor: |
| """ |
| Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of |
| (1-beta) over time from t = [0,1]. |
| |
| Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up |
| to that part of the diffusion process. |
| |
| |
| Args: |
| num_diffusion_timesteps (`int`): the number of betas to produce. |
| max_beta (`float`): the maximum beta to use; use values lower than 1 to |
| prevent singularities. |
| alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. |
| Choose from `cosine` or `exp` |
| sigmoid_alpha/sigmoid_beta: additional hyper-parameters for the sigmoid schedule |
| |
| Returns: |
| betas (`Tensor`): the betas used by the scheduler to step the model outputs |
| """ |
| if alpha_transform_type == "cosine": |
|
|
| def alpha_bar_fn(t): |
| return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 |
|
|
| elif alpha_transform_type == "sigmoid": |
|
|
| def alpha_bar_fn(t): |
| epsilon = 1e-32 |
| return sigmoid( |
| sigmoid_beta |
| - sigmoid_alpha |
| * logit(torch.clamp(torch.tensor(t), min=epsilon, max=1 - epsilon)) |
| ) |
|
|
| elif alpha_transform_type == "exp": |
|
|
| def alpha_bar_fn(t): |
| return math.exp(t * -12.0) |
|
|
| else: |
| raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") |
|
|
| betas = [] |
| for i in range(num_diffusion_timesteps): |
| t1 = i / num_diffusion_timesteps |
| t2 = (i + 1) / num_diffusion_timesteps |
| betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) |
| return torch.tensor(betas, dtype=torch.float32) |
|
|
|
|
| def rescale_zero_terminal_snr(betas: Tensor) -> Tensor: |
| """ |
| Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) |
| |
| Args: |
| betas (`Tensor`): |
| the betas that the scheduler is being initialized with. |
| |
| Returns: |
| `Tensor`: rescaled betas with zero terminal SNR |
| """ |
| |
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| alphas_bar_sqrt = alphas_cumprod.sqrt() |
|
|
| |
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() |
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() |
|
|
| |
| alphas_bar_sqrt -= alphas_bar_sqrt_T |
|
|
| |
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) |
|
|
| |
| alphas_bar = alphas_bar_sqrt**2 |
| alphas = alphas_bar[1:] / alphas_bar[:-1] |
| alphas = torch.cat([alphas_bar[0:1], alphas]) |
| betas = 1 - alphas |
|
|
| return betas |
|
|