File size: 6,203 Bytes
97923d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from collections.abc import Callable
from typing import Any, Optional
import torch
from tqdm import trange

from lib_es.compat import to_d
from modules import errors

import lib_es.const as consts
from lib_es.utils import sampler_metadata


def compute_optimal_gamma(steps: int, adaptive: bool = True) -> float:
    """
    Compute the optimal gamma parameter for gradient estimation based on step count.

    Args:
        steps: Number of sampling steps
        adaptive: Whether to use adaptive gamma based on step count

    Returns:
        Optimal gamma value
    """
    if not adaptive:
        return consts.GE_DEFAULT_GAMMA

    # Define min and max values
    min_steps, max_steps = 10, 100
    min_gamma, max_gamma = 1.5, 2.6

    # Handle edge cases
    if steps <= min_steps:
        return min_gamma
    elif steps >= max_steps:
        return max_gamma

    # Apply logarithmic scaling
    # log(steps/min_steps) / log(max_steps/min_steps) gives a value from 0 to 1
    # that increases logarithmically with steps
    log_factor = torch.log(torch.tensor(steps / min_steps)) / torch.log(torch.tensor(max_steps / min_steps))

    # Convert the logarithmic factor to gamma value
    gamma = min_gamma + log_factor * (max_gamma - min_gamma)

    return float(gamma)


def validate_schedule(sigmas: torch.Tensor, eta: float = 0.1, nu: float = 2.0) -> bool:
    """
    Validate whether a noise schedule satisfies the admissibility criteria from the paper.

    Args:
        sigmas: Tensor of noise levels in descending order
        eta: Error parameter
        nu: Accuracy parameter for distance estimates

    Returns:
        True if schedule is admissible, False otherwise
    """
    n = len(sigmas) - 1
    is_admissible = True
    issues = []

    # Check if sigmas are strictly decreasing
    if not torch.all(sigmas[:-1] > sigmas[1:]):
        is_admissible = False
        issues.append("Sigmas must be strictly decreasing")

    # Calculate the maximum allowable beta
    c = 1 - nu ** (-1 / n)
    beta_max = c / (eta + c)

    # Check that step sizes respect the admissibility criteria
    for i in range(n - 1):
        ratio = sigmas[i + 1] / sigmas[i]
        beta = 1 - ratio
        if beta > beta_max:
            is_admissible = False
            issues.append(f"Step {i} has beta {beta:.4f} > beta_max {beta_max:.4f}")

    if not is_admissible:
        errors.display(ValueError(f"Noise schedule is not admissible: {', '.join(issues)}"))
        errors.print_error_explanation("Noise schedule validation failed.\n\tIssues:" + ",\n\t\t".join(issues))

    return is_admissible


@torch.no_grad()
@sampler_metadata("Gradient Estimation", {"scheduler": "sgm_uniform"})
def sample_gradient_estimation(
    model,
    x: torch.Tensor,
    sigmas: torch.Tensor,
    extra_args: Optional[dict[str, Any]] = None,
    callback: Optional[Callable] = None,
    disable: Optional[bool] = None,
    validate_sigmas: bool = False,
    eta: float = 0.1,
    nu: float = 2.0,
) -> torch.Tensor:
    """
    Gradient-estimation sampler as described in "Interpreting and Improving Diffusion Models from an Optimization Perspective".

    This sampler implements a second-order method that improves upon DDIM by using a combination of current and previous
    gradients to reduce gradient estimation error. It is based on the insight that denoising is approximately equivalent to
    projection onto the data manifold, and diffusion sampling is gradient descent on the squared Euclidean distance function.

    Args:
        model: The diffusion model
        x: Input tensor
        sigmas: Noise schedule (should be in descending order)
        extra_args: Extra arguments to pass to the model
        callback: Callback function
        disable: Whether to disable the progress bar
        validate_sigmas: Whether to validate the noise schedule
        eta: Error parameter for schedule validation (default 0.1)
        nu: Accuracy parameter for schedule validation (default 2.0)

    Returns:
        Denoised tensor

    References:
        Paper: https://openreview.net/pdf?id=o2ND9v0CeK
    """
    # Parameter validation and initialization
    if sigmas.shape[0] < 2:
        raise ValueError("Need at least 2 timesteps for gradient estimation")

    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    old_d = None
    steps = len(sigmas) - 1

    # Schedule validation
    if validate_sigmas:
        validate_schedule(sigmas, eta, nu)

    # Get gamma from model properties or compute optimal value
    use_adaptive_steps: bool = getattr(model.p, consts.GE_USE_ADAPTIVE_STEPS, True)
    if use_adaptive_steps:
        # Compute optimal gamma based on the number of steps
        # and add the offset if specified
        ge_gamma = compute_optimal_gamma(steps, use_adaptive_steps) + getattr(
            model.p, consts.GE_GAMMA_OFFSET, consts.GE_DEFAULT_GAMMA_OFFSET
        )
    else:
        ge_gamma = getattr(model.p, consts.GE_GAMMA, consts.GE_DEFAULT_GAMMA)

    # Initialize timestep-adaptive gamma values if needed
    timestep_adaptive_gamma = getattr(model.p, consts.GE_USE_TIMESTEP_ADAPTIVE_GAMMA, False)

    if timestep_adaptive_gamma:
        # Higher gamma at the beginning, lower toward the end
        # This is a heuristic based on the observation that early steps benefit more
        # from aggressive gradient correction
        gammas = torch.linspace(ge_gamma * 1.2, ge_gamma * 0.8, steps)

    # Main sampling loop
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        d = to_d(x, sigmas[i], denoised)

        if callback is not None:
            callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})

        dt = sigmas[i + 1] - sigmas[i]

        if i == 0:
            # Euler method for first step
            x = x + d * dt
        else:
            # Gradient estimation
            current_gamma = gammas[i].item() if timestep_adaptive_gamma else ge_gamma

            d_bar = current_gamma * d + (1 - current_gamma) * old_d
            x = x + d_bar * dt

        old_d = d

    return x