|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from monai.utils import ComponentStore, unsqueeze_right |
|
|
|
|
|
NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") |
|
|
|
|
|
|
|
|
@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") |
|
|
def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): |
|
|
""" |
|
|
Linear beta noise schedule function. |
|
|
|
|
|
Args: |
|
|
num_train_timesteps: number of timesteps |
|
|
beta_start: start of beta range, default 1e-4 |
|
|
beta_end: end of beta range, default 2e-2 |
|
|
|
|
|
Returns: |
|
|
betas: beta schedule tensor |
|
|
""" |
|
|
return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) |
|
|
|
|
|
|
|
|
@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") |
|
|
def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): |
|
|
""" |
|
|
Scaled linear beta noise schedule function. |
|
|
|
|
|
Args: |
|
|
num_train_timesteps: number of timesteps |
|
|
beta_start: start of beta range, default 1e-4 |
|
|
beta_end: end of beta range, default 2e-2 |
|
|
|
|
|
Returns: |
|
|
betas: beta schedule tensor |
|
|
""" |
|
|
return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 |
|
|
|
|
|
|
|
|
@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") |
|
|
def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): |
|
|
""" |
|
|
Sigmoid beta noise schedule function. |
|
|
|
|
|
Args: |
|
|
num_train_timesteps: number of timesteps |
|
|
beta_start: start of beta range, default 1e-4 |
|
|
beta_end: end of beta range, default 2e-2 |
|
|
sig_range: pos/neg range of sigmoid input, default 6 |
|
|
|
|
|
Returns: |
|
|
betas: beta schedule tensor |
|
|
""" |
|
|
betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) |
|
|
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start |
|
|
|
|
|
|
|
|
@NoiseSchedules.add_def("cosine", "Cosine schedule") |
|
|
def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): |
|
|
""" |
|
|
Cosine noise schedule, see https://arxiv.org/abs/2102.09672 |
|
|
|
|
|
Args: |
|
|
num_train_timesteps: number of timesteps |
|
|
s: smoothing factor, default 8e-3 (see referenced paper) |
|
|
|
|
|
Returns: |
|
|
(betas, alphas, alpha_cumprod) values |
|
|
""" |
|
|
x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) |
|
|
alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 |
|
|
alphas_cumprod /= alphas_cumprod[0].item() |
|
|
betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
|
|
betas = torch.clip(betas, 0.0, 0.999) |
|
|
alphas = 1.0 - betas |
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
|
return betas, alphas, alphas_cumprod |
|
|
|
|
|
|
|
|
class Scheduler(nn.Module): |
|
|
""" |
|
|
Base class for other schedulers based on a noise schedule function. |
|
|
|
|
|
This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here |
|
|
the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, |
|
|
which is the name of a component in NoiseSchedules. These components must all be callables which return either |
|
|
the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions |
|
|
can be provided by using the NoiseSchedules.add_def, for example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
from monai.networks.schedulers import NoiseSchedules, DDPMScheduler |
|
|
|
|
|
@NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") |
|
|
def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): |
|
|
return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) |
|
|
|
|
|
scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") |
|
|
|
|
|
All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of |
|
|
timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through |
|
|
the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules |
|
|
to get a listing of stored objects with their docstring descriptions. |
|
|
|
|
|
Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule |
|
|
type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended |
|
|
to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are |
|
|
still used for some schedules but these are provided as keyword arguments now. |
|
|
|
|
|
Args: |
|
|
num_train_timesteps: number of diffusion steps used to train the model. |
|
|
schedule: member of NoiseSchedules, |
|
|
a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple |
|
|
schedule_args: arguments to pass to the schedule function |
|
|
""" |
|
|
|
|
|
def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: |
|
|
super().__init__() |
|
|
schedule_args["num_train_timesteps"] = num_train_timesteps |
|
|
noise_sched = NoiseSchedules[schedule](**schedule_args) |
|
|
|
|
|
|
|
|
if isinstance(noise_sched, tuple): |
|
|
self.betas, self.alphas, self.alphas_cumprod = noise_sched |
|
|
else: |
|
|
self.betas = noise_sched |
|
|
self.alphas = 1.0 - self.betas |
|
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
|
|
|
|
|
self.num_train_timesteps = num_train_timesteps |
|
|
self.one = torch.tensor(1.0) |
|
|
|
|
|
|
|
|
self.num_inference_steps: int | None = None |
|
|
self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) |
|
|
|
|
|
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Add noise to the original samples. |
|
|
|
|
|
Args: |
|
|
original_samples: original samples |
|
|
noise: noise to add to samples |
|
|
timesteps: timesteps tensor indicating the timestep to be computed for each sample. |
|
|
|
|
|
Returns: |
|
|
noisy_samples: sample with added noise |
|
|
""" |
|
|
|
|
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) |
|
|
timesteps = timesteps.to(original_samples.device) |
|
|
|
|
|
sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) |
|
|
sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( |
|
|
(1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim |
|
|
) |
|
|
|
|
|
noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise |
|
|
return noisy_samples |
|
|
|
|
|
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) |
|
|
timesteps = timesteps.to(sample.device) |
|
|
|
|
|
sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) |
|
|
sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( |
|
|
(1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim |
|
|
) |
|
|
|
|
|
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample |
|
|
return velocity |
|
|
|