File size: 15,714 Bytes
fabd6c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 |
import torch
from torch import no_grad, FloatTensor
from tqdm import tqdm
from itertools import pairwise
from typing import Protocol, Optional, Dict, Any, TypedDict, NamedTuple, Union, List
import math
from tqdm.auto import trange
# copied from kdiffusion/sampling.py and utils.py
def default_noise_sampler(x):
return lambda sigma, sigma_next: torch.randn_like(x)
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
class DenoiserModel(Protocol):
def __call__(self, x: FloatTensor, t: FloatTensor, *args, **kwargs) -> FloatTensor: ...
class RefinedExpCallbackPayload(TypedDict):
x: FloatTensor
i: int
sigma: FloatTensor
sigma_hat: FloatTensor
class RefinedExpCallback(Protocol):
def __call__(self, payload: RefinedExpCallbackPayload) -> None: ...
class NoiseSampler(Protocol):
def __call__(self, x: FloatTensor) -> FloatTensor: ...
class StepOutput(NamedTuple):
x_next: FloatTensor
denoised: FloatTensor
denoised2: FloatTensor
vel: FloatTensor
vel_2: FloatTensor
def _gamma(
n: int,
) -> int:
"""
https://en.wikipedia.org/wiki/Gamma_function
for every positive integer n,
Γ(n) = (n-1)!
"""
return math.factorial(n-1)
def _incomplete_gamma(
s: int,
x: float,
gamma_s: Optional[int] = None
) -> float:
"""
https://en.wikipedia.org/wiki/Incomplete_gamma_function#Special_values
if s is a positive integer,
Γ(s, x) = (s-1)!*∑{k=0..s-1}(x^k/k!)
"""
if gamma_s is None:
gamma_s = _gamma(s)
sum_: float = 0
# {k=0..s-1} inclusive
for k in range(s):
numerator: float = x**k
denom: int = math.factorial(k)
quotient: float = numerator/denom
sum_ += quotient
incomplete_gamma_: float = sum_ * math.exp(-x) * gamma_s
return incomplete_gamma_
# by Katherine Crowson
def _phi_1(neg_h: FloatTensor):
return torch.nan_to_num(torch.expm1(neg_h) / neg_h, nan=1.0)
# by Katherine Crowson
def _phi_2(neg_h: FloatTensor):
return torch.nan_to_num((torch.expm1(neg_h) - neg_h) / neg_h**2, nan=0.5)
# by Katherine Crowson
def _phi_3(neg_h: FloatTensor):
return torch.nan_to_num((torch.expm1(neg_h) - neg_h - neg_h**2 / 2) / neg_h**3, nan=1 / 6)
def _phi(
neg_h: float,
j: int,
):
"""
For j={1,2,3}: you could alternatively use Kat's phi_1, phi_2, phi_3 which perform fewer steps
Lemma 1
https://arxiv.org/abs/2308.02157
ϕj(-h) = 1/h^j*∫{0..h}(e^(τ-h)*(τ^(j-1))/((j-1)!)dτ)
https://www.wolframalpha.com/input?i=integrate+e%5E%28%CF%84-h%29*%28%CF%84%5E%28j-1%29%2F%28j-1%29%21%29d%CF%84
= 1/h^j*[(e^(-h)*(-τ)^(-j)*τ(j))/((j-1)!)]{0..h}
https://www.wolframalpha.com/input?i=integrate+e%5E%28%CF%84-h%29*%28%CF%84%5E%28j-1%29%2F%28j-1%29%21%29d%CF%84+between+0+and+h
= 1/h^j*((e^(-h)*(-h)^(-j)*h^j*(Γ(j)-Γ(j,-h)))/(j-1)!)
= (e^(-h)*(-h)^(-j)*h^j*(Γ(j)-Γ(j,-h))/((j-1)!*h^j)
= (e^(-h)*(-h)^(-j)*(Γ(j)-Γ(j,-h))/(j-1)!
= (e^(-h)*(-h)^(-j)*(Γ(j)-Γ(j,-h))/Γ(j)
= (e^(-h)*(-h)^(-j)*(1-Γ(j,-h)/Γ(j))
requires j>0
"""
assert j > 0
gamma_: float = _gamma(j)
incomp_gamma_: float = _incomplete_gamma(j, neg_h, gamma_s=gamma_)
phi_: float = math.exp(neg_h) * neg_h**-j * (1-incomp_gamma_/gamma_)
return phi_
class RESDECoeffsSecondOrder(NamedTuple):
a2_1: float
b1: float
b2: float
def _de_second_order(
h: float,
c2: float,
simple_phi_calc = False,
) -> RESDECoeffsSecondOrder:
"""
Table 3
https://arxiv.org/abs/2308.02157
ϕi,j := ϕi,j(-h) = ϕi(-cj*h)
a2_1 = c2ϕ1,2
= c2ϕ1(-c2*h)
b1 = ϕ1 - ϕ2/c2
"""
if simple_phi_calc:
# Kat computed simpler expressions for phi for cases j={1,2,3}
a2_1: float = c2 * _phi_1(-c2*h)
phi1: float = _phi_1(-h)
phi2: float = _phi_2(-h)
else:
# I computed general solution instead.
# they're close, but there are slight differences. not sure which would be more prone to numerical error.
a2_1: float = c2 * _phi(j=1, neg_h=-c2*h)
phi1: float = _phi(j=1, neg_h=-h)
phi2: float = _phi(j=2, neg_h=-h)
phi2_c2: float = phi2/c2
b1: float = phi1 - phi2_c2
b2: float = phi2_c2
return RESDECoeffsSecondOrder(
a2_1=a2_1,
b1=b1,
b2=b2,
)
def _refined_exp_sosu_step(
model: DenoiserModel,
x: FloatTensor,
sigma: FloatTensor,
sigma_next: FloatTensor,
c2 = 0.5,
extra_args: Dict[str, Any] = {},
pbar: Optional[tqdm] = None,
simple_phi_calc = False,
momentum = 0.0,
vel = None,
vel_2 = None,
time = None
) -> StepOutput:
"""
Algorithm 1 "RES Second order Single Update Step with c2"
https://arxiv.org/abs/2308.02157
Parameters:
model (`DenoiserModel`): a k-diffusion wrapped denoiser model (e.g. a subclass of DiscreteEpsDDPMDenoiser)
x (`FloatTensor`): noised latents (or RGB I suppose), e.g. torch.randn((B, C, H, W)) * sigma[0]
sigma (`FloatTensor`): timestep to denoise
sigma_next (`FloatTensor`): timestep+1 to denoise
c2 (`float`, *optional*, defaults to .5): partial step size for solving ODE. .5 = midpoint method
extra_args (`Dict[str, Any]`, *optional*, defaults to `{}`): kwargs to pass to `model#__call__()`
pbar (`tqdm`, *optional*, defaults to `None`): progress bar to update after each model call
simple_phi_calc (`bool`, *optional*, defaults to `True`): True = calculate phi_i,j(-h) via simplified formulae specific to j={1,2}. False = Use general solution that works for any j. Mathematically equivalent, but could be numeric differences.
"""
def momentum_func(diff, velocity, timescale=1.0, offset=-momentum / 2.0): # Diff is current diff, vel is previous diff
if velocity is None:
momentum_vel = diff
else:
momentum_vel = momentum * (timescale + offset) * velocity + (1 - momentum * (timescale + offset)) * diff
return momentum_vel
lam_next, lam = (s.log().neg() for s in (sigma_next, sigma))
# type hints aren't strictly true regarding float vs FloatTensor.
# everything gets promoted to `FloatTensor` after interacting with `sigma: FloatTensor`.
# I will use float to indicate any variables which are scalars.
h: float = lam_next - lam
a2_1, b1, b2 = _de_second_order(h=h, c2=c2, simple_phi_calc=simple_phi_calc)
denoised: FloatTensor = model(x, sigma.repeat(x.size(0)), **extra_args)
# if pbar is not None:
# pbar.update(0.5)
c2_h: float = c2*h
diff_2 = momentum_func(a2_1*h*denoised, vel_2, time)
vel_2 = diff_2
x_2: FloatTensor = math.exp(-c2_h)*x + diff_2
lam_2: float = lam + c2_h
sigma_2: float = lam_2.neg().exp()
denoised2: FloatTensor = model(x_2, sigma_2.repeat(x_2.size(0)), **extra_args)
if pbar is not None:
pbar.update()
diff = momentum_func(h*(b1*denoised + b2*denoised2), vel, time)
vel = diff
x_next: FloatTensor = math.exp(-h)*x + diff
return StepOutput(
x_next=x_next,
denoised=denoised,
denoised2=denoised2,
vel=vel,
vel_2=vel_2,
)
@no_grad()
def sample_refined_exp_s(
model: FloatTensor,
x: FloatTensor,
sigmas: FloatTensor,
denoise_to_zero: bool = True,
extra_args: Dict[str, Any] = {},
callback: Optional[RefinedExpCallback] = None,
disable: Optional[bool] = None,
ita: FloatTensor = torch.zeros((1,)),
c2 = .5,
noise_sampler: NoiseSampler = torch.randn_like,
simple_phi_calc = False,
momentum = 0.0,
):
"""
Refined Exponential Solver (S).
Algorithm 2 "RES Single-Step Sampler" with Algorithm 1 second-order step
https://arxiv.org/abs/2308.02157
Parameters:
model (`DenoiserModel`): a k-diffusion wrapped denoiser model (e.g. a subclass of DiscreteEpsDDPMDenoiser)
x (`FloatTensor`): noised latents (or RGB I suppose), e.g. torch.randn((B, C, H, W)) * sigma[0]
sigmas (`FloatTensor`): sigmas (ideally an exponential schedule!) e.g. get_sigmas_exponential(n=25, sigma_min=model.sigma_min, sigma_max=model.sigma_max)
denoise_to_zero (`bool`, *optional*, defaults to `True`): whether to finish with a first-order step down to 0 (rather than stopping at sigma_min). True = fully denoise image. False = match Algorithm 2 in paper
extra_args (`Dict[str, Any]`, *optional*, defaults to `{}`): kwargs to pass to `model#__call__()`
callback (`RefinedExpCallback`, *optional*, defaults to `None`): you can supply this callback to see the intermediate denoising results, e.g. to preview each step of the denoising process
disable (`bool`, *optional*, defaults to `False`): whether to hide `tqdm`'s progress bar animation from being printed
ita (`FloatTensor`, *optional*, defaults to 0.): degree of stochasticity, η, for each timestep. tensor shape must be broadcastable to 1-dimensional tensor with length `len(sigmas) if denoise_to_zero else len(sigmas)-1`. each element should be from 0 to 1.
- if used: batch noise doesn't match non-batch
c2 (`float`, *optional*, defaults to .5): partial step size for solving ODE. .5 = midpoint method
noise_sampler (`NoiseSampler`, *optional*, defaults to `torch.randn_like`): method used for adding noise
simple_phi_calc (`bool`, *optional*, defaults to `True`): True = calculate phi_i,j(-h) via simplified formulae specific to j={1,2}. False = Use general solution that works for any j. Mathematically equivalent, but could be numeric differences.
"""
#assert sigmas[-1] == 0
device = x.device
ita = ita.to(device)
sigmas = sigmas.to(device)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
vel, vel_2 = None, None
with tqdm(disable=disable, total=len(sigmas)-(1 if denoise_to_zero else 2)) as pbar:
for i, (sigma, sigma_next) in enumerate(pairwise(sigmas[:-1].split(1))):
time = sigmas[i] / sigma_max
if 'sigma' not in locals():
sigma = sigmas[i]
eps = torch.randn_like(x).float()
sigma_hat = sigma * (1 + ita)
x_hat = x + (sigma_hat ** 2 - sigma ** 2).sqrt() * eps
x_next, denoised, denoised2, vel, vel_2 = _refined_exp_sosu_step(
model,
x_hat,
sigma_hat,
sigma_next,
c2=c2,
extra_args=extra_args,
pbar=pbar,
simple_phi_calc=simple_phi_calc,
momentum = momentum,
vel = vel,
vel_2 = vel_2,
time = time
)
if callback is not None:
payload = RefinedExpCallbackPayload(
x=x,
i=i,
sigma=sigma,
sigma_hat=sigma_hat,
denoised=denoised,
denoised2=denoised2,
)
callback(payload)
x = x_next
if denoise_to_zero:
eps = torch.randn_like(x).float()
sigma_hat = sigma * (1 + ita)
x_hat = x + (sigma_hat ** 2 - sigma ** 2).sqrt() * eps
x_next: FloatTensor = model(x_hat, sigma.to(x_hat.device).repeat(x_hat.size(0)), **extra_args)
pbar.update()
if callback is not None:
payload = RefinedExpCallbackPayload(
x=x,
i=i,
sigma=sigma,
sigma_hat=sigma_hat,
denoised=denoised,
denoised2=denoised2,
)
callback(payload)
x = x_next
return x
# Many thanks to Kat + Birch-San for this wonderful sampler implementation! https://github.com/Birch-san/sdxl-play/commits/res/
def sample_res_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler_type="gaussian", noise_sampler=None, denoise_to_zero=True, simple_phi_calc=False, c2=0.5, ita=torch.Tensor((0.0,)), momentum=0.0):
return sample_refined_exp_s(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, noise_sampler=noise_sampler, denoise_to_zero=denoise_to_zero, simple_phi_calc=simple_phi_calc, c2=c2, ita=ita, momentum=momentum)
## modified from ReForge, original implementation ComfyUI
@torch.no_grad()
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfgpp=False):
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
old_denoised = None
sigmas = sigmas.to(x.device)
if cfgpp:
model.need_last_noise_uncond = True
model.inner_model.inner_model.forge_objects.unet.model_options["disable_cfg1_optimization"] = True
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
if sigmas[i + 1] == 0 or old_denoised is None:
# Euler method
if cfgpp:
d = model.last_noise_uncond
x = denoised + d * sigmas[i + 1]
else:
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1])
h = t_next - t
c2 = (t_prev - t) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
if cfgpp:
d = model.last_noise_uncond
x = denoised + d * sigma_hat
x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised)
old_denoised = denoised
return x
@torch.no_grad()
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfgpp=False)
@torch.no_grad()
def sample_res_multistep_cfgpp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfgpp=True)
|