|
|
from collections.abc import Callable |
|
|
from typing import Any, Optional |
|
|
import torch |
|
|
from tqdm import trange |
|
|
|
|
|
from k_diffusion.sampling import to_d |
|
|
from modules import errors |
|
|
|
|
|
import lib_es.const as consts |
|
|
from lib_es.utils import sampler_metadata |
|
|
|
|
|
|
|
|
def compute_optimal_gamma(steps: int, adaptive: bool = True) -> float: |
|
|
""" |
|
|
Compute the optimal gamma parameter for gradient estimation based on step count. |
|
|
|
|
|
Args: |
|
|
steps: Number of sampling steps |
|
|
adaptive: Whether to use adaptive gamma based on step count |
|
|
|
|
|
Returns: |
|
|
Optimal gamma value |
|
|
""" |
|
|
if not adaptive: |
|
|
return consts.GE_DEFAULT_GAMMA |
|
|
|
|
|
|
|
|
min_steps, max_steps = 10, 100 |
|
|
min_gamma, max_gamma = 1.5, 2.6 |
|
|
|
|
|
|
|
|
if steps <= min_steps: |
|
|
return min_gamma |
|
|
elif steps >= max_steps: |
|
|
return max_gamma |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_factor = torch.log(torch.tensor(steps / min_steps)) / torch.log(torch.tensor(max_steps / min_steps)) |
|
|
|
|
|
|
|
|
gamma = min_gamma + log_factor * (max_gamma - min_gamma) |
|
|
|
|
|
return float(gamma) |
|
|
|
|
|
|
|
|
def validate_schedule(sigmas: torch.Tensor, eta: float = 0.1, nu: float = 2.0) -> bool: |
|
|
""" |
|
|
Validate whether a noise schedule satisfies the admissibility criteria from the paper. |
|
|
|
|
|
Args: |
|
|
sigmas: Tensor of noise levels in descending order |
|
|
eta: Error parameter |
|
|
nu: Accuracy parameter for distance estimates |
|
|
|
|
|
Returns: |
|
|
True if schedule is admissible, False otherwise |
|
|
""" |
|
|
n = len(sigmas) - 1 |
|
|
is_admissible = True |
|
|
issues = [] |
|
|
|
|
|
|
|
|
if not torch.all(sigmas[:-1] > sigmas[1:]): |
|
|
is_admissible = False |
|
|
issues.append("Sigmas must be strictly decreasing") |
|
|
|
|
|
|
|
|
c = 1 - nu ** (-1 / n) |
|
|
beta_max = c / (eta + c) |
|
|
|
|
|
|
|
|
for i in range(n - 1): |
|
|
ratio = sigmas[i + 1] / sigmas[i] |
|
|
beta = 1 - ratio |
|
|
if beta > beta_max: |
|
|
is_admissible = False |
|
|
issues.append(f"Step {i} has beta {beta:.4f} > beta_max {beta_max:.4f}") |
|
|
|
|
|
if not is_admissible: |
|
|
errors.display(ValueError(f"Noise schedule is not admissible: {', '.join(issues)}")) |
|
|
errors.print_error_explanation("Noise schedule validation failed.\n\tIssues:" + ",\n\t\t".join(issues)) |
|
|
|
|
|
return is_admissible |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
@sampler_metadata("Gradient Estimation", {"scheduler": "sgm_uniform"}) |
|
|
def sample_gradient_estimation( |
|
|
model, |
|
|
x: torch.Tensor, |
|
|
sigmas: torch.Tensor, |
|
|
extra_args: Optional[dict[str, Any]] = None, |
|
|
callback: Optional[Callable] = None, |
|
|
disable: Optional[bool] = None, |
|
|
validate_sigmas: bool = False, |
|
|
eta: float = 0.1, |
|
|
nu: float = 2.0, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Gradient-estimation sampler as described in "Interpreting and Improving Diffusion Models from an Optimization Perspective". |
|
|
|
|
|
This sampler implements a second-order method that improves upon DDIM by using a combination of current and previous |
|
|
gradients to reduce gradient estimation error. It is based on the insight that denoising is approximately equivalent to |
|
|
projection onto the data manifold, and diffusion sampling is gradient descent on the squared Euclidean distance function. |
|
|
|
|
|
Args: |
|
|
model: The diffusion model |
|
|
x: Input tensor |
|
|
sigmas: Noise schedule (should be in descending order) |
|
|
extra_args: Extra arguments to pass to the model |
|
|
callback: Callback function |
|
|
disable: Whether to disable the progress bar |
|
|
validate_sigmas: Whether to validate the noise schedule |
|
|
eta: Error parameter for schedule validation (default 0.1) |
|
|
nu: Accuracy parameter for schedule validation (default 2.0) |
|
|
|
|
|
Returns: |
|
|
Denoised tensor |
|
|
|
|
|
References: |
|
|
Paper: https://openreview.net/pdf?id=o2ND9v0CeK |
|
|
""" |
|
|
|
|
|
if sigmas.shape[0] < 2: |
|
|
raise ValueError("Need at least 2 timesteps for gradient estimation") |
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args |
|
|
s_in = x.new_ones([x.shape[0]]) |
|
|
old_d = None |
|
|
steps = len(sigmas) - 1 |
|
|
|
|
|
|
|
|
if validate_sigmas: |
|
|
validate_schedule(sigmas, eta, nu) |
|
|
|
|
|
|
|
|
use_adaptive_steps: bool = getattr(model.p, consts.GE_USE_ADAPTIVE_STEPS, True) |
|
|
if use_adaptive_steps: |
|
|
|
|
|
|
|
|
ge_gamma = compute_optimal_gamma(steps, use_adaptive_steps) + getattr( |
|
|
model.p, consts.GE_GAMMA_OFFSET, consts.GE_DEFAULT_GAMMA_OFFSET |
|
|
) |
|
|
else: |
|
|
ge_gamma = getattr(model.p, consts.GE_GAMMA, consts.GE_DEFAULT_GAMMA) |
|
|
|
|
|
|
|
|
timestep_adaptive_gamma = getattr(model.p, consts.GE_USE_TIMESTEP_ADAPTIVE_GAMMA, False) |
|
|
|
|
|
if timestep_adaptive_gamma: |
|
|
|
|
|
|
|
|
|
|
|
gammas = torch.linspace(ge_gamma * 1.2, ge_gamma * 0.8, steps) |
|
|
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
|
d = to_d(x, sigmas[i], denoised) |
|
|
|
|
|
if callback is not None: |
|
|
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) |
|
|
|
|
|
dt = sigmas[i + 1] - sigmas[i] |
|
|
|
|
|
if i == 0: |
|
|
|
|
|
x = x + d * dt |
|
|
else: |
|
|
|
|
|
current_gamma = gammas[i].item() if timestep_adaptive_gamma else ge_gamma |
|
|
|
|
|
d_bar = current_gamma * d + (1 - current_gamma) * old_d |
|
|
x = x + d_bar * dt |
|
|
|
|
|
old_d = d |
|
|
|
|
|
return x |
|
|
|