Seemanth's picture
Upload Chiluka TTS model
f28049f verified
"""Diffusion sampling classes."""
from math import atan, cos, pi, sin, sqrt
from typing import Any, Callable, List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from torch import Tensor
from .utils import exists, default
class Distribution:
def __call__(self, num_samples: int, device: torch.device):
raise NotImplementedError()
class LogNormalDistribution(Distribution):
def __init__(self, mean: float, std: float):
self.mean = mean
self.std = std
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")) -> Tensor:
normal = self.mean + self.std * torch.randn((num_samples,), device=device)
return normal.exp()
class UniformDistribution(Distribution):
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
return torch.rand(num_samples, device=device)
def to_batch(batch_size: int, device: torch.device, x: Optional[float] = None, xs: Optional[Tensor] = None) -> Tensor:
assert exists(x) ^ exists(xs), "Either x or xs must be provided"
if exists(x):
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
assert exists(xs)
return xs
class Diffusion(nn.Module):
alias: str = ""
def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor:
raise NotImplementedError("Diffusion class missing denoise_fn")
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
raise NotImplementedError("Diffusion class missing forward function")
class KDiffusion(Diffusion):
"""Elucidated Diffusion (Karras et al. 2022)"""
alias = "k"
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution, sigma_data: float, dynamic_threshold: float = 0.0):
super().__init__()
self.net = net
self.sigma_data = sigma_data
self.sigma_distribution = sigma_distribution
self.dynamic_threshold = dynamic_threshold
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
sigma_data = self.sigma_data
c_noise = torch.log(sigmas) * 0.25
sigmas = rearrange(sigmas, "b -> b 1 1")
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
return c_skip, c_out, c_in, c_noise
def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor:
batch_size, device = x_noisy.shape[0], x_noisy.device
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
x_denoised = c_skip * x_noisy + c_out * x_pred
return x_denoised
def loss_weight(self, sigmas: Tensor) -> Tensor:
return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
batch_size, device = x.shape[0], x.device
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
noise = default(noise, lambda: torch.randn_like(x))
x_noisy = x + sigmas_padded * noise
x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
losses = F.mse_loss(x_denoised, x, reduction="none")
losses = reduce(losses, "b ... -> b", "mean")
losses = losses * self.loss_weight(sigmas)
return losses.mean()
class Schedule(nn.Module):
def forward(self, num_steps: int, device: torch.device) -> Tensor:
raise NotImplementedError()
class KarrasSchedule(Schedule):
def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
def forward(self, num_steps: int, device: Any) -> Tensor:
rho_inv = 1.0 / self.rho
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
sigmas = (
self.sigma_max ** rho_inv
+ (steps / (num_steps - 1))
* (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
) ** self.rho
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
return sigmas
class Sampler(nn.Module):
diffusion_types: List[Type[Diffusion]] = []
def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor:
raise NotImplementedError()
class ADPM2Sampler(Sampler):
diffusion_types = [KDiffusion]
def __init__(self, rho: float = 1.0):
super().__init__()
self.rho = rho
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
r = self.rho
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
return sigma_up, sigma_down, sigma_mid
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
d = (x - fn(x, sigma=sigma)) / sigma
x_mid = x + d * (sigma_mid - sigma)
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
x = x + d_mid * (sigma_down - sigma)
x_next = x + torch.randn_like(x) * sigma_up
return x_next
def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor:
x = sigmas[0] * noise
for i in range(num_steps - 1):
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1])
return x
class DiffusionSampler(nn.Module):
def __init__(self, diffusion: Diffusion, *, sampler: Sampler, sigma_schedule: Schedule, num_steps: Optional[int] = None, clamp: bool = True):
super().__init__()
self.denoise_fn = diffusion.denoise_fn
self.sampler = sampler
self.sigma_schedule = sigma_schedule
self.num_steps = num_steps
self.clamp = clamp
def forward(self, noise: Tensor, num_steps: Optional[int] = None, **kwargs) -> Tensor:
device = noise.device
num_steps = default(num_steps, self.num_steps)
assert exists(num_steps), "Parameter `num_steps` must be provided"
sigmas = self.sigma_schedule(num_steps, device)
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs})
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
x = x.clamp(-1.0, 1.0) if self.clamp else x
return x