| import torch |
|
|
| from lib_es.compat import to_d |
|
|
| from tqdm.auto import trange |
|
|
| from lib_es.utils import sampler_metadata |
|
|
|
|
| @sampler_metadata("Euler Negative") |
| @torch.no_grad() |
| def sample_euler_negative( |
| model, |
| x, |
| sigmas, |
| extra_args=None, |
| callback=None, |
| disable=None, |
| s_churn=0.0, |
| s_tmin=0.0, |
| s_tmax=float("inf"), |
| s_noise=1.0, |
| ): |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
|
|
| for i in trange(len(sigmas) - 1, disable=disable): |
| gamma = max(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 |
| eps = torch.randn_like(x) * s_noise |
| sigma_hat = sigmas[i] * (gamma + 1) |
| dt = sigmas[i + 1] - sigma_hat |
|
|
| if gamma > 0: |
| x = x - eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 |
|
|
| denoised = model(x, sigma_hat * s_in, **extra_args) |
| d = to_d(x, sigma_hat, denoised) |
|
|
| if callback is not None: |
| callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}) |
|
|
| |
| if sigmas[i + 1] > 0 and i // 2 == 1: |
| x = -x - d * dt |
| else: |
| x = x + d * dt |
|
|
| return x |
|
|