|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from .ddpm import DDPMPredictionType |
|
|
from .scheduler import Scheduler |
|
|
|
|
|
DDIMPredictionType = DDPMPredictionType |
|
|
|
|
|
|
|
|
class DDIMScheduler(Scheduler): |
|
|
""" |
|
|
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising |
|
|
diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion |
|
|
Implicit Models" https://arxiv.org/abs/2010.02502 |
|
|
|
|
|
Args: |
|
|
num_train_timesteps: number of diffusion steps used to train the model. |
|
|
schedule: member of NoiseSchedules, name of noise schedule function in component store |
|
|
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. |
|
|
set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. |
|
|
For the final step there is no previous alpha. When this option is `True` the previous alpha product is |
|
|
fixed to `1`, otherwise it uses the value of alpha at step 0. |
|
|
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and |
|
|
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in |
|
|
stable diffusion. |
|
|
prediction_type: member of DDPMPredictionType |
|
|
clip_sample_min: minimum clipping value when clip_sample equals True |
|
|
clip_sample_max: maximum clipping value when clip_sample equals True |
|
|
schedule_args: arguments to pass to the schedule function |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_train_timesteps: int = 1000, |
|
|
schedule: str = "linear_beta", |
|
|
clip_sample: bool = True, |
|
|
set_alpha_to_one: bool = True, |
|
|
steps_offset: int = 0, |
|
|
prediction_type: str = DDIMPredictionType.EPSILON, |
|
|
clip_sample_min: float = -1.0, |
|
|
clip_sample_max: float = 1.0, |
|
|
**schedule_args, |
|
|
) -> None: |
|
|
super().__init__(num_train_timesteps, schedule, **schedule_args) |
|
|
|
|
|
if prediction_type not in DDIMPredictionType.__members__.values(): |
|
|
raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") |
|
|
|
|
|
self.prediction_type = prediction_type |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] |
|
|
|
|
|
|
|
|
self.init_noise_sigma = 1.0 |
|
|
|
|
|
self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) |
|
|
|
|
|
self.clip_sample = clip_sample |
|
|
self.clip_sample_values = [clip_sample_min, clip_sample_max] |
|
|
self.steps_offset = steps_offset |
|
|
|
|
|
|
|
|
self.num_inference_steps: int |
|
|
self.set_timesteps(self.num_train_timesteps) |
|
|
|
|
|
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: |
|
|
""" |
|
|
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. |
|
|
|
|
|
Args: |
|
|
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. |
|
|
device: target device to put the data. |
|
|
""" |
|
|
if num_inference_steps > self.num_train_timesteps: |
|
|
raise ValueError( |
|
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" |
|
|
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" |
|
|
f" maximal {self.num_train_timesteps} timesteps." |
|
|
) |
|
|
|
|
|
self.num_inference_steps = num_inference_steps |
|
|
step_ratio = self.num_train_timesteps // self.num_inference_steps |
|
|
if self.steps_offset >= step_ratio: |
|
|
raise ValueError( |
|
|
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " |
|
|
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" |
|
|
f" the max train timestep." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) |
|
|
self.timesteps = torch.from_numpy(timesteps).to(device) |
|
|
self.timesteps += self.steps_offset |
|
|
|
|
|
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: |
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod |
|
|
beta_prod_t = 1 - alpha_prod_t |
|
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
|
|
|
variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) |
|
|
|
|
|
return variance |
|
|
|
|
|
def step( |
|
|
self, |
|
|
model_output: torch.Tensor, |
|
|
timestep: int, |
|
|
sample: torch.Tensor, |
|
|
eta: float = 0.0, |
|
|
generator: torch.Generator | None = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion |
|
|
process from the learned model outputs (most often the predicted noise). |
|
|
|
|
|
Args: |
|
|
model_output: direct output from learned diffusion model. |
|
|
timestep: current discrete timestep in the diffusion chain. |
|
|
sample: current instance of sample being created by diffusion process. |
|
|
eta: weight of noise for added noise in diffusion step. |
|
|
generator: random number generator. |
|
|
|
|
|
Returns: |
|
|
pred_prev_sample: Predicted previous sample |
|
|
pred_original_sample: Predicted original sample |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps |
|
|
|
|
|
|
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod |
|
|
|
|
|
beta_prod_t = 1 - alpha_prod_t |
|
|
|
|
|
|
|
|
pred_original_sample = sample |
|
|
pred_epsilon = model_output |
|
|
|
|
|
|
|
|
|
|
|
if self.prediction_type == DDIMPredictionType.EPSILON: |
|
|
pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) |
|
|
pred_epsilon = model_output |
|
|
elif self.prediction_type == DDIMPredictionType.SAMPLE: |
|
|
pred_original_sample = model_output |
|
|
pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) |
|
|
elif self.prediction_type == DDIMPredictionType.V_PREDICTION: |
|
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output |
|
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
|
|
|
|
|
|
|
|
if self.clip_sample: |
|
|
pred_original_sample = torch.clamp( |
|
|
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
variance = self._get_variance(timestep, prev_timestep) |
|
|
std_dev_t = eta * variance**0.5 |
|
|
|
|
|
|
|
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon |
|
|
|
|
|
|
|
|
pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction |
|
|
|
|
|
if eta > 0: |
|
|
|
|
|
device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") |
|
|
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device) |
|
|
variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise |
|
|
|
|
|
pred_prev_sample = pred_prev_sample + variance |
|
|
|
|
|
return pred_prev_sample, pred_original_sample |
|
|
|
|
|
def reversed_step( |
|
|
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion |
|
|
process from the learned model outputs (most often the predicted noise). |
|
|
|
|
|
Args: |
|
|
model_output: direct output from learned diffusion model. |
|
|
timestep: current discrete timestep in the diffusion chain. |
|
|
sample: current instance of sample being created by diffusion process. |
|
|
|
|
|
Returns: |
|
|
pred_prev_sample: Predicted previous sample |
|
|
pred_original_sample: Predicted original sample |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps |
|
|
|
|
|
|
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod |
|
|
|
|
|
beta_prod_t = 1 - alpha_prod_t |
|
|
|
|
|
|
|
|
pred_original_sample = sample |
|
|
pred_epsilon = model_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.prediction_type == DDIMPredictionType.EPSILON: |
|
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
|
|
pred_epsilon = model_output |
|
|
elif self.prediction_type == DDIMPredictionType.SAMPLE: |
|
|
pred_original_sample = model_output |
|
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) |
|
|
elif self.prediction_type == DDIMPredictionType.V_PREDICTION: |
|
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output |
|
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
|
|
|
|
|
|
|
|
if self.clip_sample: |
|
|
pred_original_sample = torch.clamp( |
|
|
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] |
|
|
) |
|
|
|
|
|
|
|
|
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon |
|
|
|
|
|
|
|
|
pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction |
|
|
|
|
|
return pred_post_sample, pred_original_sample |
|
|
|