dikdimon commited on
Commit
bc40952
·
verified ·
1 Parent(s): d31be88

Upload sd_schedulers.py using SD-Hub

Browse files
Files changed (1) hide show
  1. sd_schedulers.py +164 -0
sd_schedulers.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ import k_diffusion
4
+ import numpy as np
5
+ from scipy import stats
6
+ from modules import shared
7
+ from modules.sd_simple_kes.simple_kes import simple_kes_scheduler
8
+ from modules.sd_simple_kes_v1.simple_kes_v1 import simple_kes_scheduler_v1
9
+ from modules.sd_simple_kes_v2.simple_kes_v2 import simple_kes_scheduler_v22
10
+ from modules.sd_simple_kes_v2_old.simple_kes_v2 import simple_kes_scheduler_v2
11
+
12
+ def to_d(x, sigma, denoised):
13
+ """Converts a denoiser output to a Karras ODE derivative."""
14
+ return (x - denoised) / sigma
15
+
16
+
17
+ k_diffusion.sampling.to_d = to_d
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class Scheduler:
22
+ name: str
23
+ label: str
24
+ function: any
25
+
26
+ default_rho: float = -1
27
+ need_inner_model: bool = False
28
+ aliases: list = None
29
+
30
+
31
+ def uniform(n, sigma_min, sigma_max, inner_model, device):
32
+ return inner_model.get_sigmas(n).to(device)
33
+
34
+
35
+ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
36
+ start = inner_model.sigma_to_t(torch.tensor(sigma_max))
37
+ end = inner_model.sigma_to_t(torch.tensor(sigma_min))
38
+ sigs = [
39
+ inner_model.t_to_sigma(ts)
40
+ for ts in torch.linspace(start, end, n + 1)[:-1]
41
+ ]
42
+ sigs += [0.0]
43
+ return torch.FloatTensor(sigs).to(device)
44
+
45
+
46
+ def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device):
47
+ # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
48
+ def loglinear_interp(t_steps, num_steps):
49
+ """
50
+ Performs log-linear interpolation of a given array of decreasing numbers.
51
+ """
52
+ xs = np.linspace(0, 1, len(t_steps))
53
+ ys = np.log(t_steps[::-1])
54
+
55
+ new_xs = np.linspace(0, 1, num_steps)
56
+ new_ys = np.interp(new_xs, xs, ys)
57
+
58
+ interped_ys = np.exp(new_ys)[::-1].copy()
59
+ return interped_ys
60
+
61
+ if shared.sd_model.is_sdxl:
62
+ sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
63
+ else:
64
+ # Default to SD 1.5 sigmas.
65
+ sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]
66
+
67
+ if n != len(sigmas):
68
+ sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
69
+ else:
70
+ sigmas.append(0.0)
71
+
72
+ return torch.FloatTensor(sigmas).to(device)
73
+
74
+
75
+ def kl_optimal(n, sigma_min, sigma_max, device):
76
+ alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))
77
+ alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))
78
+ step_indices = torch.arange(n + 1, device=device)
79
+ sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
80
+ return sigmas
81
+
82
+
83
+ def simple_scheduler(n, sigma_min, sigma_max, inner_model, device):
84
+ sigs = []
85
+ ss = len(inner_model.sigmas) / n
86
+ for x in range(n):
87
+ sigs += [float(inner_model.sigmas[-(1 + int(x * ss))])]
88
+ sigs += [0.0]
89
+ return torch.FloatTensor(sigs).to(device)
90
+
91
+
92
+ def normal_scheduler(n, sigma_min, sigma_max, inner_model, device, sgm=False, floor=False):
93
+ start = inner_model.sigma_to_t(torch.tensor(sigma_max))
94
+ end = inner_model.sigma_to_t(torch.tensor(sigma_min))
95
+
96
+ if sgm:
97
+ timesteps = torch.linspace(start, end, n + 1)[:-1]
98
+ else:
99
+ timesteps = torch.linspace(start, end, n)
100
+
101
+ sigs = []
102
+ for x in range(len(timesteps)):
103
+ ts = timesteps[x]
104
+ sigs.append(inner_model.t_to_sigma(ts))
105
+ sigs += [0.0]
106
+ return torch.FloatTensor(sigs).to(device)
107
+
108
+
109
+ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
110
+ sigs = []
111
+ ss = max(len(inner_model.sigmas) // n, 1)
112
+ x = 1
113
+ while x < len(inner_model.sigmas):
114
+ sigs += [float(inner_model.sigmas[x])]
115
+ x += ss
116
+ sigs = sigs[::-1]
117
+ sigs += [0.0]
118
+ return torch.FloatTensor(sigs).to(device)
119
+
120
+
121
+ def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
122
+ # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
123
+ alpha = shared.opts.beta_dist_alpha
124
+ beta = shared.opts.beta_dist_beta
125
+ timesteps = 1 - np.linspace(0, 1, n)
126
+ timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
127
+ sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
128
+ sigmas += [0.0]
129
+ return torch.FloatTensor(sigmas).to(device)
130
+
131
+ def beta_scheduler_old(n, sigma_min, sigma_max, inner_model, device):
132
+ # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
133
+ alpha = shared.opts.beta_dist_alpha_old
134
+ beta = shared.opts.beta_dist_beta_old
135
+ timesteps = 1 - np.linspace(0, 1, n)
136
+ timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
137
+ sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
138
+
139
+
140
+
141
+ sigmas += [0.0]
142
+ return torch.FloatTensor(sigmas).to(device)
143
+
144
+ schedulers = [
145
+ Scheduler('automatic', 'Automatic', None),
146
+ Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
147
+ Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
148
+ Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
149
+ Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
150
+ Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
151
+ Scheduler('kl_optimal', 'KL Optimal', kl_optimal),
152
+ Scheduler('align_your_steps', 'Align Your Steps', get_align_your_steps_sigmas),
153
+ Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
154
+ Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
155
+ Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
156
+ Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True),
157
+ Scheduler('beta_old', 'Beta Old', beta_scheduler_old, need_inner_model=True),
158
+ Scheduler('karras_exponential', 'Karras Exponential', simple_kes_scheduler),
159
+ Scheduler('karras_exponential_v1', 'Karras Exponential v1', simple_kes_scheduler_v1),
160
+ Scheduler('karras_exponential_v2_olds', 'Karras Exponential v2 old', simple_kes_scheduler_v2),
161
+ Scheduler('karras_exponential_v2', 'Karras Exponential v2', simple_kes_scheduler_v22),
162
+ ]
163
+
164
+ schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}