Wan_Backup / custom_nodes /ComfyUI-FSampler /comfy_copy /k_diffusion_sampling.py
Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
# Scheduler functions copied from ComfyUI k_diffusion
# Original source: ComfyUI/comfy/k_diffusion/sampling.py
import math
import torch
from .utils import append_zero
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return append_zero(sigmas)
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
"""Constructs an polynomial in log sigma noise schedule."""
ramp = torch.linspace(1, 0, n, device=device) ** rho
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return append_zero(sigmas)
# Standard scheduler names supported by this fallback module
STANDARD_SCHEDULERS = [
# Implemented here
"karras", # Karras et al. (2022)
"exponential", # Uniform in log-sigma
"polyexponential", # Polynomial in log-sigma (rho)
"vp", # Continuous VP schedule
"laplace", # Laplace distribution schedule
# Additional common ComfyUI schedulers (fallback approximations here)
"normal", # Uniform in sigma (fallback)
"simple", # Alias of normal (fallback)
"sgm_uniform", # Approximated by exponential (fallback)
"ddim_uniform", # Approximated by uniform in sigma (fallback)
"beta", # Approximated by VP (fallback)
"linear_quadratic", # Polyexponential with rho=2 (fallback)
"kl_optimal", # Alias of karras (fallback)
]
def get_sigmas_normal(n, sigma_min, sigma_max, device='cpu'):
"""Uniform spacing in sigma (descending). Fallback for 'normal'/'simple'."""
sigmas = torch.linspace(float(sigma_max), float(sigma_min), n, device=device)
return append_zero(sigmas).to(device)
def get_sigmas_simple(n, sigma_min, sigma_max, device='cpu'):
"""Alias of get_sigmas_normal (historically named 'simple')."""
return get_sigmas_normal(n, sigma_min, sigma_max, device=device)
def get_sigmas_sgm_uniform(n, sigma_min, sigma_max, device='cpu'):
"""Approximate SGM-uniform by uniform in log-sigma (same as exponential)."""
return get_sigmas_exponential(n, sigma_min, sigma_max, device=device)
def get_sigmas_ddim_uniform(n, sigma_min, sigma_max, device='cpu'):
"""Approximate DDIM-uniform by uniform sigma spacing (conservative fallback)."""
return get_sigmas_normal(n, sigma_min, sigma_max, device=device)
def get_sigmas_beta(n, sigma_min, sigma_max, device='cpu'):
"""Approximate 'beta' schedule by VP schedule (commonly used for DDPM/VP)."""
return get_sigmas_vp(n, device=device)
def get_sigmas_linear_quadratic(n, sigma_min, sigma_max, device='cpu'):
"""Piecewise-like curvature via polyexponential with rho=2 (quadratic in log-sigma)."""
return get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=2.0, device=device)
def get_sigmas_kl_optimal(n, sigma_min, sigma_max, device='cpu'):
"""Use Karras schedule as a strong KL-friendly default."""
return get_sigmas_karras(n, sigma_min, sigma_max, device=device)
def get_sigmas(scheduler_name, steps, sigma_min=0.03, sigma_max=14.6, device='cpu'):
"""Get sigma schedule using specified scheduler.
Currently implements the mathematical schedulers we've copied from ComfyUI.
"""
if scheduler_name == "karras":
return get_sigmas_karras(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "exponential":
return get_sigmas_exponential(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "polyexponential":
return get_sigmas_polyexponential(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "vp":
return get_sigmas_vp(steps, device=device)
elif scheduler_name == "laplace":
return get_sigmas_laplace(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "normal":
return get_sigmas_normal(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "simple":
return get_sigmas_simple(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "sgm_uniform":
return get_sigmas_sgm_uniform(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "ddim_uniform":
return get_sigmas_ddim_uniform(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "beta":
return get_sigmas_beta(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "linear_quadratic":
return get_sigmas_linear_quadratic(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "kl_optimal":
return get_sigmas_kl_optimal(steps, sigma_min, sigma_max, device=device)
else:
raise ValueError(f"Scheduler '{scheduler_name}' not implemented in comfy_copy. "
f"Available schedulers: {STANDARD_SCHEDULERS}")