dikdimon's picture
Upload 58 files
97923d1 verified
import torch
from lib_es.compat import to_d
from tqdm.auto import trange
from lib_es.utils import sampler_metadata
@sampler_metadata("Euler Negative")
@torch.no_grad()
def sample_euler_negative(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = max(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
dt = sigmas[i + 1] - sigma_hat
if gamma > 0:
x = x - eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
# Euler method
if sigmas[i + 1] > 0 and i // 2 == 1:
x = -x - d * dt
else:
x = x + d * dt
return x