Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from statistics import NormalDist
from typing import Callable, Tuple
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
class EDMScaling:
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise
class EDMSDE:
def __init__(
self,
p_mean: float = -1.2,
p_std: float = 1.2,
sigma_max: float = 80.0,
sigma_min: float = 0.002,
):
self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std)
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self._generator = np.random
def sample_t(self, batch_size: int) -> torch.Tensor:
cdf_vals = self._generator.uniform(size=(batch_size))
samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals]
log_sigma = torch.tensor(samples_interval_gaussian, device="cuda")
return torch.exp(log_sigma)
def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return x0, sigma
class EDMSampler(nn.Module):
"""
Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
# https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/generate.py#L25
Attributes:
None
Methods:
forward(x0_fn: Callable, x_sigma_max: torch.Tensor, num_steps: int = 35, sigma_min: float = 0.002,
sigma_max: float = 80, rho: float = 7, S_churn: float = 0, S_min: float = 0,
S_max: float = float("inf"), S_noise: float = 1) -> torch.Tensor:
Performs the forward pass for the EDM sampling process.
Parameters:
x0_fn (Callable): A function that takes in a tensor and returns a denoised tensor.
x_sigma_max (torch.Tensor): The initial noise level tensor.
num_steps (int, optional): The number of sampling steps. Default is 35.
sigma_min (float, optional): The minimum noise level. Default is 0.002.
sigma_max (float, optional): The maximum noise level. Default is 80.
rho (float, optional): The rho parameter for time step discretization. Default is 7.
S_churn (float, optional): The churn parameter for noise increase. Default is 0.
S_min (float, optional): The minimum value for the churn parameter. Default is 0.
S_max (float, optional): The maximum value for the churn parameter. Default is float("inf").
S_noise (float, optional): The noise scale for the churn parameter. Default is 1.
Returns:
torch.Tensor: The sampled tensor after the EDM process.
"""
@torch.no_grad()
def forward(
self,
x0_fn: Callable,
x_sigma_max: torch.Tensor,
num_steps: int = 35,
sigma_min: float = 0.002,
sigma_max: float = 80,
rho: float = 7,
S_churn: float = 0,
S_min: float = 0,
S_max: float = float("inf"),
S_noise: float = 1,
) -> torch.Tensor:
# Time step discretization.
in_dtype = x_sigma_max.dtype
_ones = torch.ones(x_sigma_max.shape[0], dtype=in_dtype, device=x_sigma_max.device)
step_indices = torch.arange(num_steps, dtype=torch.float64, device=x_sigma_max.device)
t_steps = (
sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = x_sigma_max.to(torch.float64)
for i, (t_cur, t_next) in enumerate(
tqdm(zip(t_steps[:-1], t_steps[1:], strict=False), total=len(t_steps) - 1)
): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = t_cur + gamma * t_cur
x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur)
# Euler step.
denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
return x_next.to(in_dtype)