dikdimon's picture
Upload sd-forge-extra-samplers using SD-Hub
f2f3b8d verified
from collections.abc import Callable
from typing import Any, Optional
import torch
from tqdm import trange
from k_diffusion.sampling import to_d
from modules import errors
import lib_es.const as consts
from lib_es.utils import sampler_metadata
def compute_optimal_gamma(steps: int, adaptive: bool = True) -> float:
"""
Compute the optimal gamma parameter for gradient estimation based on step count.
Args:
steps: Number of sampling steps
adaptive: Whether to use adaptive gamma based on step count
Returns:
Optimal gamma value
"""
if not adaptive:
return consts.GE_DEFAULT_GAMMA
# Define min and max values
min_steps, max_steps = 10, 100
min_gamma, max_gamma = 1.5, 2.6
# Handle edge cases
if steps <= min_steps:
return min_gamma
elif steps >= max_steps:
return max_gamma
# Apply logarithmic scaling
# log(steps/min_steps) / log(max_steps/min_steps) gives a value from 0 to 1
# that increases logarithmically with steps
log_factor = torch.log(torch.tensor(steps / min_steps)) / torch.log(torch.tensor(max_steps / min_steps))
# Convert the logarithmic factor to gamma value
gamma = min_gamma + log_factor * (max_gamma - min_gamma)
return float(gamma)
def validate_schedule(sigmas: torch.Tensor, eta: float = 0.1, nu: float = 2.0) -> bool:
"""
Validate whether a noise schedule satisfies the admissibility criteria from the paper.
Args:
sigmas: Tensor of noise levels in descending order
eta: Error parameter
nu: Accuracy parameter for distance estimates
Returns:
True if schedule is admissible, False otherwise
"""
n = len(sigmas) - 1
is_admissible = True
issues = []
# Check if sigmas are strictly decreasing
if not torch.all(sigmas[:-1] > sigmas[1:]):
is_admissible = False
issues.append("Sigmas must be strictly decreasing")
# Calculate the maximum allowable beta
c = 1 - nu ** (-1 / n)
beta_max = c / (eta + c)
# Check that step sizes respect the admissibility criteria
for i in range(n - 1):
ratio = sigmas[i + 1] / sigmas[i]
beta = 1 - ratio
if beta > beta_max:
is_admissible = False
issues.append(f"Step {i} has beta {beta:.4f} > beta_max {beta_max:.4f}")
if not is_admissible:
errors.display(ValueError(f"Noise schedule is not admissible: {', '.join(issues)}"))
errors.print_error_explanation("Noise schedule validation failed.\n\tIssues:" + ",\n\t\t".join(issues))
return is_admissible
@torch.no_grad()
@sampler_metadata("Gradient Estimation", {"scheduler": "sgm_uniform"})
def sample_gradient_estimation(
model,
x: torch.Tensor,
sigmas: torch.Tensor,
extra_args: Optional[dict[str, Any]] = None,
callback: Optional[Callable] = None,
disable: Optional[bool] = None,
validate_sigmas: bool = False,
eta: float = 0.1,
nu: float = 2.0,
) -> torch.Tensor:
"""
Gradient-estimation sampler as described in "Interpreting and Improving Diffusion Models from an Optimization Perspective".
This sampler implements a second-order method that improves upon DDIM by using a combination of current and previous
gradients to reduce gradient estimation error. It is based on the insight that denoising is approximately equivalent to
projection onto the data manifold, and diffusion sampling is gradient descent on the squared Euclidean distance function.
Args:
model: The diffusion model
x: Input tensor
sigmas: Noise schedule (should be in descending order)
extra_args: Extra arguments to pass to the model
callback: Callback function
disable: Whether to disable the progress bar
validate_sigmas: Whether to validate the noise schedule
eta: Error parameter for schedule validation (default 0.1)
nu: Accuracy parameter for schedule validation (default 2.0)
Returns:
Denoised tensor
References:
Paper: https://openreview.net/pdf?id=o2ND9v0CeK
"""
# Parameter validation and initialization
if sigmas.shape[0] < 2:
raise ValueError("Need at least 2 timesteps for gradient estimation")
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_d = None
steps = len(sigmas) - 1
# Schedule validation
if validate_sigmas:
validate_schedule(sigmas, eta, nu)
# Get gamma from model properties or compute optimal value
use_adaptive_steps: bool = getattr(model.p, consts.GE_USE_ADAPTIVE_STEPS, True)
if use_adaptive_steps:
# Compute optimal gamma based on the number of steps
# and add the offset if specified
ge_gamma = compute_optimal_gamma(steps, use_adaptive_steps) + getattr(
model.p, consts.GE_GAMMA_OFFSET, consts.GE_DEFAULT_GAMMA_OFFSET
)
else:
ge_gamma = getattr(model.p, consts.GE_GAMMA, consts.GE_DEFAULT_GAMMA)
# Initialize timestep-adaptive gamma values if needed
timestep_adaptive_gamma = getattr(model.p, consts.GE_USE_TIMESTEP_ADAPTIVE_GAMMA, False)
if timestep_adaptive_gamma:
# Higher gamma at the beginning, lower toward the end
# This is a heuristic based on the observation that early steps benefit more
# from aggressive gradient correction
gammas = torch.linspace(ge_gamma * 1.2, ge_gamma * 0.8, steps)
# Main sampling loop
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
dt = sigmas[i + 1] - sigmas[i]
if i == 0:
# Euler method for first step
x = x + d * dt
else:
# Gradient estimation
current_gamma = gammas[i].item() if timestep_adaptive_gamma else ge_gamma
d_bar = current_gamma * d + (1 - current_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x