File size: 6,170 Bytes
c6535db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | # Scheduler functions copied from ComfyUI k_diffusion
# Original source: ComfyUI/comfy/k_diffusion/sampling.py
import math
import torch
from .utils import append_zero
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return append_zero(sigmas)
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
"""Constructs an polynomial in log sigma noise schedule."""
ramp = torch.linspace(1, 0, n, device=device) ** rho
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return append_zero(sigmas)
# Standard scheduler names supported by this fallback module
STANDARD_SCHEDULERS = [
# Implemented here
"karras", # Karras et al. (2022)
"exponential", # Uniform in log-sigma
"polyexponential", # Polynomial in log-sigma (rho)
"vp", # Continuous VP schedule
"laplace", # Laplace distribution schedule
# Additional common ComfyUI schedulers (fallback approximations here)
"normal", # Uniform in sigma (fallback)
"simple", # Alias of normal (fallback)
"sgm_uniform", # Approximated by exponential (fallback)
"ddim_uniform", # Approximated by uniform in sigma (fallback)
"beta", # Approximated by VP (fallback)
"linear_quadratic", # Polyexponential with rho=2 (fallback)
"kl_optimal", # Alias of karras (fallback)
]
def get_sigmas_normal(n, sigma_min, sigma_max, device='cpu'):
"""Uniform spacing in sigma (descending). Fallback for 'normal'/'simple'."""
sigmas = torch.linspace(float(sigma_max), float(sigma_min), n, device=device)
return append_zero(sigmas).to(device)
def get_sigmas_simple(n, sigma_min, sigma_max, device='cpu'):
"""Alias of get_sigmas_normal (historically named 'simple')."""
return get_sigmas_normal(n, sigma_min, sigma_max, device=device)
def get_sigmas_sgm_uniform(n, sigma_min, sigma_max, device='cpu'):
"""Approximate SGM-uniform by uniform in log-sigma (same as exponential)."""
return get_sigmas_exponential(n, sigma_min, sigma_max, device=device)
def get_sigmas_ddim_uniform(n, sigma_min, sigma_max, device='cpu'):
"""Approximate DDIM-uniform by uniform sigma spacing (conservative fallback)."""
return get_sigmas_normal(n, sigma_min, sigma_max, device=device)
def get_sigmas_beta(n, sigma_min, sigma_max, device='cpu'):
"""Approximate 'beta' schedule by VP schedule (commonly used for DDPM/VP)."""
return get_sigmas_vp(n, device=device)
def get_sigmas_linear_quadratic(n, sigma_min, sigma_max, device='cpu'):
"""Piecewise-like curvature via polyexponential with rho=2 (quadratic in log-sigma)."""
return get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=2.0, device=device)
def get_sigmas_kl_optimal(n, sigma_min, sigma_max, device='cpu'):
"""Use Karras schedule as a strong KL-friendly default."""
return get_sigmas_karras(n, sigma_min, sigma_max, device=device)
def get_sigmas(scheduler_name, steps, sigma_min=0.03, sigma_max=14.6, device='cpu'):
"""Get sigma schedule using specified scheduler.
Currently implements the mathematical schedulers we've copied from ComfyUI.
"""
if scheduler_name == "karras":
return get_sigmas_karras(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "exponential":
return get_sigmas_exponential(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "polyexponential":
return get_sigmas_polyexponential(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "vp":
return get_sigmas_vp(steps, device=device)
elif scheduler_name == "laplace":
return get_sigmas_laplace(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "normal":
return get_sigmas_normal(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "simple":
return get_sigmas_simple(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "sgm_uniform":
return get_sigmas_sgm_uniform(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "ddim_uniform":
return get_sigmas_ddim_uniform(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "beta":
return get_sigmas_beta(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "linear_quadratic":
return get_sigmas_linear_quadratic(steps, sigma_min, sigma_max, device=device)
elif scheduler_name == "kl_optimal":
return get_sigmas_kl_optimal(steps, sigma_min, sigma_max, device=device)
else:
raise ValueError(f"Scheduler '{scheduler_name}' not implemented in comfy_copy. "
f"Available schedulers: {STANDARD_SCHEDULERS}")
|