File size: 5,725 Bytes
b386992 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from statistics import NormalDist
from typing import Callable, Tuple
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
class EDMScaling:
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise
class EDMSDE:
def __init__(
self,
p_mean: float = -1.2,
p_std: float = 1.2,
sigma_max: float = 80.0,
sigma_min: float = 0.002,
):
self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std)
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self._generator = np.random
def sample_t(self, batch_size: int) -> torch.Tensor:
cdf_vals = self._generator.uniform(size=(batch_size))
samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals]
log_sigma = torch.tensor(samples_interval_gaussian, device="cuda")
return torch.exp(log_sigma)
def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return x0, sigma
class EDMSampler(nn.Module):
"""
Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
# https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/generate.py#L25
Attributes:
None
Methods:
forward(x0_fn: Callable, x_sigma_max: torch.Tensor, num_steps: int = 35, sigma_min: float = 0.002,
sigma_max: float = 80, rho: float = 7, S_churn: float = 0, S_min: float = 0,
S_max: float = float("inf"), S_noise: float = 1) -> torch.Tensor:
Performs the forward pass for the EDM sampling process.
Parameters:
x0_fn (Callable): A function that takes in a tensor and returns a denoised tensor.
x_sigma_max (torch.Tensor): The initial noise level tensor.
num_steps (int, optional): The number of sampling steps. Default is 35.
sigma_min (float, optional): The minimum noise level. Default is 0.002.
sigma_max (float, optional): The maximum noise level. Default is 80.
rho (float, optional): The rho parameter for time step discretization. Default is 7.
S_churn (float, optional): The churn parameter for noise increase. Default is 0.
S_min (float, optional): The minimum value for the churn parameter. Default is 0.
S_max (float, optional): The maximum value for the churn parameter. Default is float("inf").
S_noise (float, optional): The noise scale for the churn parameter. Default is 1.
Returns:
torch.Tensor: The sampled tensor after the EDM process.
"""
@torch.no_grad()
def forward(
self,
x0_fn: Callable,
x_sigma_max: torch.Tensor,
num_steps: int = 35,
sigma_min: float = 0.002,
sigma_max: float = 80,
rho: float = 7,
S_churn: float = 0,
S_min: float = 0,
S_max: float = float("inf"),
S_noise: float = 1,
) -> torch.Tensor:
# Time step discretization.
in_dtype = x_sigma_max.dtype
_ones = torch.ones(x_sigma_max.shape[0], dtype=in_dtype, device=x_sigma_max.device)
step_indices = torch.arange(num_steps, dtype=torch.float64, device=x_sigma_max.device)
t_steps = (
sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = x_sigma_max.to(torch.float64)
for i, (t_cur, t_next) in enumerate(
tqdm(zip(t_steps[:-1], t_steps[1:], strict=False), total=len(t_steps) - 1)
): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = t_cur + gamma * t_cur
x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur)
# Euler step.
denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
return x_next.to(in_dtype)
|