| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from statistics import NormalDist |
| | from typing import Callable, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | from tqdm import tqdm |
| |
|
| |
|
| | class EDMScaling: |
| | def __init__(self, sigma_data: float = 0.5): |
| | self.sigma_data = sigma_data |
| |
|
| | def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) |
| | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 |
| | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 |
| | c_noise = 0.25 * sigma.log() |
| | return c_skip, c_out, c_in, c_noise |
| |
|
| |
|
| | class EDMSDE: |
| | def __init__( |
| | self, |
| | p_mean: float = -1.2, |
| | p_std: float = 1.2, |
| | sigma_max: float = 80.0, |
| | sigma_min: float = 0.002, |
| | ): |
| | self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) |
| | self.sigma_max = sigma_max |
| | self.sigma_min = sigma_min |
| | self._generator = np.random |
| |
|
| | def sample_t(self, batch_size: int) -> torch.Tensor: |
| | cdf_vals = self._generator.uniform(size=(batch_size)) |
| | samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] |
| | log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") |
| | return torch.exp(log_sigma) |
| |
|
| | def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | return x0, sigma |
| |
|
| |
|
| | class EDMSampler(nn.Module): |
| | """ |
| | Elucidating the Design Space of Diffusion-Based Generative Models (EDM) |
| | # https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/generate.py#L25 |
| | |
| | Attributes: |
| | None |
| | |
| | Methods: |
| | forward(x0_fn: Callable, x_sigma_max: torch.Tensor, num_steps: int = 35, sigma_min: float = 0.002, |
| | sigma_max: float = 80, rho: float = 7, S_churn: float = 0, S_min: float = 0, |
| | S_max: float = float("inf"), S_noise: float = 1) -> torch.Tensor: |
| | Performs the forward pass for the EDM sampling process. |
| | |
| | Parameters: |
| | x0_fn (Callable): A function that takes in a tensor and returns a denoised tensor. |
| | x_sigma_max (torch.Tensor): The initial noise level tensor. |
| | num_steps (int, optional): The number of sampling steps. Default is 35. |
| | sigma_min (float, optional): The minimum noise level. Default is 0.002. |
| | sigma_max (float, optional): The maximum noise level. Default is 80. |
| | rho (float, optional): The rho parameter for time step discretization. Default is 7. |
| | S_churn (float, optional): The churn parameter for noise increase. Default is 0. |
| | S_min (float, optional): The minimum value for the churn parameter. Default is 0. |
| | S_max (float, optional): The maximum value for the churn parameter. Default is float("inf"). |
| | S_noise (float, optional): The noise scale for the churn parameter. Default is 1. |
| | |
| | Returns: |
| | torch.Tensor: The sampled tensor after the EDM process. |
| | """ |
| |
|
| | @torch.no_grad() |
| | def forward( |
| | self, |
| | x0_fn: Callable, |
| | x_sigma_max: torch.Tensor, |
| | num_steps: int = 35, |
| | sigma_min: float = 0.002, |
| | sigma_max: float = 80, |
| | rho: float = 7, |
| | S_churn: float = 0, |
| | S_min: float = 0, |
| | S_max: float = float("inf"), |
| | S_noise: float = 1, |
| | ) -> torch.Tensor: |
| | |
| | in_dtype = x_sigma_max.dtype |
| | _ones = torch.ones(x_sigma_max.shape[0], dtype=in_dtype, device=x_sigma_max.device) |
| | step_indices = torch.arange(num_steps, dtype=torch.float64, device=x_sigma_max.device) |
| | t_steps = ( |
| | sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) |
| | ) ** rho |
| | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) |
| |
|
| | |
| | x_next = x_sigma_max.to(torch.float64) |
| | for i, (t_cur, t_next) in enumerate( |
| | tqdm(zip(t_steps[:-1], t_steps[1:], strict=False), total=len(t_steps) - 1) |
| | ): |
| | x_cur = x_next |
| |
|
| | |
| | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 |
| | t_hat = t_cur + gamma * t_cur |
| | x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur) |
| |
|
| | |
| | denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64) |
| | d_cur = (x_hat - denoised) / t_hat |
| | x_next = x_hat + (t_next - t_hat) * d_cur |
| |
|
| | |
| | if i < num_steps - 1: |
| | denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64) |
| | d_prime = (x_next - denoised) / t_next |
| | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) |
| |
|
| | return x_next.to(in_dtype) |
| |
|