Spaces:
Sleeping
Sleeping
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
| """ | |
| GaussianDiffusion wraps operators for denoising diffusion models, including the | |
| diffusion and denoising processes, as well as the loss evaluation. | |
| """ | |
| import torch | |
| import torchsde | |
| import random | |
| from tqdm.auto import trange | |
| __all__ = ['GaussianDiffusion'] | |
| def _i(tensor, t, x): | |
| """ | |
| Index tensor using t and format the output according to x. | |
| """ | |
| shape = (x.size(0), ) + (1, ) * (x.ndim - 1) | |
| return tensor[t.to(tensor.device)].view(shape).to(x.device) | |
| class BatchedBrownianTree: | |
| """ | |
| A wrapper around torchsde.BrownianTree that enables batches of entropy. | |
| """ | |
| def __init__(self, x, t0, t1, seed=None, **kwargs): | |
| t0, t1, self.sign = self.sort(t0, t1) | |
| w0 = kwargs.get('w0', torch.zeros_like(x)) | |
| if seed is None: | |
| seed = torch.randint(0, 2 ** 63 - 1, []).item() | |
| self.batched = True | |
| try: | |
| assert len(seed) == x.shape[0] | |
| w0 = w0[0] | |
| except TypeError: | |
| seed = [seed] | |
| self.batched = False | |
| self.trees = [torchsde.BrownianTree( | |
| t0, w0, t1, entropy=s, **kwargs | |
| ) for s in seed] | |
| def sort(a, b): | |
| return (a, b, 1) if a < b else (b, a, -1) | |
| def __call__(self, t0, t1): | |
| t0, t1, sign = self.sort(t0, t1) | |
| w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) | |
| return w if self.batched else w[0] | |
| class BrownianTreeNoiseSampler: | |
| """ | |
| A noise sampler backed by a torchsde.BrownianTree. | |
| Args: | |
| x (Tensor): The tensor whose shape, device and dtype to use to generate | |
| random samples. | |
| sigma_min (float): The low end of the valid interval. | |
| sigma_max (float): The high end of the valid interval. | |
| seed (int or List[int]): The random seed. If a list of seeds is | |
| supplied instead of a single integer, then the noise sampler will | |
| use one BrownianTree per batch item, each with its own seed. | |
| transform (callable): A function that maps sigma to the sampler's | |
| internal timestep. | |
| """ | |
| def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): | |
| self.transform = transform | |
| t0 = self.transform(torch.as_tensor(sigma_min)) | |
| t1 = self.transform(torch.as_tensor(sigma_max)) | |
| self.tree = BatchedBrownianTree(x, t0, t1, seed) | |
| def __call__(self, sigma, sigma_next): | |
| t0 = self.transform(torch.as_tensor(sigma)) | |
| t1 = self.transform(torch.as_tensor(sigma_next)) | |
| return self.tree(t0, t1) / (t1 - t0).abs().sqrt() | |
| def get_scalings(sigma): | |
| c_out = -sigma | |
| c_in = 1 / (sigma ** 2 + 1. ** 2) ** 0.5 | |
| return c_out, c_in | |
| def sample_dpmpp_2m_sde( | |
| noise, | |
| model, | |
| sigmas, | |
| eta=1., | |
| s_noise=1., | |
| solver_type='midpoint', | |
| show_progress=True | |
| ): | |
| """ | |
| DPM-Solver++ (2M) SDE. | |
| """ | |
| assert solver_type in {'heun', 'midpoint'} | |
| x = noise * sigmas[0] | |
| sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[sigmas < float('inf')].max() | |
| noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) | |
| old_denoised = None | |
| h_last = None | |
| for i in trange(len(sigmas) - 1, disable=not show_progress): | |
| if sigmas[i] == float('inf'): | |
| # Euler method | |
| denoised = model(noise, sigmas[i]) | |
| x = denoised + sigmas[i + 1] * noise | |
| else: | |
| _, c_in = get_scalings(sigmas[i]) | |
| denoised = model(x * c_in, sigmas[i]) | |
| if sigmas[i + 1] == 0: | |
| # Denoising step | |
| x = denoised | |
| else: | |
| # DPM-Solver++(2M) SDE | |
| t, s = -sigmas[i].log(), -sigmas[i + 1].log() | |
| h = s - t | |
| eta_h = eta * h | |
| x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \ | |
| (-h - eta_h).expm1().neg() * denoised | |
| if old_denoised is not None: | |
| r = h_last / h | |
| if solver_type == 'heun': | |
| x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \ | |
| (1 / r) * (denoised - old_denoised) | |
| elif solver_type == 'midpoint': | |
| x = x + 0.5 * (-h - eta_h).expm1().neg() * \ | |
| (1 / r) * (denoised - old_denoised) | |
| x = x + noise_sampler( | |
| sigmas[i], | |
| sigmas[i + 1] | |
| ) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise | |
| old_denoised = denoised | |
| h_last = h | |
| return x | |
| class GaussianDiffusion(object): | |
| def __init__(self, sigmas, prediction_type='eps'): | |
| assert prediction_type in {'x0', 'eps', 'v'} | |
| self.sigmas = sigmas.float() # noise coefficients | |
| self.alphas = torch.sqrt(1 - sigmas ** 2).float() # signal coefficients | |
| self.num_timesteps = len(sigmas) | |
| self.prediction_type = prediction_type | |
| def diffuse(self, x0, t, noise=None): | |
| """ | |
| Add Gaussian noise to signal x0 according to: | |
| q(x_t | x_0) = N(x_t | alpha_t x_0, sigma_t^2 I). | |
| """ | |
| noise = torch.randn_like(x0) if noise is None else noise | |
| xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise | |
| return xt | |
| def denoise( | |
| self, | |
| xt, | |
| t, | |
| s, | |
| model, | |
| model_kwargs={}, | |
| guide_scale=None, | |
| guide_rescale=None, | |
| clamp=None, | |
| percentile=None | |
| ): | |
| """ | |
| Apply one step of denoising from the posterior distribution q(x_s | x_t, x0). | |
| Since x0 is not available, estimate the denoising results using the learned | |
| distribution p(x_s | x_t, \hat{x}_0 == f(x_t)). | |
| """ | |
| s = t - 1 if s is None else s | |
| # hyperparams | |
| sigmas = _i(self.sigmas, t, xt) | |
| alphas = _i(self.alphas, t, xt) | |
| alphas_s = _i(self.alphas, s.clamp(0), xt) | |
| alphas_s[s < 0] = 1. | |
| sigmas_s = torch.sqrt(1 - alphas_s ** 2) | |
| # precompute variables | |
| betas = 1 - (alphas / alphas_s) ** 2 | |
| coef1 = betas * alphas_s / sigmas ** 2 | |
| coef2 = (alphas * sigmas_s ** 2) / (alphas_s * sigmas ** 2) | |
| var = betas * (sigmas_s / sigmas) ** 2 | |
| log_var = torch.log(var).clamp_(-20, 20) | |
| # prediction | |
| if guide_scale is None: | |
| assert isinstance(model_kwargs, dict) | |
| out = model(xt, t=t, **model_kwargs) | |
| else: | |
| # classifier-free guidance (arXiv:2207.12598) | |
| # model_kwargs[0]: conditional kwargs | |
| # model_kwargs[1]: non-conditional kwargs | |
| assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 | |
| y_out = model(xt, t=t, **model_kwargs[0]) | |
| if guide_scale == 1.: | |
| out = y_out | |
| else: | |
| u_out = model(xt, t=t, **model_kwargs[1]) | |
| out = u_out + guide_scale * (y_out - u_out) | |
| # rescale the output according to arXiv:2305.08891 | |
| if guide_rescale is not None: | |
| assert guide_rescale >= 0 and guide_rescale <= 1 | |
| ratio = (y_out.flatten(1).std(dim=1) / ( | |
| out.flatten(1).std(dim=1) + 1e-12 | |
| )).view((-1, ) + (1, ) * (y_out.ndim - 1)) | |
| out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 | |
| # compute x0 | |
| if self.prediction_type == 'x0': | |
| x0 = out | |
| elif self.prediction_type == 'eps': | |
| x0 = (xt - sigmas * out) / alphas | |
| elif self.prediction_type == 'v': | |
| x0 = alphas * xt - sigmas * out | |
| else: | |
| raise NotImplementedError( | |
| f'prediction_type {self.prediction_type} not implemented' | |
| ) | |
| # restrict the range of x0 | |
| if percentile is not None: | |
| # NOTE: percentile should only be used when data is within range [-1, 1] | |
| assert percentile > 0 and percentile <= 1 | |
| s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) | |
| s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) | |
| x0 = torch.min(s, torch.max(-s, x0)) / s | |
| elif clamp is not None: | |
| x0 = x0.clamp(-clamp, clamp) | |
| # recompute eps using the restricted x0 | |
| eps = (xt - alphas * x0) / sigmas | |
| # compute mu (mean of posterior distribution) using the restricted x0 | |
| mu = coef1 * x0 + coef2 * xt | |
| return mu, var, log_var, x0, eps | |
| def sample( | |
| self, | |
| noise, | |
| model, | |
| model_kwargs={}, | |
| condition_fn=None, | |
| guide_scale=None, | |
| guide_rescale=None, | |
| clamp=None, | |
| percentile=None, | |
| solver='euler_a', | |
| steps=20, | |
| t_max=None, | |
| t_min=None, | |
| discretization=None, | |
| discard_penultimate_step=None, | |
| return_intermediate=None, | |
| show_progress=False, | |
| seed=-1, | |
| **kwargs | |
| ): | |
| # sanity check | |
| assert isinstance(steps, (int, torch.LongTensor)) | |
| assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) | |
| assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) | |
| assert discretization in (None, 'leading', 'linspace', 'trailing') | |
| assert discard_penultimate_step in (None, True, False) | |
| assert return_intermediate in (None, 'x0', 'xt') | |
| # function of diffusion solver | |
| solver_fn = { | |
| # 'heun': sample_heun, | |
| 'dpmpp_2m_sde': sample_dpmpp_2m_sde | |
| }[solver] | |
| # options | |
| schedule = 'karras' if 'karras' in solver else None | |
| discretization = discretization or 'linspace' | |
| seed = seed if seed >= 0 else random.randint(0, 2 ** 31) | |
| if isinstance(steps, torch.LongTensor): | |
| discard_penultimate_step = False | |
| if discard_penultimate_step is None: | |
| discard_penultimate_step = True if solver in ( | |
| 'dpm2', | |
| 'dpm2_ancestral', | |
| 'dpmpp_2m_sde', | |
| 'dpm2_karras', | |
| 'dpm2_ancestral_karras', | |
| 'dpmpp_2m_sde_karras' | |
| ) else False | |
| # function for denoising xt to get x0 | |
| intermediates = [] | |
| def model_fn(xt, sigma): | |
| # denoising | |
| t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() | |
| x0 = self.denoise( | |
| xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, | |
| percentile | |
| )[-2] | |
| # collect intermediate outputs | |
| if return_intermediate == 'xt': | |
| intermediates.append(xt) | |
| elif return_intermediate == 'x0': | |
| intermediates.append(x0) | |
| return x0 | |
| # get timesteps | |
| if isinstance(steps, int): | |
| steps += 1 if discard_penultimate_step else 0 | |
| t_max = self.num_timesteps - 1 if t_max is None else t_max | |
| t_min = 0 if t_min is None else t_min | |
| # discretize timesteps | |
| if discretization == 'leading': | |
| steps = torch.arange( | |
| t_min, t_max + 1, (t_max - t_min + 1) / steps | |
| ).flip(0) | |
| elif discretization == 'linspace': | |
| steps = torch.linspace(t_max, t_min, steps) | |
| elif discretization == 'trailing': | |
| steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps)) | |
| else: | |
| raise NotImplementedError( | |
| f'{discretization} discretization not implemented' | |
| ) | |
| steps = steps.clamp_(t_min, t_max) | |
| steps = torch.as_tensor(steps, dtype=torch.float32, device=noise.device) | |
| # get sigmas | |
| sigmas = self._t_to_sigma(steps) | |
| sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) | |
| if schedule == 'karras': | |
| if sigmas[0] == float('inf'): | |
| sigmas = karras_schedule( | |
| n=len(steps) - 1, | |
| sigma_min=sigmas[sigmas > 0].min().item(), | |
| sigma_max=sigmas[sigmas < float('inf')].max().item(), | |
| rho=7. | |
| ).to(sigmas) | |
| sigmas = torch.cat([ | |
| sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1]) | |
| ]) | |
| else: | |
| sigmas = karras_schedule( | |
| n=len(steps), | |
| sigma_min=sigmas[sigmas > 0].min().item(), | |
| sigma_max=sigmas.max().item(), | |
| rho=7. | |
| ).to(sigmas) | |
| sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) | |
| if discard_penultimate_step: | |
| sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) | |
| # sampling | |
| x0 = solver_fn( | |
| noise, | |
| model_fn, | |
| sigmas, | |
| show_progress=show_progress, | |
| **kwargs | |
| ) | |
| return (x0, intermediates) if return_intermediate is not None else x0 | |
| def ddim_reverse_sample( | |
| self, | |
| xt, | |
| t, | |
| model, | |
| model_kwargs={}, | |
| clamp=None, | |
| percentile=None, | |
| guide_scale=None, | |
| guide_rescale=None, | |
| ddim_timesteps=20, | |
| reverse_steps=600 | |
| ): | |
| r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). | |
| """ | |
| stride = reverse_steps // ddim_timesteps | |
| # predict distribution of p(x_{t-1} | x_t) | |
| _, _, _, x0, eps = self.denoise( | |
| xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, | |
| percentile | |
| ) | |
| # derive variables | |
| s = (t + stride).clamp(0, reverse_steps-1) | |
| # hyperparams | |
| sigmas = _i(self.sigmas, t, xt) | |
| alphas = _i(self.alphas, t, xt) | |
| alphas_s = _i(self.alphas, s.clamp(0), xt) | |
| alphas_s[s < 0] = 1. | |
| sigmas_s = torch.sqrt(1 - alphas_s ** 2) | |
| # reverse sample | |
| mu = alphas_s * x0 + sigmas_s * eps | |
| return mu, x0 | |
| def ddim_reverse_sample_loop( | |
| self, | |
| x0, | |
| model, | |
| model_kwargs={}, | |
| clamp=None, | |
| percentile=None, | |
| guide_scale=None, | |
| guide_rescale=None, | |
| ddim_timesteps=20, | |
| reverse_steps=600 | |
| ): | |
| # prepare input | |
| b = x0.size(0) | |
| xt = x0 | |
| # reconstruction steps | |
| steps = torch.arange(0, reverse_steps, reverse_steps // ddim_timesteps) | |
| for step in steps: | |
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |
| xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, guide_rescale, ddim_timesteps, reverse_steps) | |
| return xt | |
| def _sigma_to_t(self, sigma): | |
| if sigma == float('inf'): | |
| t = torch.full_like(sigma, len(self.sigmas) - 1) | |
| else: | |
| log_sigmas = torch.sqrt( | |
| self.sigmas ** 2 / (1 - self.sigmas ** 2) | |
| ).log().to(sigma) | |
| log_sigma = sigma.log() | |
| dists = log_sigma - log_sigmas[:, None] | |
| low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp( | |
| max=log_sigmas.shape[0] - 2 | |
| ) | |
| high_idx = low_idx + 1 | |
| low, high = log_sigmas[low_idx], log_sigmas[high_idx] | |
| w = (low - log_sigma) / (low - high) | |
| w = w.clamp(0, 1) | |
| t = (1 - w) * low_idx + w * high_idx | |
| t = t.view(sigma.shape) | |
| if t.ndim == 0: | |
| t = t.unsqueeze(0) | |
| return t | |
| def _t_to_sigma(self, t): | |
| t = t.float() | |
| low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | |
| log_sigmas = torch.sqrt(self.sigmas ** 2 / (1 - self.sigmas ** 2)).log().to(t) | |
| log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx] | |
| log_sigma[torch.isnan(log_sigma) | torch.isinf(log_sigma)] = float('inf') | |
| return log_sigma.exp() | |
| def prev_step(self, model_out, t, xt, inference_steps=50): | |
| prev_t = t - self.num_timesteps // inference_steps | |
| sigmas = _i(self.sigmas, t, xt) | |
| alphas = _i(self.alphas, t, xt) | |
| alphas_prev = _i(self.alphas, prev_t.clamp(0), xt) | |
| alphas_prev[prev_t < 0] = 1. | |
| sigmas_prev = torch.sqrt(1 - alphas_prev ** 2) | |
| x0 = alphas * xt - sigmas * model_out | |
| eps = (xt - alphas * x0) / sigmas | |
| prev_sample = alphas_prev * x0 + sigmas_prev * eps | |
| return prev_sample | |
| def next_step(self, model_out, t, xt, inference_steps=50): | |
| t, next_t = min(t - self.num_timesteps // inference_steps, 999), t | |
| sigmas = _i(self.sigmas, t, xt) | |
| alphas = _i(self.alphas, t, xt) | |
| alphas_next = _i(self.alphas, next_t.clamp(0), xt) | |
| alphas_next[next_t < 0] = 1. | |
| sigmas_next = torch.sqrt(1 - alphas_next ** 2) | |
| x0 = alphas * xt - sigmas * model_out | |
| eps = (xt - alphas * x0) / sigmas | |
| next_sample = alphas_next * x0 + sigmas_next * eps | |
| return next_sample | |
| def get_noise_pred_single(self, xt, t, model, model_kwargs): | |
| assert isinstance(model_kwargs, dict) | |
| out = model(xt, t=t, **model_kwargs) | |
| return out | |