File size: 6,970 Bytes
f28049f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
"""Diffusion sampling classes."""
from math import atan, cos, pi, sin, sqrt
from typing import Any, Callable, List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from torch import Tensor
from .utils import exists, default
class Distribution:
def __call__(self, num_samples: int, device: torch.device):
raise NotImplementedError()
class LogNormalDistribution(Distribution):
def __init__(self, mean: float, std: float):
self.mean = mean
self.std = std
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")) -> Tensor:
normal = self.mean + self.std * torch.randn((num_samples,), device=device)
return normal.exp()
class UniformDistribution(Distribution):
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
return torch.rand(num_samples, device=device)
def to_batch(batch_size: int, device: torch.device, x: Optional[float] = None, xs: Optional[Tensor] = None) -> Tensor:
assert exists(x) ^ exists(xs), "Either x or xs must be provided"
if exists(x):
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
assert exists(xs)
return xs
class Diffusion(nn.Module):
alias: str = ""
def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor:
raise NotImplementedError("Diffusion class missing denoise_fn")
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
raise NotImplementedError("Diffusion class missing forward function")
class KDiffusion(Diffusion):
"""Elucidated Diffusion (Karras et al. 2022)"""
alias = "k"
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution, sigma_data: float, dynamic_threshold: float = 0.0):
super().__init__()
self.net = net
self.sigma_data = sigma_data
self.sigma_distribution = sigma_distribution
self.dynamic_threshold = dynamic_threshold
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
sigma_data = self.sigma_data
c_noise = torch.log(sigmas) * 0.25
sigmas = rearrange(sigmas, "b -> b 1 1")
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
return c_skip, c_out, c_in, c_noise
def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor:
batch_size, device = x_noisy.shape[0], x_noisy.device
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
x_denoised = c_skip * x_noisy + c_out * x_pred
return x_denoised
def loss_weight(self, sigmas: Tensor) -> Tensor:
return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
batch_size, device = x.shape[0], x.device
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
noise = default(noise, lambda: torch.randn_like(x))
x_noisy = x + sigmas_padded * noise
x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
losses = F.mse_loss(x_denoised, x, reduction="none")
losses = reduce(losses, "b ... -> b", "mean")
losses = losses * self.loss_weight(sigmas)
return losses.mean()
class Schedule(nn.Module):
def forward(self, num_steps: int, device: torch.device) -> Tensor:
raise NotImplementedError()
class KarrasSchedule(Schedule):
def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
def forward(self, num_steps: int, device: Any) -> Tensor:
rho_inv = 1.0 / self.rho
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
sigmas = (
self.sigma_max ** rho_inv
+ (steps / (num_steps - 1))
* (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
) ** self.rho
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
return sigmas
class Sampler(nn.Module):
diffusion_types: List[Type[Diffusion]] = []
def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor:
raise NotImplementedError()
class ADPM2Sampler(Sampler):
diffusion_types = [KDiffusion]
def __init__(self, rho: float = 1.0):
super().__init__()
self.rho = rho
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
r = self.rho
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
return sigma_up, sigma_down, sigma_mid
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
d = (x - fn(x, sigma=sigma)) / sigma
x_mid = x + d * (sigma_mid - sigma)
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
x = x + d_mid * (sigma_down - sigma)
x_next = x + torch.randn_like(x) * sigma_up
return x_next
def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor:
x = sigmas[0] * noise
for i in range(num_steps - 1):
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1])
return x
class DiffusionSampler(nn.Module):
def __init__(self, diffusion: Diffusion, *, sampler: Sampler, sigma_schedule: Schedule, num_steps: Optional[int] = None, clamp: bool = True):
super().__init__()
self.denoise_fn = diffusion.denoise_fn
self.sampler = sampler
self.sigma_schedule = sigma_schedule
self.num_steps = num_steps
self.clamp = clamp
def forward(self, noise: Tensor, num_steps: Optional[int] = None, **kwargs) -> Tensor:
device = noise.device
num_steps = default(num_steps, self.num_steps)
assert exists(num_steps), "Parameter `num_steps` must be provided"
sigmas = self.sigma_schedule(num_steps, device)
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs})
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
x = x.clamp(-1.0, 1.0) if self.clamp else x
return x
|