File size: 6,170 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Scheduler functions copied from ComfyUI k_diffusion
# Original source: ComfyUI/comfy/k_diffusion/sampling.py

import math
import torch
from .utils import append_zero


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n, device=device)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
    """Constructs an exponential noise schedule."""
    sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
    return append_zero(sigmas)


def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
    """Constructs an polynomial in log sigma noise schedule."""
    ramp = torch.linspace(1, 0, n, device=device) ** rho
    sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
    return append_zero(sigmas)


def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
    """Constructs a continuous VP noise schedule."""
    t = torch.linspace(1, eps_s, n, device=device)
    sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
    return append_zero(sigmas)


def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
    """Constructs the noise schedule proposed by Tiankai et al. (2024). """
    epsilon = 1e-5  # avoid log(0)
    x = torch.linspace(0, 1, n, device=device)
    clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
    lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
    sigmas = clamp(torch.exp(lmb))
    return append_zero(sigmas)


# Standard scheduler names supported by this fallback module
STANDARD_SCHEDULERS = [
    # Implemented here
    "karras",            # Karras et al. (2022)
    "exponential",       # Uniform in log-sigma
    "polyexponential",   # Polynomial in log-sigma (rho)
    "vp",                # Continuous VP schedule
    "laplace",           # Laplace distribution schedule
    # Additional common ComfyUI schedulers (fallback approximations here)
    "normal",            # Uniform in sigma (fallback)
    "simple",            # Alias of normal (fallback)
    "sgm_uniform",       # Approximated by exponential (fallback)
    "ddim_uniform",      # Approximated by uniform in sigma (fallback)
    "beta",              # Approximated by VP (fallback)
    "linear_quadratic",  # Polyexponential with rho=2 (fallback)
    "kl_optimal",        # Alias of karras (fallback)
]

def get_sigmas_normal(n, sigma_min, sigma_max, device='cpu'):
    """Uniform spacing in sigma (descending). Fallback for 'normal'/'simple'."""
    sigmas = torch.linspace(float(sigma_max), float(sigma_min), n, device=device)
    return append_zero(sigmas).to(device)


def get_sigmas_simple(n, sigma_min, sigma_max, device='cpu'):
    """Alias of get_sigmas_normal (historically named 'simple')."""
    return get_sigmas_normal(n, sigma_min, sigma_max, device=device)


def get_sigmas_sgm_uniform(n, sigma_min, sigma_max, device='cpu'):
    """Approximate SGM-uniform by uniform in log-sigma (same as exponential)."""
    return get_sigmas_exponential(n, sigma_min, sigma_max, device=device)


def get_sigmas_ddim_uniform(n, sigma_min, sigma_max, device='cpu'):
    """Approximate DDIM-uniform by uniform sigma spacing (conservative fallback)."""
    return get_sigmas_normal(n, sigma_min, sigma_max, device=device)


def get_sigmas_beta(n, sigma_min, sigma_max, device='cpu'):
    """Approximate 'beta' schedule by VP schedule (commonly used for DDPM/VP)."""
    return get_sigmas_vp(n, device=device)


def get_sigmas_linear_quadratic(n, sigma_min, sigma_max, device='cpu'):
    """Piecewise-like curvature via polyexponential with rho=2 (quadratic in log-sigma)."""
    return get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=2.0, device=device)


def get_sigmas_kl_optimal(n, sigma_min, sigma_max, device='cpu'):
    """Use Karras schedule as a strong KL-friendly default."""
    return get_sigmas_karras(n, sigma_min, sigma_max, device=device)


def get_sigmas(scheduler_name, steps, sigma_min=0.03, sigma_max=14.6, device='cpu'):
    """Get sigma schedule using specified scheduler.



    Currently implements the mathematical schedulers we've copied from ComfyUI.

    """
    if scheduler_name == "karras":
        return get_sigmas_karras(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "exponential":
        return get_sigmas_exponential(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "polyexponential":
        return get_sigmas_polyexponential(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "vp":
        return get_sigmas_vp(steps, device=device)
    elif scheduler_name == "laplace":
        return get_sigmas_laplace(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "normal":
        return get_sigmas_normal(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "simple":
        return get_sigmas_simple(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "sgm_uniform":
        return get_sigmas_sgm_uniform(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "ddim_uniform":
        return get_sigmas_ddim_uniform(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "beta":
        return get_sigmas_beta(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "linear_quadratic":
        return get_sigmas_linear_quadratic(steps, sigma_min, sigma_max, device=device)
    elif scheduler_name == "kl_optimal":
        return get_sigmas_kl_optimal(steps, sigma_min, sigma_max, device=device)
    else:
        raise ValueError(f"Scheduler '{scheduler_name}' not implemented in comfy_copy. "
                        f"Available schedulers: {STANDARD_SCHEDULERS}")