Spaces:
Running
Running
| """ | |
| Based on https://github.com/openai/guided-diffusion/blob/main/guided_diffusion | |
| and https://github.com/ehoogeboom/multinomial_diffusion | |
| """ | |
| import torch.nn.functional as F | |
| import torch | |
| import math | |
| import numpy as np | |
| from .utils import * | |
| """ | |
| Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 | |
| """ | |
| eps = 1e-8 | |
| def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): | |
| """ | |
| Get a pre-defined beta schedule for the given name. | |
| The beta schedule library consists of beta schedules which remain similar | |
| in the limit of num_diffusion_timesteps. | |
| Beta schedules may be added, but should not be removed or changed once | |
| they are committed to maintain backwards compatibility. | |
| """ | |
| if schedule_name == "linear": | |
| # Linear schedule from Ho et al, extended to work for any number of | |
| # diffusion steps. | |
| scale = 1000 / num_diffusion_timesteps | |
| beta_start = scale * 0.0001 | |
| beta_end = scale * 0.02 | |
| return np.linspace( | |
| beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 | |
| ) | |
| elif schedule_name == "cosine": | |
| return betas_for_alpha_bar( | |
| num_diffusion_timesteps, | |
| lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, | |
| ) | |
| else: | |
| raise NotImplementedError(f"unknown beta schedule: {schedule_name}") | |
| def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): | |
| """ | |
| Create a beta schedule that discretizes the given alpha_t_bar function, | |
| which defines the cumulative product of (1-beta) over time from t = [0,1]. | |
| :param num_diffusion_timesteps: the number of betas to produce. | |
| :param alpha_bar: a lambda that takes an argument t from 0 to 1 and | |
| produces the cumulative product of (1-beta) up to that | |
| part of the diffusion process. | |
| :param max_beta: the maximum beta to use; use values lower than 1 to | |
| prevent singularities. | |
| """ | |
| betas = [] | |
| for i in range(num_diffusion_timesteps): | |
| t1 = i / num_diffusion_timesteps | |
| t2 = (i + 1) / num_diffusion_timesteps | |
| betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) | |
| return np.array(betas) | |
| class GaussianMultinomialDiffusion(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_classes: np.array, | |
| num_numerical_features: int, | |
| denoise_fn, | |
| num_timesteps=1000, | |
| gaussian_loss_type='mse', | |
| gaussian_parametrization='eps', | |
| multinomial_loss_type='vb_stochastic', | |
| parametrization='x0', | |
| scheduler='cosine', | |
| device=torch.device('cpu') | |
| ): | |
| super(GaussianMultinomialDiffusion, self).__init__() | |
| assert multinomial_loss_type in ('vb_stochastic', 'vb_all') | |
| assert parametrization in ('x0', 'direct') | |
| if multinomial_loss_type == 'vb_all': | |
| print('Computing the loss using the bound on _all_ timesteps.' | |
| ' This is expensive both in terms of memory and computation.') | |
| self.num_numerical_features = num_numerical_features | |
| self.num_classes = num_classes # it as a vector [K1, K2, ..., Km] | |
| self.num_classes_expanded = torch.from_numpy( | |
| np.concatenate([num_classes[i].repeat(num_classes[i]) for i in range(len(num_classes))]) | |
| ).to(device) | |
| self.slices_for_classes = [np.arange(self.num_classes[0])] | |
| offsets = np.cumsum(self.num_classes) | |
| for i in range(1, len(offsets)): | |
| self.slices_for_classes.append(np.arange(offsets[i - 1], offsets[i])) | |
| self.offsets = torch.from_numpy(np.append([0], offsets)).to(device) | |
| self._denoise_fn = denoise_fn | |
| self.gaussian_loss_type = gaussian_loss_type | |
| self.gaussian_parametrization = gaussian_parametrization | |
| self.multinomial_loss_type = multinomial_loss_type | |
| self.num_timesteps = num_timesteps | |
| self.parametrization = parametrization | |
| self.scheduler = scheduler | |
| alphas = 1. - get_named_beta_schedule(scheduler, num_timesteps) | |
| alphas = torch.tensor(alphas.astype('float64')) | |
| betas = 1. - alphas | |
| log_alpha = np.log(alphas) | |
| log_cumprod_alpha = np.cumsum(log_alpha) | |
| log_1_min_alpha = log_1_min_a(log_alpha) | |
| log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) | |
| alphas_cumprod = np.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = torch.tensor(np.append(1.0, alphas_cumprod[:-1])) | |
| alphas_cumprod_next = torch.tensor(np.append(alphas_cumprod[1:], 0.0)) | |
| sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) | |
| sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) | |
| sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) | |
| sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) | |
| # Gaussian diffusion | |
| self.posterior_variance = ( | |
| betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
| ) | |
| self.posterior_log_variance_clipped = torch.from_numpy( | |
| np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) | |
| ).float().to(device) | |
| self.posterior_mean_coef1 = ( | |
| betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
| ).float().to(device) | |
| self.posterior_mean_coef2 = ( | |
| (1.0 - alphas_cumprod_prev) | |
| * np.sqrt(alphas.numpy()) | |
| / (1.0 - alphas_cumprod) | |
| ).float().to(device) | |
| assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 | |
| assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 | |
| assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5 | |
| # Convert to float32 and register buffers. | |
| self.register_buffer('alphas', alphas.float().to(device)) | |
| self.register_buffer('log_alpha', log_alpha.float().to(device)) | |
| self.register_buffer('log_1_min_alpha', log_1_min_alpha.float().to(device)) | |
| self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float().to(device)) | |
| self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float().to(device)) | |
| self.register_buffer('alphas_cumprod', alphas_cumprod.float().to(device)) | |
| self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float().to(device)) | |
| self.register_buffer('alphas_cumprod_next', alphas_cumprod_next.float().to(device)) | |
| self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod.float().to(device)) | |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', sqrt_one_minus_alphas_cumprod.float().to(device)) | |
| self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod.float().to(device)) | |
| self.register_buffer('sqrt_recipm1_alphas_cumprod', sqrt_recipm1_alphas_cumprod.float().to(device)) | |
| self.register_buffer('Lt_history', torch.zeros(num_timesteps)) | |
| self.register_buffer('Lt_count', torch.zeros(num_timesteps)) | |
| # Gaussian part | |
| def gaussian_q_mean_variance(self, x_start, t): | |
| mean = ( | |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| ) | |
| variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) | |
| log_variance = extract( | |
| self.log_1_min_cumprod_alpha, t, x_start.shape | |
| ) | |
| return mean, variance, log_variance | |
| def gaussian_q_sample(self, x_start, t, noise=None): | |
| if noise is None: | |
| noise = torch.randn_like(x_start) | |
| assert noise.shape == x_start.shape | |
| return ( | |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
| * noise | |
| ) | |
| def gaussian_q_posterior_mean_variance(self, x_start, x_t, t): | |
| assert x_start.shape == x_t.shape | |
| posterior_mean = ( | |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
| + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance_clipped = extract( | |
| self.posterior_log_variance_clipped, t, x_t.shape | |
| ) | |
| assert ( | |
| posterior_mean.shape[0] | |
| == posterior_variance.shape[0] | |
| == posterior_log_variance_clipped.shape[0] | |
| == x_start.shape[0] | |
| ) | |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
| def gaussian_p_mean_variance( | |
| self, model_output, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None | |
| ): | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| B, C = x.shape[:2] | |
| assert t.shape == (B,) | |
| model_variance = torch.cat([self.posterior_variance[1].unsqueeze(0).to(x.device), (1. - self.alphas)[1:]], dim=0) | |
| # model_variance = self.posterior_variance.to(x.device) | |
| model_log_variance = torch.log(model_variance) | |
| model_variance = extract(model_variance, t, x.shape) | |
| model_log_variance = extract(model_log_variance, t, x.shape) | |
| if self.gaussian_parametrization == 'eps': | |
| pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) | |
| elif self.gaussian_parametrization == 'x0': | |
| pred_xstart = model_output | |
| else: | |
| raise NotImplementedError | |
| model_mean, _, _ = self.gaussian_q_posterior_mean_variance( | |
| x_start=pred_xstart, x_t=x, t=t | |
| ) | |
| assert ( | |
| model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape | |
| ), f'{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}' | |
| return { | |
| "mean": model_mean, | |
| "variance": model_variance, | |
| "log_variance": model_log_variance, | |
| "pred_xstart": pred_xstart, | |
| } | |
| def _vb_terms_bpd( | |
| self, model_output, x_start, x_t, t, clip_denoised=False, model_kwargs=None | |
| ): | |
| true_mean, _, true_log_variance_clipped = self.gaussian_q_posterior_mean_variance( | |
| x_start=x_start, x_t=x_t, t=t | |
| ) | |
| out = self.gaussian_p_mean_variance( | |
| model_output, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs | |
| ) | |
| kl = normal_kl( | |
| true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] | |
| ) | |
| kl = mean_flat(kl) / np.log(2.0) | |
| decoder_nll = -discretized_gaussian_log_likelihood( | |
| x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] | |
| ) | |
| assert decoder_nll.shape == x_start.shape | |
| decoder_nll = mean_flat(decoder_nll) / np.log(2.0) | |
| # At the first timestep return the decoder NLL, | |
| # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) | |
| output = torch.where((t == 0), decoder_nll, kl) | |
| return {"output": output, "pred_xstart": out["pred_xstart"], "out_mean": out["mean"], "true_mean": true_mean} | |
| def _prior_gaussian(self, x_start): | |
| """ | |
| Get the prior KL term for the variational lower-bound, measured in | |
| bits-per-dim. | |
| This term can't be optimized, as it only depends on the encoder. | |
| :param x_start: the [N x C x ...] tensor of inputs. | |
| :return: a batch of [N] KL values (in bits), one per batch element. | |
| """ | |
| batch_size = x_start.shape[0] | |
| t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) | |
| qt_mean, _, qt_log_variance = self.gaussian_q_mean_variance(x_start, t) | |
| kl_prior = normal_kl( | |
| mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 | |
| ) | |
| return mean_flat(kl_prior) / np.log(2.0) | |
| def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None): | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| terms = {} | |
| if self.gaussian_loss_type == 'mse': | |
| terms["loss"] = mean_flat((noise - model_out) ** 2) | |
| elif self.gaussian_loss_type == 'kl': | |
| terms["loss"] = self._vb_terms_bpd( | |
| model_output=model_out, | |
| x_start=x_start, | |
| x_t=x_t, | |
| t=t, | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| )["output"] | |
| return terms['loss'] | |
| def _predict_xstart_from_eps(self, x_t, t, eps): | |
| assert x_t.shape == eps.shape | |
| return ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
| - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps | |
| ) | |
| def _predict_eps_from_xstart(self, x_t, t, pred_xstart): | |
| return ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
| - pred_xstart | |
| ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
| def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): | |
| """ | |
| Compute the mean for the previous step, given a function cond_fn that | |
| computes the gradient of a conditional log probability with respect to | |
| x. In particular, cond_fn computes grad(log(p(y|x))), and we want to | |
| condition on y. | |
| This uses the conditioning strategy from Sohl-Dickstein et al. (2015). | |
| """ | |
| gradient = cond_fn(x, t, **model_kwargs) | |
| new_mean = ( | |
| p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() | |
| ) | |
| return new_mean | |
| def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): | |
| """ | |
| Compute what the p_mean_variance output would have been, should the | |
| model's score function be conditioned by cond_fn. | |
| See condition_mean() for details on cond_fn. | |
| Unlike condition_mean(), this instead uses the conditioning strategy | |
| from Song et al (2020). | |
| """ | |
| alpha_bar = extract(self.alphas_cumprod, t, x.shape) | |
| eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) | |
| eps = eps - (1 - alpha_bar).sqrt() * cond_fn( | |
| x, t, **model_kwargs | |
| ) | |
| out = p_mean_var.copy() | |
| out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) | |
| out["mean"], _, _ = self.q_posterior_mean_variance( | |
| x_start=out["pred_xstart"], x_t=x, t=t | |
| ) | |
| return out | |
| def gaussian_p_sample( | |
| self, | |
| model_out, | |
| x, | |
| t, | |
| clip_denoised=False, | |
| denoised_fn=None, | |
| model_kwargs=None, | |
| cond_fn=None | |
| ): | |
| out = self.gaussian_p_mean_variance( | |
| model_out, | |
| x, | |
| t, | |
| clip_denoised=clip_denoised, | |
| denoised_fn=denoised_fn, | |
| model_kwargs=model_kwargs, | |
| ) | |
| noise = torch.randn_like(x) | |
| nonzero_mask = ( | |
| (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) | |
| ) # no noise when t == 0 | |
| if cond_fn is not None: | |
| out["mean"] = self.condition_mean( | |
| cond_fn, out, x, t, model_kwargs=model_kwargs | |
| ) | |
| sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise | |
| return {"sample": sample, "pred_xstart": out["pred_xstart"]} | |
| # Multinomial part | |
| def multinomial_kl(self, log_prob1, log_prob2): | |
| kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1) | |
| return kl | |
| def q_pred_one_timestep(self, log_x_t, t): | |
| log_alpha_t = extract(self.log_alpha, t, log_x_t.shape) | |
| log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape) | |
| # alpha_t * E[xt] + (1 - alpha_t) 1 / K | |
| log_probs = log_add_exp( | |
| log_x_t + log_alpha_t, | |
| log_1_min_alpha_t - torch.log(self.num_classes_expanded) | |
| ) | |
| return log_probs | |
| def q_pred(self, log_x_start, t): | |
| log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) | |
| log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape) | |
| log_probs = log_add_exp( | |
| log_x_start + log_cumprod_alpha_t, | |
| log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded) | |
| ) | |
| return log_probs | |
| def predict_start(self, model_out, log_x_t, t, out_dict): | |
| # model_out = self._denoise_fn(x_t, t.to(x_t.device), **out_dict) | |
| assert model_out.size(0) == log_x_t.size(0) | |
| assert model_out.size(1) == self.num_classes.sum(), f'{model_out.size()}' | |
| log_pred = torch.empty_like(model_out) | |
| for ix in self.slices_for_classes: | |
| log_pred[:, ix] = F.log_softmax(model_out[:, ix], dim=1) | |
| return log_pred | |
| def q_posterior(self, log_x_start, log_x_t, t): | |
| # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0) | |
| # where q(xt | xt-1, x0) = q(xt | xt-1). | |
| # EV_log_qxt_x0 = self.q_pred(log_x_start, t) | |
| # print('sum exp', EV_log_qxt_x0.exp().sum(1).mean()) | |
| # assert False | |
| # log_qxt_x0 = (log_x_t.exp() * EV_log_qxt_x0).sum(dim=1) | |
| t_minus_1 = t - 1 | |
| # Remove negative values, will not be used anyway for final decoder | |
| t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) | |
| log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) | |
| num_axes = (1,) * (len(log_x_start.size()) - 1) | |
| t_broadcast = t.to(log_x_start.device).view(-1, *num_axes) * torch.ones_like(log_x_start) | |
| log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0.to(torch.float32)) | |
| # unnormed_logprobs = log_EV_qxtmin_x0 + | |
| # log q_pred_one_timestep(x_t, t) | |
| # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! | |
| # Not very easy to see why this is true. But it is :) | |
| unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) | |
| log_EV_xtmin_given_xt_given_xstart = \ | |
| unnormed_logprobs \ | |
| - sliced_logsumexp(unnormed_logprobs, self.offsets) | |
| return log_EV_xtmin_given_xt_given_xstart | |
| def p_pred(self, model_out, log_x, t, out_dict): | |
| if self.parametrization == 'x0': | |
| log_x_recon = self.predict_start(model_out, log_x, t=t, out_dict=out_dict) | |
| log_model_pred = self.q_posterior( | |
| log_x_start=log_x_recon, log_x_t=log_x, t=t) | |
| elif self.parametrization == 'direct': | |
| log_model_pred = self.predict_start(model_out, log_x, t=t, out_dict=out_dict) | |
| else: | |
| raise ValueError | |
| return log_model_pred | |
| def p_sample(self, model_out, log_x, t, out_dict): | |
| model_log_prob = self.p_pred(model_out, log_x=log_x, t=t, out_dict=out_dict) | |
| out = self.log_sample_categorical(model_log_prob) | |
| return out | |
| def p_sample_loop(self, shape, out_dict): | |
| device = self.log_alpha.device | |
| b = shape[0] | |
| # start with random normal image. | |
| img = torch.randn(shape, device=device) | |
| for i in reversed(range(1, self.num_timesteps)): | |
| img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), out_dict) | |
| return img | |
| def _sample(self, image_size, out_dict, batch_size = 16): | |
| return self.p_sample_loop((batch_size, 3, image_size, image_size), out_dict) | |
| def interpolate(self, x1, x2, t = None, lam = 0.5): | |
| b, *_, device = *x1.shape, x1.device | |
| t = default(t, self.num_timesteps - 1) | |
| assert x1.shape == x2.shape | |
| t_batched = torch.stack([torch.tensor(t, device=device)] * b) | |
| xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) | |
| img = (1 - lam) * xt1 + lam * xt2 | |
| for i in reversed(range(0, t)): | |
| img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) | |
| return img | |
| def log_sample_categorical(self, logits): | |
| full_sample = [] | |
| for i in range(len(self.num_classes)): | |
| one_class_logits = logits[:, self.slices_for_classes[i]] | |
| uniform = torch.rand_like(one_class_logits) | |
| gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) | |
| sample = (gumbel_noise + one_class_logits).argmax(dim=1) | |
| full_sample.append(sample.unsqueeze(1)) | |
| full_sample = torch.cat(full_sample, dim=1) | |
| log_sample = index_to_log_onehot(full_sample, self.num_classes) | |
| return log_sample | |
| def q_sample(self, log_x_start, t): | |
| log_EV_qxt_x0 = self.q_pred(log_x_start, t) | |
| log_sample = self.log_sample_categorical(log_EV_qxt_x0) | |
| return log_sample | |
| def nll(self, log_x_start, out_dict): | |
| b = log_x_start.size(0) | |
| device = log_x_start.device | |
| loss = 0 | |
| for t in range(0, self.num_timesteps): | |
| t_array = (torch.ones(b, device=device) * t).long() | |
| kl = self.compute_Lt( | |
| log_x_start=log_x_start, | |
| log_x_t=self.q_sample(log_x_start=log_x_start, t=t_array), | |
| t=t_array, | |
| out_dict=out_dict) | |
| loss += kl | |
| loss += self.kl_prior(log_x_start) | |
| return loss | |
| def kl_prior(self, log_x_start): | |
| b = log_x_start.size(0) | |
| device = log_x_start.device | |
| ones = torch.ones(b, device=device).long() | |
| log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) | |
| log_half_prob = -torch.log(self.num_classes_expanded * torch.ones_like(log_qxT_prob)) | |
| kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) | |
| return sum_except_batch(kl_prior) | |
| def compute_Lt(self, model_out, log_x_start, log_x_t, t, out_dict, detach_mean=False): | |
| log_true_prob = self.q_posterior( | |
| log_x_start=log_x_start, log_x_t=log_x_t, t=t) | |
| log_model_prob = self.p_pred(model_out, log_x=log_x_t, t=t, out_dict=out_dict) | |
| if detach_mean: | |
| log_model_prob = log_model_prob.detach() | |
| kl = self.multinomial_kl(log_true_prob, log_model_prob) | |
| kl = sum_except_batch(kl) | |
| decoder_nll = -log_categorical(log_x_start, log_model_prob) | |
| decoder_nll = sum_except_batch(decoder_nll) | |
| mask = (t == torch.zeros_like(t)).float() | |
| loss = mask * decoder_nll + (1. - mask) * kl | |
| return loss | |
| def sample_time(self, b, device, method='uniform'): | |
| if method == 'importance': | |
| if not (self.Lt_count > 10).all(): | |
| return self.sample_time(b, device, method='uniform') | |
| Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 | |
| Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. | |
| pt_all = (Lt_sqrt / Lt_sqrt.sum()).to(device) | |
| t = torch.multinomial(pt_all, num_samples=b, replacement=True).to(device) | |
| pt = pt_all.gather(dim=0, index=t) | |
| return t, pt | |
| elif method == 'uniform': | |
| t = torch.randint(0, self.num_timesteps, (b,), device=device).long() | |
| pt = torch.ones_like(t).float() / self.num_timesteps | |
| return t, pt | |
| else: | |
| raise ValueError | |
| def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt, out_dict): | |
| if self.multinomial_loss_type == 'vb_stochastic': | |
| kl = self.compute_Lt( | |
| model_out, log_x_start, log_x_t, t, out_dict | |
| ) | |
| kl_prior = self.kl_prior(log_x_start) | |
| # Upweigh loss term of the kl | |
| vb_loss = kl / pt + kl_prior | |
| return vb_loss | |
| elif self.multinomial_loss_type == 'vb_all': | |
| # Expensive, dont do it ;). | |
| # DEPRECATED | |
| return -self.nll(log_x_start) | |
| else: | |
| raise ValueError() | |
| def log_prob(self, x, out_dict): | |
| b, device = x.size(0), x.device | |
| if self.training: | |
| return self._multinomial_loss(x, out_dict) | |
| else: | |
| log_x_start = index_to_log_onehot(x, self.num_classes) | |
| t, pt = self.sample_time(b, device, 'importance') | |
| kl = self.compute_Lt( | |
| log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t, out_dict) | |
| kl_prior = self.kl_prior(log_x_start) | |
| # Upweigh loss term of the kl | |
| loss = kl / pt + kl_prior | |
| return -loss | |
| def mixed_loss(self, x, out_dict): | |
| b = x.shape[0] | |
| device = x.device | |
| t, pt = self.sample_time(b, device, 'uniform') | |
| x_num = x[:, :self.num_numerical_features] | |
| x_cat = x[:, self.num_numerical_features:] | |
| x_num_t = x_num | |
| log_x_cat_t = x_cat | |
| if x_num.shape[1] > 0: | |
| noise = torch.randn_like(x_num) | |
| x_num_t = self.gaussian_q_sample(x_num, t, noise=noise) | |
| if x_cat.shape[1] > 0: | |
| log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes) | |
| log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t) | |
| x_in = torch.cat([x_num_t, log_x_cat_t], dim=1) | |
| model_out = self._denoise_fn( | |
| x_in, | |
| t, | |
| **out_dict | |
| ) | |
| model_out_num = model_out[:, :self.num_numerical_features] | |
| model_out_cat = model_out[:, self.num_numerical_features:] | |
| loss_multi = torch.zeros((1,)).float() | |
| loss_gauss = torch.zeros((1,)).float() | |
| if x_cat.shape[1] > 0: | |
| loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, t, pt, out_dict) / len(self.num_classes) | |
| if x_num.shape[1] > 0: | |
| loss_gauss = self._gaussian_loss(model_out_num, x_num, x_num_t, t, noise) | |
| # loss_multi = torch.where(out_dict['y'] == 1, loss_multi, 2 * loss_multi) | |
| # loss_gauss = torch.where(out_dict['y'] == 1, loss_gauss, 2 * loss_gauss) | |
| return loss_multi.mean(), loss_gauss.mean() | |
| def mixed_elbo(self, x0, out_dict): | |
| b = x0.size(0) | |
| device = x0.device | |
| x_num = x0[:, :self.num_numerical_features] | |
| x_cat = x0[:, self.num_numerical_features:] | |
| has_cat = x_cat.shape[1] > 0 | |
| if has_cat: | |
| log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes).to(device) | |
| gaussian_loss = [] | |
| xstart_mse = [] | |
| mse = [] | |
| mu_mse = [] | |
| out_mean = [] | |
| true_mean = [] | |
| multinomial_loss = [] | |
| for t in range(self.num_timesteps): | |
| t_array = (torch.ones(b, device=device) * t).long() | |
| noise = torch.randn_like(x_num) | |
| x_num_t = self.gaussian_q_sample(x_start=x_num, t=t_array, noise=noise) | |
| if has_cat: | |
| log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t_array) | |
| else: | |
| log_x_cat_t = x_cat | |
| model_out = self._denoise_fn( | |
| torch.cat([x_num_t, log_x_cat_t], dim=1), | |
| t_array, | |
| **out_dict | |
| ) | |
| model_out_num = model_out[:, :self.num_numerical_features] | |
| model_out_cat = model_out[:, self.num_numerical_features:] | |
| kl = torch.tensor([0.0]) | |
| if has_cat: | |
| kl = self.compute_Lt( | |
| model_out=model_out_cat, | |
| log_x_start=log_x_cat, | |
| log_x_t=log_x_cat_t, | |
| t=t_array, | |
| out_dict=out_dict | |
| ) | |
| out = self._vb_terms_bpd( | |
| model_out_num, | |
| x_start=x_num, | |
| x_t=x_num_t, | |
| t=t_array, | |
| clip_denoised=False | |
| ) | |
| multinomial_loss.append(kl) | |
| gaussian_loss.append(out["output"]) | |
| xstart_mse.append(mean_flat((out["pred_xstart"] - x_num) ** 2)) | |
| # mu_mse.append(mean_flat(out["mean_mse"])) | |
| out_mean.append(mean_flat(out["out_mean"])) | |
| true_mean.append(mean_flat(out["true_mean"])) | |
| eps = self._predict_eps_from_xstart(x_num_t, t_array, out["pred_xstart"]) | |
| mse.append(mean_flat((eps - noise) ** 2)) | |
| gaussian_loss = torch.stack(gaussian_loss, dim=1) | |
| multinomial_loss = torch.stack(multinomial_loss, dim=1) | |
| xstart_mse = torch.stack(xstart_mse, dim=1) | |
| mse = torch.stack(mse, dim=1) | |
| # mu_mse = torch.stack(mu_mse, dim=1) | |
| out_mean = torch.stack(out_mean, dim=1) | |
| true_mean = torch.stack(true_mean, dim=1) | |
| prior_gauss = self._prior_gaussian(x_num) | |
| prior_multin = torch.tensor([0.0]) | |
| if has_cat: | |
| prior_multin = self.kl_prior(log_x_cat) | |
| total_gauss = gaussian_loss.sum(dim=1) + prior_gauss | |
| total_multin = multinomial_loss.sum(dim=1) + prior_multin | |
| return { | |
| "total_gaussian": total_gauss, | |
| "total_multinomial": total_multin, | |
| "losses_gaussian": gaussian_loss, | |
| "losses_multinimial": multinomial_loss, | |
| "xstart_mse": xstart_mse, | |
| "mse": mse, | |
| # "mu_mse": mu_mse | |
| "out_mean": out_mean, | |
| "true_mean": true_mean | |
| } | |
| def gaussian_ddim_step( | |
| self, | |
| model_out_num, | |
| x, | |
| t, | |
| clip_denoised=False, | |
| denoised_fn=None, | |
| eta=0.0, | |
| model_kwargs=None, | |
| cond_fn=None | |
| ): | |
| out = self.gaussian_p_mean_variance( | |
| model_out_num, | |
| x, | |
| t, | |
| clip_denoised=clip_denoised, | |
| denoised_fn=denoised_fn, | |
| model_kwargs=None, | |
| ) | |
| if cond_fn is not None: | |
| out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) | |
| eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) | |
| alpha_bar = extract(self.alphas_cumprod, t, x.shape) | |
| alpha_bar_prev = extract(self.alphas_cumprod_prev, t, x.shape) | |
| sigma = ( | |
| eta | |
| * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) | |
| * torch.sqrt(1 - alpha_bar / alpha_bar_prev) | |
| ) | |
| noise = torch.randn_like(x) | |
| mean_pred = ( | |
| out["pred_xstart"] * torch.sqrt(alpha_bar_prev) | |
| + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps | |
| ) | |
| nonzero_mask = ( | |
| (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) | |
| ) # no noise when t == 0 | |
| sample = mean_pred + nonzero_mask * sigma * noise | |
| return sample | |
| def gaussian_ddim_sample( | |
| self, | |
| noise, | |
| T, | |
| out_dict, | |
| eta=0.0, | |
| model_kwargs=None, | |
| cond_fn=None | |
| ): | |
| x = noise | |
| b = x.shape[0] | |
| device = x.device | |
| for t in reversed(range(T)): | |
| print(f'Sample timestep {t:4d}', end='\r') | |
| t_array = (torch.ones(b, device=device) * t).long() | |
| out_num = self._denoise_fn(x, t_array, **out_dict) | |
| x = self.gaussian_ddim_step( | |
| out_num, | |
| x, | |
| t_array, | |
| model_kwargs=model_kwargs, | |
| cond_fn=cond_fn | |
| ) | |
| print() | |
| return x | |
| def gaussian_ddim_reverse_step( | |
| self, | |
| model_out_num, | |
| x, | |
| t, | |
| clip_denoised=False, | |
| eta=0.0 | |
| ): | |
| assert eta == 0.0, "Eta must be zero." | |
| out = self.gaussian_p_mean_variance( | |
| model_out_num, | |
| x, | |
| t, | |
| clip_denoised=clip_denoised, | |
| denoised_fn=None, | |
| model_kwargs=None, | |
| ) | |
| eps = ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x | |
| - out["pred_xstart"] | |
| ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) | |
| alpha_bar_next = extract(self.alphas_cumprod_next, t, x.shape) | |
| mean_pred = ( | |
| out["pred_xstart"] * torch.sqrt(alpha_bar_next) | |
| + torch.sqrt(1 - alpha_bar_next) * eps | |
| ) | |
| return mean_pred | |
| def gaussian_ddim_reverse_sample( | |
| self, | |
| x, | |
| T, | |
| out_dict, | |
| ): | |
| b = x.shape[0] | |
| device = x.device | |
| for t in range(T): | |
| print(f'Reverse timestep {t:4d}', end='\r') | |
| t_array = (torch.ones(b, device=device) * t).long() | |
| out_num = self._denoise_fn(x, t_array, **out_dict) | |
| x = self.gaussian_ddim_reverse_step( | |
| out_num, | |
| x, | |
| t_array, | |
| eta=0.0 | |
| ) | |
| print() | |
| return x | |
| def multinomial_ddim_step( | |
| self, | |
| model_out_cat, | |
| log_x_t, | |
| t, | |
| out_dict, | |
| eta=0.0 | |
| ): | |
| # not ddim, essentially | |
| log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t, t=t, out_dict=out_dict) | |
| alpha_bar = extract(self.alphas_cumprod, t, log_x_t.shape) | |
| alpha_bar_prev = extract(self.alphas_cumprod_prev, t, log_x_t.shape) | |
| sigma = ( | |
| eta | |
| * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) | |
| * torch.sqrt(1 - alpha_bar / alpha_bar_prev) | |
| ) | |
| coef1 = sigma | |
| coef2 = alpha_bar_prev - sigma * alpha_bar | |
| coef3 = 1 - coef1 - coef2 | |
| log_ps = torch.stack([ | |
| torch.log(coef1) + log_x_t, | |
| torch.log(coef2) + log_x0, | |
| torch.log(coef3) - torch.log(self.num_classes_expanded) | |
| ], dim=2) | |
| log_prob = torch.logsumexp(log_ps, dim=2) | |
| out = self.log_sample_categorical(log_prob) | |
| return out | |
| def sample_ddim(self, num_samples, y_dist, model_kwargs=None, cond_fn=None): | |
| b = num_samples | |
| device = self.log_alpha.device | |
| z_norm = torch.randn((b, self.num_numerical_features), device=device) | |
| has_cat = self.num_classes[0] != 0 | |
| log_z = torch.zeros((b, 0), device=device).float() | |
| if has_cat: | |
| uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device) | |
| log_z = self.log_sample_categorical(uniform_logits) | |
| y = torch.multinomial( | |
| y_dist, | |
| num_samples=b, | |
| replacement=True | |
| ) | |
| out_dict = {'y': y.long().to(device)} | |
| for i in reversed(range(0, self.num_timesteps)): | |
| print(f'Sample timestep {i:4d}', end='\r') | |
| t = torch.full((b,), i, device=device, dtype=torch.long) | |
| model_out = self._denoise_fn( | |
| torch.cat([z_norm, log_z], dim=1).float(), | |
| t, | |
| **out_dict | |
| ) | |
| model_out_num = model_out[:, :self.num_numerical_features] | |
| model_out_cat = model_out[:, self.num_numerical_features:] | |
| z_norm = self.gaussian_ddim_step( | |
| model_out_num, | |
| z_norm, | |
| t, | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| cond_fn=cond_fn | |
| ) | |
| if has_cat: | |
| log_z = self.multinomial_ddim_step(model_out_cat, log_z, t, out_dict) | |
| print() | |
| z_ohe = torch.exp(log_z).round() | |
| z_cat = log_z | |
| if has_cat: | |
| z_cat = ohe_to_categories(z_ohe, self.num_classes) | |
| sample = torch.cat([z_norm, z_cat], dim=1).cpu() | |
| return sample, out_dict | |
| def conditional_sample(self, ys, model_kwargs=None, cond_fn=None): | |
| device = self.log_alpha.device | |
| b = len(ys) | |
| z_norm = torch.randn((b, self.num_numerical_features), device=device) | |
| has_cat = self.num_classes[0] != 0 | |
| log_z = torch.zeros((b, 0), device=device).float() | |
| out_dict = {'y': ys.long().to(device)} | |
| for i in reversed(range(0, self.num_timesteps)): | |
| print(f'Sample timestep {i:4d}', end='\r') | |
| t = torch.full((b,), i, device=device, dtype=torch.long) | |
| model_out = self._denoise_fn( | |
| torch.cat([z_norm, log_z], dim=1).float(), | |
| t, | |
| **out_dict | |
| ) | |
| model_out_num = model_out[:, :self.num_numerical_features] | |
| model_out_cat = model_out[:, self.num_numerical_features:] | |
| z_norm = self.gaussian_p_sample( | |
| model_out_num, | |
| z_norm, | |
| t, | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| cond_fn=cond_fn | |
| )['sample'] | |
| if has_cat: | |
| log_z = self.p_sample(model_out_cat, log_z, t, out_dict) | |
| print() | |
| z_ohe = torch.exp(log_z).round() | |
| z_cat = log_z | |
| if has_cat: | |
| z_cat = ohe_to_categories(z_ohe, self.num_classes) | |
| sample = torch.cat([z_norm, z_cat], dim=1).cpu() | |
| return sample, out_dict | |
| def sample(self, num_samples, y_dist, model_kwargs=None, cond_fn=None): | |
| b = num_samples | |
| device = self.log_alpha.device | |
| z_norm = torch.randn((b, self.num_numerical_features), device=device) | |
| has_cat = self.num_classes[0] != 0 | |
| log_z = torch.zeros((b, 0), device=device).float() | |
| if has_cat: | |
| uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device) | |
| log_z = self.log_sample_categorical(uniform_logits) | |
| y = torch.multinomial( | |
| y_dist, | |
| num_samples=b, | |
| replacement=True | |
| ) | |
| out_dict = {'y': y.long().to(device)} | |
| for i in reversed(range(0, self.num_timesteps)): | |
| print(f'Sample timestep {i:4d}', end='\r') | |
| t = torch.full((b,), i, device=device, dtype=torch.long) | |
| model_out = self._denoise_fn( | |
| torch.cat([z_norm, log_z], dim=1).float(), | |
| t, | |
| **out_dict | |
| ) | |
| model_out_num = model_out[:, :self.num_numerical_features] | |
| model_out_cat = model_out[:, self.num_numerical_features:] | |
| z_norm = self.gaussian_p_sample( | |
| model_out_num, | |
| z_norm, | |
| t, | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| cond_fn=cond_fn | |
| )['sample'] | |
| if has_cat: | |
| log_z = self.p_sample(model_out_cat, log_z, t, out_dict) | |
| print() | |
| z_ohe = torch.exp(log_z).round() | |
| z_cat = log_z | |
| if has_cat: | |
| z_cat = ohe_to_categories(z_ohe, self.num_classes) | |
| sample = torch.cat([z_norm, z_cat], dim=1).cpu() | |
| return sample, out_dict | |
| def sample_all(self, num_samples, batch_size, y_dist, ddim=False, model_kwargs=None, cond_fn=None): | |
| if ddim: | |
| print('Sample using DDIM.') | |
| sample_fn = self.sample_ddim | |
| else: | |
| sample_fn = self.sample | |
| b = batch_size | |
| all_y = [] | |
| all_samples = [] | |
| num_generated = 0 | |
| while num_generated < num_samples: | |
| sample, out_dict = sample_fn(b, y_dist, model_kwargs=model_kwargs, cond_fn=cond_fn) | |
| mask_nan = torch.any(sample.isnan(), dim=1) | |
| sample = sample[~mask_nan] | |
| out_dict['y'] = out_dict['y'][~mask_nan] | |
| all_samples.append(sample) | |
| all_y.append(out_dict['y'].cpu()) | |
| if sample.shape[0] != b: | |
| raise FoundNANsError | |
| num_generated += sample.shape[0] | |
| x_gen = torch.cat(all_samples, dim=0)[:num_samples] | |
| y_gen = torch.cat(all_y, dim=0)[:num_samples] | |
| return x_gen, y_gen |