|
|
from math import pi
|
|
|
from typing import Any, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from einops import rearrange, repeat
|
|
|
from torch import Tensor
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from .utils import default
|
|
|
|
|
|
""" Distributions """
|
|
|
|
|
|
|
|
|
class Distribution:
|
|
|
"""Interface used by different distributions"""
|
|
|
|
|
|
def __call__(self, num_samples: int, device: torch.device):
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
class UniformDistribution(Distribution):
|
|
|
def __init__(self, vmin: float = 0.0, vmax: float = 1.0):
|
|
|
super().__init__()
|
|
|
self.vmin, self.vmax = vmin, vmax
|
|
|
|
|
|
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
|
|
vmax, vmin = self.vmax, self.vmin
|
|
|
return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin
|
|
|
|
|
|
|
|
|
""" Diffusion Methods """
|
|
|
|
|
|
|
|
|
def pad_dims(x: Tensor, ndim: int) -> Tensor:
|
|
|
|
|
|
return x.view(*x.shape, *((1,) * ndim))
|
|
|
|
|
|
|
|
|
def clip(x: Tensor, dynamic_threshold: float = 0.0):
|
|
|
if dynamic_threshold == 0.0:
|
|
|
return x.clamp(-1.0, 1.0)
|
|
|
else:
|
|
|
|
|
|
|
|
|
x_flat = rearrange(x, "b ... -> b (...)")
|
|
|
scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
|
|
|
|
|
|
scale.clamp_(min=1.0)
|
|
|
|
|
|
scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
|
|
|
x = x.clamp(-scale, scale) / scale
|
|
|
return x
|
|
|
|
|
|
|
|
|
def extend_dim(x: Tensor, dim: int):
|
|
|
|
|
|
return x.view(*x.shape + (1,) * (dim - x.ndim))
|
|
|
|
|
|
|
|
|
class Diffusion(nn.Module):
|
|
|
"""Interface used by different diffusion methods"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
class VDiffusion(Diffusion):
|
|
|
def __init__(
|
|
|
self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution(), loss_fn: Any = F.mse_loss
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.net = net
|
|
|
self.sigma_distribution = sigma_distribution
|
|
|
self.loss_fn = loss_fn
|
|
|
|
|
|
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
angle = sigmas * pi / 2
|
|
|
alpha, beta = torch.cos(angle), torch.sin(angle)
|
|
|
return alpha, beta
|
|
|
|
|
|
def forward(self, x: Tensor, **kwargs) -> Tensor:
|
|
|
batch_size, device = x.shape[0], x.device
|
|
|
|
|
|
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
|
|
sigmas_batch = extend_dim(sigmas, dim=x.ndim)
|
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
|
alphas, betas = self.get_alpha_beta(sigmas_batch)
|
|
|
x_noisy = alphas * x + betas * noise
|
|
|
v_target = alphas * noise - betas * x
|
|
|
|
|
|
v_pred = self.net(x_noisy, sigmas, **kwargs)
|
|
|
return self.loss_fn(v_pred, v_target)
|
|
|
|
|
|
|
|
|
class ARVDiffusion(Diffusion):
|
|
|
def __init__(self, net: nn.Module, length: int, num_splits: int, loss_fn: Any = F.mse_loss):
|
|
|
super().__init__()
|
|
|
assert length % num_splits == 0, "length must be divisible by num_splits"
|
|
|
self.net = net
|
|
|
self.length = length
|
|
|
self.num_splits = num_splits
|
|
|
self.split_length = length // num_splits
|
|
|
self.loss_fn = loss_fn
|
|
|
|
|
|
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
angle = sigmas * pi / 2
|
|
|
alpha, beta = torch.cos(angle), torch.sin(angle)
|
|
|
return alpha, beta
|
|
|
|
|
|
def forward(self, x: Tensor, **kwargs) -> Tensor:
|
|
|
"""Returns diffusion loss of v-objective with different noises per split"""
|
|
|
b, _, t, device, dtype = *x.shape, x.device, x.dtype
|
|
|
assert t == self.length, "input length must match length"
|
|
|
|
|
|
sigmas = torch.rand((b, 1, self.num_splits), device=device, dtype=dtype)
|
|
|
sigmas = repeat(sigmas, "b 1 n -> b 1 (n l)", l=self.split_length)
|
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
|
alphas, betas = self.get_alpha_beta(sigmas)
|
|
|
x_noisy = alphas * x + betas * noise
|
|
|
v_target = alphas * noise - betas * x
|
|
|
|
|
|
channels = torch.cat([x_noisy, sigmas], dim=1)
|
|
|
|
|
|
v_pred = self.net(channels, **kwargs)
|
|
|
return self.loss_fn(v_pred, v_target)
|
|
|
|
|
|
""" Schedules """
|
|
|
|
|
|
|
|
|
class Schedule(nn.Module):
|
|
|
"""Interface used by different sampling schedules"""
|
|
|
|
|
|
def forward(self, num_steps: int, device: torch.device) -> Tensor:
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
class LinearSchedule(Schedule):
|
|
|
def __init__(self, start: float = 1.0, end: float = 0.0):
|
|
|
super().__init__()
|
|
|
self.start, self.end = start, end
|
|
|
|
|
|
def forward(self, num_steps: int, device: Any) -> Tensor:
|
|
|
return torch.linspace(self.start, self.end, num_steps, device=device)
|
|
|
|
|
|
|
|
|
""" Samplers """
|
|
|
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
|
pass
|
|
|
|
|
|
|
|
|
class VSampler(Sampler):
|
|
|
|
|
|
diffusion_types = [VDiffusion]
|
|
|
|
|
|
def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
|
|
|
super().__init__()
|
|
|
self.net = net
|
|
|
self.schedule = schedule
|
|
|
|
|
|
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
angle = sigmas * pi / 2
|
|
|
alpha, beta = torch.cos(angle), torch.sin(angle)
|
|
|
return alpha, beta
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def forward(
|
|
|
self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
|
|
|
) -> Tensor:
|
|
|
b = x_noisy.shape[0]
|
|
|
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
|
|
|
sigmas = repeat(sigmas, "i -> i b", b=b)
|
|
|
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
|
|
|
alphas, betas = self.get_alpha_beta(sigmas_batch)
|
|
|
progress_bar = tqdm(range(num_steps), disable=not show_progress)
|
|
|
|
|
|
for i in progress_bar:
|
|
|
v_pred = self.net(x_noisy, sigmas[i], **kwargs)
|
|
|
x_pred = alphas[i] * x_noisy - betas[i] * v_pred
|
|
|
noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
|
|
|
x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
|
|
|
progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})")
|
|
|
|
|
|
return x_noisy
|
|
|
|
|
|
|
|
|
class ARVSampler(Sampler):
|
|
|
def __init__(self, net: nn.Module, in_channels: int, length: int, num_splits: int):
|
|
|
super().__init__()
|
|
|
assert length % num_splits == 0, "length must be divisible by num_splits"
|
|
|
self.length = length
|
|
|
self.in_channels = in_channels
|
|
|
self.num_splits = num_splits
|
|
|
self.split_length = length // num_splits
|
|
|
self.net = net
|
|
|
|
|
|
@property
|
|
|
def device(self):
|
|
|
return next(self.net.parameters()).device
|
|
|
|
|
|
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
angle = sigmas * pi / 2
|
|
|
alpha = torch.cos(angle)
|
|
|
beta = torch.sin(angle)
|
|
|
return alpha, beta
|
|
|
|
|
|
def get_sigmas_ladder(self, num_items: int, num_steps_per_split: int) -> Tensor:
|
|
|
b, n, l, i = num_items, self.num_splits, self.split_length, num_steps_per_split
|
|
|
n_half = n // 2
|
|
|
sigmas = torch.linspace(1, 0, i * n_half, device=self.device)
|
|
|
sigmas = repeat(sigmas, "(n i) -> i b 1 (n l)", b=b, l=l, n=n_half)
|
|
|
sigmas = torch.flip(sigmas, dims=[-1])
|
|
|
sigmas = F.pad(sigmas, pad=[0, 0, 0, 0, 0, 0, 0, 1])
|
|
|
sigmas[-1, :, :, l:] = sigmas[0, :, :, :-l]
|
|
|
return torch.cat([torch.zeros_like(sigmas), sigmas], dim=-1)
|
|
|
|
|
|
def sample_loop(
|
|
|
self, current: Tensor, sigmas: Tensor, show_progress: bool = False, **kwargs
|
|
|
) -> Tensor:
|
|
|
num_steps = sigmas.shape[0] - 1
|
|
|
alphas, betas = self.get_alpha_beta(sigmas)
|
|
|
progress_bar = tqdm(range(num_steps), disable=not show_progress)
|
|
|
|
|
|
for i in progress_bar:
|
|
|
channels = torch.cat([current, sigmas[i]], dim=1)
|
|
|
v_pred = self.net(channels, **kwargs)
|
|
|
x_pred = alphas[i] * current - betas[i] * v_pred
|
|
|
noise_pred = betas[i] * current + alphas[i] * v_pred
|
|
|
current = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
|
|
|
progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0,0,0]:.2f})")
|
|
|
|
|
|
return current
|
|
|
|
|
|
def sample_start(self, num_items: int, num_steps: int, **kwargs) -> Tensor:
|
|
|
b, c, t = num_items, self.in_channels, self.length
|
|
|
|
|
|
sigmas = torch.linspace(1, 0, num_steps + 1, device=self.device)
|
|
|
sigmas = repeat(sigmas, "i -> i b 1 t", b=b, t=t)
|
|
|
noise = torch.randn((b, c, t), device=self.device) * sigmas[0]
|
|
|
|
|
|
return self.sample_loop(current=noise, sigmas=sigmas, **kwargs)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def forward(
|
|
|
self,
|
|
|
num_items: int,
|
|
|
num_chunks: int,
|
|
|
num_steps: int,
|
|
|
start: Optional[Tensor] = None,
|
|
|
show_progress: bool = False,
|
|
|
**kwargs,
|
|
|
) -> Tensor:
|
|
|
assert_message = f"required at least {self.num_splits} chunks"
|
|
|
assert num_chunks >= self.num_splits, assert_message
|
|
|
|
|
|
|
|
|
start = self.sample_start(num_items=num_items, num_steps=num_steps, **kwargs)
|
|
|
|
|
|
if num_chunks == self.num_splits:
|
|
|
return start
|
|
|
|
|
|
|
|
|
b, n = num_items, self.num_splits
|
|
|
assert num_steps >= n, "num_steps must be greater than num_splits"
|
|
|
sigmas = self.get_sigmas_ladder(
|
|
|
num_items=b,
|
|
|
num_steps_per_split=num_steps // self.num_splits,
|
|
|
)
|
|
|
alphas, betas = self.get_alpha_beta(sigmas)
|
|
|
|
|
|
|
|
|
start_noise = alphas[0] * start + betas[0] * torch.randn_like(start)
|
|
|
chunks = list(start_noise.chunk(chunks=n, dim=-1))
|
|
|
|
|
|
|
|
|
num_shifts = num_chunks
|
|
|
progress_bar = tqdm(range(num_shifts), disable=not show_progress)
|
|
|
|
|
|
for j in progress_bar:
|
|
|
|
|
|
updated = self.sample_loop(
|
|
|
current=torch.cat(chunks[-n:], dim=-1), sigmas=sigmas, **kwargs
|
|
|
)
|
|
|
|
|
|
chunks[-n:] = list(updated.chunk(chunks=n, dim=-1))
|
|
|
|
|
|
shape = (b, self.in_channels, self.split_length)
|
|
|
chunks += [torch.randn(shape, device=self.device)]
|
|
|
|
|
|
return torch.cat(chunks[:num_chunks], dim=-1)
|
|
|
|
|
|
|
|
|
""" Inpainters """
|
|
|
|
|
|
|
|
|
class Inpainter(nn.Module):
|
|
|
pass
|
|
|
|
|
|
|
|
|
class VInpainter(Inpainter):
|
|
|
|
|
|
diffusion_types = [VDiffusion]
|
|
|
|
|
|
def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
|
|
|
super().__init__()
|
|
|
self.net = net
|
|
|
self.schedule = schedule
|
|
|
|
|
|
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
|
|
angle = sigmas * pi / 2
|
|
|
alpha, beta = torch.cos(angle), torch.sin(angle)
|
|
|
return alpha, beta
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def forward(
|
|
|
self,
|
|
|
source: Tensor,
|
|
|
mask: Tensor,
|
|
|
num_steps: int,
|
|
|
num_resamples: int,
|
|
|
show_progress: bool = False,
|
|
|
x_noisy: Optional[Tensor] = None,
|
|
|
**kwargs,
|
|
|
) -> Tensor:
|
|
|
x_noisy = default(x_noisy, lambda: torch.randn_like(source))
|
|
|
b = x_noisy.shape[0]
|
|
|
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
|
|
|
sigmas = repeat(sigmas, "i -> i b", b=b)
|
|
|
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
|
|
|
alphas, betas = self.get_alpha_beta(sigmas_batch)
|
|
|
progress_bar = tqdm(range(num_steps), disable=not show_progress)
|
|
|
|
|
|
for i in progress_bar:
|
|
|
for r in range(num_resamples):
|
|
|
v_pred = self.net(x_noisy, sigmas[i], **kwargs)
|
|
|
x_pred = alphas[i] * x_noisy - betas[i] * v_pred
|
|
|
noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
|
|
|
|
|
|
j = r == num_resamples - 1
|
|
|
x_noisy = alphas[i + j] * x_pred + betas[i + j] * noise_pred
|
|
|
s_noisy = alphas[i + j] * source + betas[i + j] * torch.randn_like(
|
|
|
source
|
|
|
)
|
|
|
x_noisy = s_noisy * mask + x_noisy * ~mask
|
|
|
|
|
|
progress_bar.set_description(f"Inpainting (noise={sigmas[i+1,0]:.2f})")
|
|
|
|
|
|
return x_noisy
|
|
|
|