| """Diffusion sampling classes.""" | |
| from math import atan, cos, pi, sin, sqrt | |
| from typing import Any, Callable, List, Optional, Tuple, Type | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, reduce | |
| from torch import Tensor | |
| from .utils import exists, default | |
| class Distribution: | |
| def __call__(self, num_samples: int, device: torch.device): | |
| raise NotImplementedError() | |
| class LogNormalDistribution(Distribution): | |
| def __init__(self, mean: float, std: float): | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")) -> Tensor: | |
| normal = self.mean + self.std * torch.randn((num_samples,), device=device) | |
| return normal.exp() | |
| class UniformDistribution(Distribution): | |
| def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): | |
| return torch.rand(num_samples, device=device) | |
| def to_batch(batch_size: int, device: torch.device, x: Optional[float] = None, xs: Optional[Tensor] = None) -> Tensor: | |
| assert exists(x) ^ exists(xs), "Either x or xs must be provided" | |
| if exists(x): | |
| xs = torch.full(size=(batch_size,), fill_value=x).to(device) | |
| assert exists(xs) | |
| return xs | |
| class Diffusion(nn.Module): | |
| alias: str = "" | |
| def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor: | |
| raise NotImplementedError("Diffusion class missing denoise_fn") | |
| def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: | |
| raise NotImplementedError("Diffusion class missing forward function") | |
| class KDiffusion(Diffusion): | |
| """Elucidated Diffusion (Karras et al. 2022)""" | |
| alias = "k" | |
| def __init__(self, net: nn.Module, *, sigma_distribution: Distribution, sigma_data: float, dynamic_threshold: float = 0.0): | |
| super().__init__() | |
| self.net = net | |
| self.sigma_data = sigma_data | |
| self.sigma_distribution = sigma_distribution | |
| self.dynamic_threshold = dynamic_threshold | |
| def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: | |
| sigma_data = self.sigma_data | |
| c_noise = torch.log(sigmas) * 0.25 | |
| sigmas = rearrange(sigmas, "b -> b 1 1") | |
| c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2) | |
| c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5 | |
| c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5 | |
| return c_skip, c_out, c_in, c_noise | |
| def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor: | |
| batch_size, device = x_noisy.shape[0], x_noisy.device | |
| sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) | |
| c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) | |
| x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) | |
| x_denoised = c_skip * x_noisy + c_out * x_pred | |
| return x_denoised | |
| def loss_weight(self, sigmas: Tensor) -> Tensor: | |
| return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2 | |
| def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: | |
| batch_size, device = x.shape[0], x.device | |
| sigmas = self.sigma_distribution(num_samples=batch_size, device=device) | |
| sigmas_padded = rearrange(sigmas, "b -> b 1 1") | |
| noise = default(noise, lambda: torch.randn_like(x)) | |
| x_noisy = x + sigmas_padded * noise | |
| x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) | |
| losses = F.mse_loss(x_denoised, x, reduction="none") | |
| losses = reduce(losses, "b ... -> b", "mean") | |
| losses = losses * self.loss_weight(sigmas) | |
| return losses.mean() | |
| class Schedule(nn.Module): | |
| def forward(self, num_steps: int, device: torch.device) -> Tensor: | |
| raise NotImplementedError() | |
| class KarrasSchedule(Schedule): | |
| def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0): | |
| super().__init__() | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.rho = rho | |
| def forward(self, num_steps: int, device: Any) -> Tensor: | |
| rho_inv = 1.0 / self.rho | |
| steps = torch.arange(num_steps, device=device, dtype=torch.float32) | |
| sigmas = ( | |
| self.sigma_max ** rho_inv | |
| + (steps / (num_steps - 1)) | |
| * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv) | |
| ) ** self.rho | |
| sigmas = F.pad(sigmas, pad=(0, 1), value=0.0) | |
| return sigmas | |
| class Sampler(nn.Module): | |
| diffusion_types: List[Type[Diffusion]] = [] | |
| def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor: | |
| raise NotImplementedError() | |
| class ADPM2Sampler(Sampler): | |
| diffusion_types = [KDiffusion] | |
| def __init__(self, rho: float = 1.0): | |
| super().__init__() | |
| self.rho = rho | |
| def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]: | |
| r = self.rho | |
| sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) | |
| sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2) | |
| sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r | |
| return sigma_up, sigma_down, sigma_mid | |
| def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: | |
| sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next) | |
| d = (x - fn(x, sigma=sigma)) / sigma | |
| x_mid = x + d * (sigma_mid - sigma) | |
| d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid | |
| x = x + d_mid * (sigma_down - sigma) | |
| x_next = x + torch.randn_like(x) * sigma_up | |
| return x_next | |
| def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor: | |
| x = sigmas[0] * noise | |
| for i in range(num_steps - 1): | |
| x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) | |
| return x | |
| class DiffusionSampler(nn.Module): | |
| def __init__(self, diffusion: Diffusion, *, sampler: Sampler, sigma_schedule: Schedule, num_steps: Optional[int] = None, clamp: bool = True): | |
| super().__init__() | |
| self.denoise_fn = diffusion.denoise_fn | |
| self.sampler = sampler | |
| self.sigma_schedule = sigma_schedule | |
| self.num_steps = num_steps | |
| self.clamp = clamp | |
| def forward(self, noise: Tensor, num_steps: Optional[int] = None, **kwargs) -> Tensor: | |
| device = noise.device | |
| num_steps = default(num_steps, self.num_steps) | |
| assert exists(num_steps), "Parameter `num_steps` must be provided" | |
| sigmas = self.sigma_schedule(num_steps, device) | |
| fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) | |
| x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps) | |
| x = x.clamp(-1.0, 1.0) if self.clamp else x | |
| return x | |