"""Diffusion sampling classes.""" from math import atan, cos, pi, sin, sqrt from typing import Any, Callable, List, Optional, Tuple, Type import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, reduce from torch import Tensor from .utils import exists, default class Distribution: def __call__(self, num_samples: int, device: torch.device): raise NotImplementedError() class LogNormalDistribution(Distribution): def __init__(self, mean: float, std: float): self.mean = mean self.std = std def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")) -> Tensor: normal = self.mean + self.std * torch.randn((num_samples,), device=device) return normal.exp() class UniformDistribution(Distribution): def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): return torch.rand(num_samples, device=device) def to_batch(batch_size: int, device: torch.device, x: Optional[float] = None, xs: Optional[Tensor] = None) -> Tensor: assert exists(x) ^ exists(xs), "Either x or xs must be provided" if exists(x): xs = torch.full(size=(batch_size,), fill_value=x).to(device) assert exists(xs) return xs class Diffusion(nn.Module): alias: str = "" def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor: raise NotImplementedError("Diffusion class missing denoise_fn") def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: raise NotImplementedError("Diffusion class missing forward function") class KDiffusion(Diffusion): """Elucidated Diffusion (Karras et al. 2022)""" alias = "k" def __init__(self, net: nn.Module, *, sigma_distribution: Distribution, sigma_data: float, dynamic_threshold: float = 0.0): super().__init__() self.net = net self.sigma_data = sigma_data self.sigma_distribution = sigma_distribution self.dynamic_threshold = dynamic_threshold def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: sigma_data = self.sigma_data c_noise = torch.log(sigmas) * 0.25 sigmas = rearrange(sigmas, "b -> b 1 1") c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2) c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5 c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5 return c_skip, c_out, c_in, c_noise def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor: batch_size, device = x_noisy.shape[0], x_noisy.device sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) x_denoised = c_skip * x_noisy + c_out * x_pred return x_denoised def loss_weight(self, sigmas: Tensor) -> Tensor: return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2 def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: batch_size, device = x.shape[0], x.device sigmas = self.sigma_distribution(num_samples=batch_size, device=device) sigmas_padded = rearrange(sigmas, "b -> b 1 1") noise = default(noise, lambda: torch.randn_like(x)) x_noisy = x + sigmas_padded * noise x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) losses = F.mse_loss(x_denoised, x, reduction="none") losses = reduce(losses, "b ... -> b", "mean") losses = losses * self.loss_weight(sigmas) return losses.mean() class Schedule(nn.Module): def forward(self, num_steps: int, device: torch.device) -> Tensor: raise NotImplementedError() class KarrasSchedule(Schedule): def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0): super().__init__() self.sigma_min = sigma_min self.sigma_max = sigma_max self.rho = rho def forward(self, num_steps: int, device: Any) -> Tensor: rho_inv = 1.0 / self.rho steps = torch.arange(num_steps, device=device, dtype=torch.float32) sigmas = ( self.sigma_max ** rho_inv + (steps / (num_steps - 1)) * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv) ) ** self.rho sigmas = F.pad(sigmas, pad=(0, 1), value=0.0) return sigmas class Sampler(nn.Module): diffusion_types: List[Type[Diffusion]] = [] def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor: raise NotImplementedError() class ADPM2Sampler(Sampler): diffusion_types = [KDiffusion] def __init__(self, rho: float = 1.0): super().__init__() self.rho = rho def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]: r = self.rho sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2) sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r return sigma_up, sigma_down, sigma_mid def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next) d = (x - fn(x, sigma=sigma)) / sigma x_mid = x + d * (sigma_mid - sigma) d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid x = x + d_mid * (sigma_down - sigma) x_next = x + torch.randn_like(x) * sigma_up return x_next def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor: x = sigmas[0] * noise for i in range(num_steps - 1): x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) return x class DiffusionSampler(nn.Module): def __init__(self, diffusion: Diffusion, *, sampler: Sampler, sigma_schedule: Schedule, num_steps: Optional[int] = None, clamp: bool = True): super().__init__() self.denoise_fn = diffusion.denoise_fn self.sampler = sampler self.sigma_schedule = sigma_schedule self.num_steps = num_steps self.clamp = clamp def forward(self, noise: Tensor, num_steps: Optional[int] = None, **kwargs) -> Tensor: device = noise.device num_steps = default(num_steps, self.num_steps) assert exists(num_steps), "Parameter `num_steps` must be provided" sigmas = self.sigma_schedule(num_steps, device) fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps) x = x.clamp(-1.0, 1.0) if self.clamp else x return x