File size: 5,725 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from statistics import NormalDist
from typing import Callable, Tuple

import numpy as np
import torch
from torch import nn
from tqdm import tqdm


class EDMScaling:
    def __init__(self, sigma_data: float = 0.5):
        self.sigma_data = sigma_data

    def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
        c_noise = 0.25 * sigma.log()
        return c_skip, c_out, c_in, c_noise


class EDMSDE:
    def __init__(
        self,
        p_mean: float = -1.2,
        p_std: float = 1.2,
        sigma_max: float = 80.0,
        sigma_min: float = 0.002,
    ):
        self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std)
        self.sigma_max = sigma_max
        self.sigma_min = sigma_min
        self._generator = np.random

    def sample_t(self, batch_size: int) -> torch.Tensor:
        cdf_vals = self._generator.uniform(size=(batch_size))
        samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals]
        log_sigma = torch.tensor(samples_interval_gaussian, device="cuda")
        return torch.exp(log_sigma)

    def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        return x0, sigma


class EDMSampler(nn.Module):
    """
    Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
    # https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/generate.py#L25

    Attributes:
        None

    Methods:
        forward(x0_fn: Callable, x_sigma_max: torch.Tensor, num_steps: int = 35, sigma_min: float = 0.002,
                sigma_max: float = 80, rho: float = 7, S_churn: float = 0, S_min: float = 0,
                S_max: float = float("inf"), S_noise: float = 1) -> torch.Tensor:
            Performs the forward pass for the EDM sampling process.

            Parameters:
                x0_fn (Callable): A function that takes in a tensor and returns a denoised tensor.
                x_sigma_max (torch.Tensor): The initial noise level tensor.
                num_steps (int, optional): The number of sampling steps. Default is 35.
                sigma_min (float, optional): The minimum noise level. Default is 0.002.
                sigma_max (float, optional): The maximum noise level. Default is 80.
                rho (float, optional): The rho parameter for time step discretization. Default is 7.
                S_churn (float, optional): The churn parameter for noise increase. Default is 0.
                S_min (float, optional): The minimum value for the churn parameter. Default is 0.
                S_max (float, optional): The maximum value for the churn parameter. Default is float("inf").
                S_noise (float, optional): The noise scale for the churn parameter. Default is 1.

            Returns:
                torch.Tensor: The sampled tensor after the EDM process.
    """

    @torch.no_grad()
    def forward(
        self,
        x0_fn: Callable,
        x_sigma_max: torch.Tensor,
        num_steps: int = 35,
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        S_churn: float = 0,
        S_min: float = 0,
        S_max: float = float("inf"),
        S_noise: float = 1,
    ) -> torch.Tensor:
        # Time step discretization.
        in_dtype = x_sigma_max.dtype
        _ones = torch.ones(x_sigma_max.shape[0], dtype=in_dtype, device=x_sigma_max.device)
        step_indices = torch.arange(num_steps, dtype=torch.float64, device=x_sigma_max.device)
        t_steps = (
            sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
        ) ** rho
        t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])])  # t_N = 0

        # Main sampling loop.
        x_next = x_sigma_max.to(torch.float64)
        for i, (t_cur, t_next) in enumerate(
            tqdm(zip(t_steps[:-1], t_steps[1:], strict=False), total=len(t_steps) - 1)
        ):  # 0, ..., N-1
            x_cur = x_next

            # Increase noise temporarily.
            gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
            t_hat = t_cur + gamma * t_cur
            x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur)

            # Euler step.
            denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64)
            d_cur = (x_hat - denoised) / t_hat
            x_next = x_hat + (t_next - t_hat) * d_cur

            # Apply 2nd order correction.
            if i < num_steps - 1:
                denoised = x0_fn(x_hat.to(in_dtype), t_hat.to(in_dtype) * _ones).to(torch.float64)
                d_prime = (x_next - denoised) / t_next
                x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        return x_next.to(in_dtype)