|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from monai.utils import StrEnum |
|
|
|
|
|
from .scheduler import Scheduler |
|
|
|
|
|
|
|
|
class DDPMVarianceType(StrEnum): |
|
|
""" |
|
|
Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise |
|
|
to the denoised sample. |
|
|
""" |
|
|
|
|
|
FIXED_SMALL = "fixed_small" |
|
|
FIXED_LARGE = "fixed_large" |
|
|
LEARNED = "learned" |
|
|
LEARNED_RANGE = "learned_range" |
|
|
|
|
|
|
|
|
class DDPMPredictionType(StrEnum): |
|
|
""" |
|
|
Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. |
|
|
|
|
|
epsilon: predicting the noise of the diffusion process |
|
|
sample: directly predicting the noisy sample |
|
|
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf |
|
|
""" |
|
|
|
|
|
EPSILON = "epsilon" |
|
|
SAMPLE = "sample" |
|
|
V_PREDICTION = "v_prediction" |
|
|
|
|
|
|
|
|
class DDPMScheduler(Scheduler): |
|
|
""" |
|
|
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and |
|
|
Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" |
|
|
https://arxiv.org/abs/2006.11239 |
|
|
|
|
|
Args: |
|
|
num_train_timesteps: number of diffusion steps used to train the model. |
|
|
schedule: member of NoiseSchedules, name of noise schedule function in component store |
|
|
variance_type: member of DDPMVarianceType |
|
|
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. |
|
|
prediction_type: member of DDPMPredictionType |
|
|
clip_sample_min: minimum clipping value when clip_sample equals True |
|
|
clip_sample_max: maximum clipping value when clip_sample equals True |
|
|
schedule_args: arguments to pass to the schedule function |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_train_timesteps: int = 1000, |
|
|
schedule: str = "linear_beta", |
|
|
variance_type: str = DDPMVarianceType.FIXED_SMALL, |
|
|
clip_sample: bool = True, |
|
|
prediction_type: str = DDPMPredictionType.EPSILON, |
|
|
clip_sample_min: float = -1.0, |
|
|
clip_sample_max: float = 1.0, |
|
|
**schedule_args, |
|
|
) -> None: |
|
|
super().__init__(num_train_timesteps, schedule, **schedule_args) |
|
|
|
|
|
if variance_type not in DDPMVarianceType.__members__.values(): |
|
|
raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") |
|
|
|
|
|
if prediction_type not in DDPMPredictionType.__members__.values(): |
|
|
raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") |
|
|
|
|
|
self.clip_sample = clip_sample |
|
|
self.clip_sample_values = [clip_sample_min, clip_sample_max] |
|
|
self.variance_type = variance_type |
|
|
self.prediction_type = prediction_type |
|
|
|
|
|
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: |
|
|
""" |
|
|
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. |
|
|
|
|
|
Args: |
|
|
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. |
|
|
device: target device to put the data. |
|
|
""" |
|
|
if num_inference_steps > self.num_train_timesteps: |
|
|
raise ValueError( |
|
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" |
|
|
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" |
|
|
f" maximal {self.num_train_timesteps} timesteps." |
|
|
) |
|
|
|
|
|
self.num_inference_steps = num_inference_steps |
|
|
step_ratio = self.num_train_timesteps // self.num_inference_steps |
|
|
|
|
|
|
|
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) |
|
|
self.timesteps = torch.from_numpy(timesteps).to(device) |
|
|
|
|
|
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute the mean of the posterior at timestep t. |
|
|
|
|
|
Args: |
|
|
timestep: current timestep. |
|
|
x0: the noise-free input. |
|
|
x_t: the input noised to timestep t. |
|
|
|
|
|
Returns: |
|
|
Returns the mean |
|
|
""" |
|
|
|
|
|
|
|
|
alpha_t = self.alphas[timestep] |
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
|
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one |
|
|
|
|
|
x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) |
|
|
x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) |
|
|
|
|
|
mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t |
|
|
|
|
|
return mean |
|
|
|
|
|
def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: |
|
|
""" |
|
|
Compute the variance of the posterior at timestep t. |
|
|
|
|
|
Args: |
|
|
timestep: current timestep. |
|
|
predicted_variance: variance predicted by the model. |
|
|
|
|
|
Returns: |
|
|
Returns the variance |
|
|
""" |
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
|
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] |
|
|
|
|
|
if self.variance_type == DDPMVarianceType.FIXED_SMALL: |
|
|
variance = torch.clamp(variance, min=1e-20) |
|
|
elif self.variance_type == DDPMVarianceType.FIXED_LARGE: |
|
|
variance = self.betas[timestep] |
|
|
elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None: |
|
|
return predicted_variance |
|
|
elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None: |
|
|
min_log = variance |
|
|
max_log = self.betas[timestep] |
|
|
frac = (predicted_variance + 1) / 2 |
|
|
variance = frac * max_log + (1 - frac) * min_log |
|
|
|
|
|
return variance |
|
|
|
|
|
def step( |
|
|
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion |
|
|
process from the learned model outputs (most often the predicted noise). |
|
|
|
|
|
Args: |
|
|
model_output: direct output from learned diffusion model. |
|
|
timestep: current discrete timestep in the diffusion chain. |
|
|
sample: current instance of sample being created by diffusion process. |
|
|
generator: random number generator. |
|
|
|
|
|
Returns: |
|
|
pred_prev_sample: Predicted previous sample |
|
|
""" |
|
|
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: |
|
|
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) |
|
|
else: |
|
|
predicted_variance = None |
|
|
|
|
|
|
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
|
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one |
|
|
beta_prod_t = 1 - alpha_prod_t |
|
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
|
|
|
|
|
|
|
|
|
if self.prediction_type == DDPMPredictionType.EPSILON: |
|
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
|
|
elif self.prediction_type == DDPMPredictionType.SAMPLE: |
|
|
pred_original_sample = model_output |
|
|
elif self.prediction_type == DDPMPredictionType.V_PREDICTION: |
|
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output |
|
|
|
|
|
|
|
|
if self.clip_sample: |
|
|
pred_original_sample = torch.clamp( |
|
|
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t |
|
|
current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t |
|
|
|
|
|
|
|
|
|
|
|
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample |
|
|
|
|
|
|
|
|
variance: torch.Tensor = torch.tensor(0) |
|
|
if timestep > 0: |
|
|
noise = torch.randn( |
|
|
model_output.size(), |
|
|
dtype=model_output.dtype, |
|
|
layout=model_output.layout, |
|
|
generator=generator, |
|
|
device=model_output.device, |
|
|
) |
|
|
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise |
|
|
|
|
|
pred_prev_sample = pred_prev_sample + variance |
|
|
|
|
|
return pred_prev_sample, pred_original_sample |
|
|
|