IRG / baselines /ClavaDDPM /tab_ddpm /gaussian_multinomial_diffsuion.py
Zilong-Zhao's picture
first commit
c4ac745
"""
Based on https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
and https://github.com/ehoogeboom/multinomial_diffusion
"""
import torch.nn.functional as F
import torch
import math
import numpy as np
from .utils import *
"""
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""
eps = 1e-8
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class GaussianMultinomialDiffusion(torch.nn.Module):
def __init__(
self,
num_classes: np.array,
num_numerical_features: int,
denoise_fn,
num_timesteps=1000,
gaussian_loss_type='mse',
gaussian_parametrization='eps',
multinomial_loss_type='vb_stochastic',
parametrization='x0',
scheduler='cosine',
device=torch.device('cpu')
):
super(GaussianMultinomialDiffusion, self).__init__()
assert multinomial_loss_type in ('vb_stochastic', 'vb_all')
assert parametrization in ('x0', 'direct')
if multinomial_loss_type == 'vb_all':
print('Computing the loss using the bound on _all_ timesteps.'
' This is expensive both in terms of memory and computation.')
self.num_numerical_features = num_numerical_features
self.num_classes = num_classes # it as a vector [K1, K2, ..., Km]
self.num_classes_expanded = torch.from_numpy(
np.concatenate([num_classes[i].repeat(num_classes[i]) for i in range(len(num_classes))])
).to(device)
self.slices_for_classes = [np.arange(self.num_classes[0])]
offsets = np.cumsum(self.num_classes)
for i in range(1, len(offsets)):
self.slices_for_classes.append(np.arange(offsets[i - 1], offsets[i]))
self.offsets = torch.from_numpy(np.append([0], offsets)).to(device)
self._denoise_fn = denoise_fn
self.gaussian_loss_type = gaussian_loss_type
self.gaussian_parametrization = gaussian_parametrization
self.multinomial_loss_type = multinomial_loss_type
self.num_timesteps = num_timesteps
self.parametrization = parametrization
self.scheduler = scheduler
alphas = 1. - get_named_beta_schedule(scheduler, num_timesteps)
alphas = torch.tensor(alphas.astype('float64'))
betas = 1. - alphas
log_alpha = np.log(alphas)
log_cumprod_alpha = np.cumsum(log_alpha)
log_1_min_alpha = log_1_min_a(log_alpha)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.tensor(np.append(1.0, alphas_cumprod[:-1]))
alphas_cumprod_next = torch.tensor(np.append(alphas_cumprod[1:], 0.0))
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
# Gaussian diffusion
self.posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
self.posterior_log_variance_clipped = torch.from_numpy(
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
).float().to(device)
self.posterior_mean_coef1 = (
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
).float().to(device)
self.posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev)
* np.sqrt(alphas.numpy())
/ (1.0 - alphas_cumprod)
).float().to(device)
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5
# Convert to float32 and register buffers.
self.register_buffer('alphas', alphas.float().to(device))
self.register_buffer('log_alpha', log_alpha.float().to(device))
self.register_buffer('log_1_min_alpha', log_1_min_alpha.float().to(device))
self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float().to(device))
self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float().to(device))
self.register_buffer('alphas_cumprod', alphas_cumprod.float().to(device))
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float().to(device))
self.register_buffer('alphas_cumprod_next', alphas_cumprod_next.float().to(device))
self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod.float().to(device))
self.register_buffer('sqrt_one_minus_alphas_cumprod', sqrt_one_minus_alphas_cumprod.float().to(device))
self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod.float().to(device))
self.register_buffer('sqrt_recipm1_alphas_cumprod', sqrt_recipm1_alphas_cumprod.float().to(device))
self.register_buffer('Lt_history', torch.zeros(num_timesteps))
self.register_buffer('Lt_count', torch.zeros(num_timesteps))
# Gaussian part
def gaussian_q_mean_variance(self, x_start, t):
mean = (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(
self.log_1_min_cumprod_alpha, t, x_start.shape
)
return mean, variance, log_variance
def gaussian_q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def gaussian_q_posterior_mean_variance(self, x_start, x_t, t):
assert x_start.shape == x_t.shape
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def gaussian_p_mean_variance(
self, model_output, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None
):
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_variance = torch.cat([self.posterior_variance[1].unsqueeze(0).to(x.device), (1. - self.alphas)[1:]], dim=0)
# model_variance = self.posterior_variance.to(x.device)
model_log_variance = torch.log(model_variance)
model_variance = extract(model_variance, t, x.shape)
model_log_variance = extract(model_log_variance, t, x.shape)
if self.gaussian_parametrization == 'eps':
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
elif self.gaussian_parametrization == 'x0':
pred_xstart = model_output
else:
raise NotImplementedError
model_mean, _, _ = self.gaussian_q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
), f'{model_mean.shape}, {model_log_variance.shape}, {pred_xstart.shape}, {x.shape}'
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
}
def _vb_terms_bpd(
self, model_output, x_start, x_t, t, clip_denoised=False, model_kwargs=None
):
true_mean, _, true_log_variance_clipped = self.gaussian_q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.gaussian_p_mean_variance(
model_output, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = torch.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"], "out_mean": out["mean"], "true_mean": true_mean}
def _prior_gaussian(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.gaussian_q_mean_variance(x_start, t)
kl_prior = normal_kl(
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
return mean_flat(kl_prior) / np.log(2.0)
def _gaussian_loss(self, model_out, x_start, x_t, t, noise, model_kwargs=None):
if model_kwargs is None:
model_kwargs = {}
terms = {}
if self.gaussian_loss_type == 'mse':
terms["loss"] = mean_flat((noise - model_out) ** 2)
elif self.gaussian_loss_type == 'kl':
terms["loss"] = self._vb_terms_bpd(
model_output=model_out,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)["output"]
return terms['loss']
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- pred_xstart
) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **model_kwargs)
new_mean = (
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
x, t, **model_kwargs
)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(
x_start=out["pred_xstart"], x_t=x, t=t
)
return out
def gaussian_p_sample(
self,
model_out,
x,
t,
clip_denoised=False,
denoised_fn=None,
model_kwargs=None,
cond_fn=None
):
out = self.gaussian_p_mean_variance(
model_out,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = torch.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
# Multinomial part
def multinomial_kl(self, log_prob1, log_prob2):
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
return kl
def q_pred_one_timestep(self, log_x_t, t):
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - torch.log(self.num_classes_expanded)
)
return log_probs
def q_pred(self, log_x_start, t):
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)
log_probs = log_add_exp(
log_x_start + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - torch.log(self.num_classes_expanded)
)
return log_probs
def predict_start(self, model_out, log_x_t, t, out_dict):
# model_out = self._denoise_fn(x_t, t.to(x_t.device), **out_dict)
assert model_out.size(0) == log_x_t.size(0)
assert model_out.size(1) == self.num_classes.sum(), f'{model_out.size()}'
log_pred = torch.empty_like(model_out)
for ix in self.slices_for_classes:
log_pred[:, ix] = F.log_softmax(model_out[:, ix], dim=1)
return log_pred
def q_posterior(self, log_x_start, log_x_t, t):
# q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
# where q(xt | xt-1, x0) = q(xt | xt-1).
# EV_log_qxt_x0 = self.q_pred(log_x_start, t)
# print('sum exp', EV_log_qxt_x0.exp().sum(1).mean())
# assert False
# log_qxt_x0 = (log_x_t.exp() * EV_log_qxt_x0).sum(dim=1)
t_minus_1 = t - 1
# Remove negative values, will not be used anyway for final decoder
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)
num_axes = (1,) * (len(log_x_start.size()) - 1)
t_broadcast = t.to(log_x_start.device).view(-1, *num_axes) * torch.ones_like(log_x_start)
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0.to(torch.float32))
# unnormed_logprobs = log_EV_qxtmin_x0 +
# log q_pred_one_timestep(x_t, t)
# Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
# Not very easy to see why this is true. But it is :)
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
log_EV_xtmin_given_xt_given_xstart = \
unnormed_logprobs \
- sliced_logsumexp(unnormed_logprobs, self.offsets)
return log_EV_xtmin_given_xt_given_xstart
def p_pred(self, model_out, log_x, t, out_dict):
if self.parametrization == 'x0':
log_x_recon = self.predict_start(model_out, log_x, t=t, out_dict=out_dict)
log_model_pred = self.q_posterior(
log_x_start=log_x_recon, log_x_t=log_x, t=t)
elif self.parametrization == 'direct':
log_model_pred = self.predict_start(model_out, log_x, t=t, out_dict=out_dict)
else:
raise ValueError
return log_model_pred
@torch.no_grad()
def p_sample(self, model_out, log_x, t, out_dict):
model_log_prob = self.p_pred(model_out, log_x=log_x, t=t, out_dict=out_dict)
out = self.log_sample_categorical(model_log_prob)
return out
@torch.no_grad()
def p_sample_loop(self, shape, out_dict):
device = self.log_alpha.device
b = shape[0]
# start with random normal image.
img = torch.randn(shape, device=device)
for i in reversed(range(1, self.num_timesteps)):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), out_dict)
return img
@torch.no_grad()
def _sample(self, image_size, out_dict, batch_size = 16):
return self.p_sample_loop((batch_size, 3, image_size, image_size), out_dict)
@torch.no_grad()
def interpolate(self, x1, x2, t = None, lam = 0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in reversed(range(0, t)):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
return img
def log_sample_categorical(self, logits):
full_sample = []
for i in range(len(self.num_classes)):
one_class_logits = logits[:, self.slices_for_classes[i]]
uniform = torch.rand_like(one_class_logits)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
sample = (gumbel_noise + one_class_logits).argmax(dim=1)
full_sample.append(sample.unsqueeze(1))
full_sample = torch.cat(full_sample, dim=1)
log_sample = index_to_log_onehot(full_sample, self.num_classes)
return log_sample
def q_sample(self, log_x_start, t):
log_EV_qxt_x0 = self.q_pred(log_x_start, t)
log_sample = self.log_sample_categorical(log_EV_qxt_x0)
return log_sample
def nll(self, log_x_start, out_dict):
b = log_x_start.size(0)
device = log_x_start.device
loss = 0
for t in range(0, self.num_timesteps):
t_array = (torch.ones(b, device=device) * t).long()
kl = self.compute_Lt(
log_x_start=log_x_start,
log_x_t=self.q_sample(log_x_start=log_x_start, t=t_array),
t=t_array,
out_dict=out_dict)
loss += kl
loss += self.kl_prior(log_x_start)
return loss
def kl_prior(self, log_x_start):
b = log_x_start.size(0)
device = log_x_start.device
ones = torch.ones(b, device=device).long()
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
log_half_prob = -torch.log(self.num_classes_expanded * torch.ones_like(log_qxT_prob))
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
return sum_except_batch(kl_prior)
def compute_Lt(self, model_out, log_x_start, log_x_t, t, out_dict, detach_mean=False):
log_true_prob = self.q_posterior(
log_x_start=log_x_start, log_x_t=log_x_t, t=t)
log_model_prob = self.p_pred(model_out, log_x=log_x_t, t=t, out_dict=out_dict)
if detach_mean:
log_model_prob = log_model_prob.detach()
kl = self.multinomial_kl(log_true_prob, log_model_prob)
kl = sum_except_batch(kl)
decoder_nll = -log_categorical(log_x_start, log_model_prob)
decoder_nll = sum_except_batch(decoder_nll)
mask = (t == torch.zeros_like(t)).float()
loss = mask * decoder_nll + (1. - mask) * kl
return loss
def sample_time(self, b, device, method='uniform'):
if method == 'importance':
if not (self.Lt_count > 10).all():
return self.sample_time(b, device, method='uniform')
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
pt_all = (Lt_sqrt / Lt_sqrt.sum()).to(device)
t = torch.multinomial(pt_all, num_samples=b, replacement=True).to(device)
pt = pt_all.gather(dim=0, index=t)
return t, pt
elif method == 'uniform':
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
pt = torch.ones_like(t).float() / self.num_timesteps
return t, pt
else:
raise ValueError
def _multinomial_loss(self, model_out, log_x_start, log_x_t, t, pt, out_dict):
if self.multinomial_loss_type == 'vb_stochastic':
kl = self.compute_Lt(
model_out, log_x_start, log_x_t, t, out_dict
)
kl_prior = self.kl_prior(log_x_start)
# Upweigh loss term of the kl
vb_loss = kl / pt + kl_prior
return vb_loss
elif self.multinomial_loss_type == 'vb_all':
# Expensive, dont do it ;).
# DEPRECATED
return -self.nll(log_x_start)
else:
raise ValueError()
def log_prob(self, x, out_dict):
b, device = x.size(0), x.device
if self.training:
return self._multinomial_loss(x, out_dict)
else:
log_x_start = index_to_log_onehot(x, self.num_classes)
t, pt = self.sample_time(b, device, 'importance')
kl = self.compute_Lt(
log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t, out_dict)
kl_prior = self.kl_prior(log_x_start)
# Upweigh loss term of the kl
loss = kl / pt + kl_prior
return -loss
def mixed_loss(self, x, out_dict):
b = x.shape[0]
device = x.device
t, pt = self.sample_time(b, device, 'uniform')
x_num = x[:, :self.num_numerical_features]
x_cat = x[:, self.num_numerical_features:]
x_num_t = x_num
log_x_cat_t = x_cat
if x_num.shape[1] > 0:
noise = torch.randn_like(x_num)
x_num_t = self.gaussian_q_sample(x_num, t, noise=noise)
if x_cat.shape[1] > 0:
log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes)
log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t)
x_in = torch.cat([x_num_t, log_x_cat_t], dim=1)
model_out = self._denoise_fn(
x_in,
t,
**out_dict
)
model_out_num = model_out[:, :self.num_numerical_features]
model_out_cat = model_out[:, self.num_numerical_features:]
loss_multi = torch.zeros((1,)).float()
loss_gauss = torch.zeros((1,)).float()
if x_cat.shape[1] > 0:
loss_multi = self._multinomial_loss(model_out_cat, log_x_cat, log_x_cat_t, t, pt, out_dict) / len(self.num_classes)
if x_num.shape[1] > 0:
loss_gauss = self._gaussian_loss(model_out_num, x_num, x_num_t, t, noise)
# loss_multi = torch.where(out_dict['y'] == 1, loss_multi, 2 * loss_multi)
# loss_gauss = torch.where(out_dict['y'] == 1, loss_gauss, 2 * loss_gauss)
return loss_multi.mean(), loss_gauss.mean()
@torch.no_grad()
def mixed_elbo(self, x0, out_dict):
b = x0.size(0)
device = x0.device
x_num = x0[:, :self.num_numerical_features]
x_cat = x0[:, self.num_numerical_features:]
has_cat = x_cat.shape[1] > 0
if has_cat:
log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes).to(device)
gaussian_loss = []
xstart_mse = []
mse = []
mu_mse = []
out_mean = []
true_mean = []
multinomial_loss = []
for t in range(self.num_timesteps):
t_array = (torch.ones(b, device=device) * t).long()
noise = torch.randn_like(x_num)
x_num_t = self.gaussian_q_sample(x_start=x_num, t=t_array, noise=noise)
if has_cat:
log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t_array)
else:
log_x_cat_t = x_cat
model_out = self._denoise_fn(
torch.cat([x_num_t, log_x_cat_t], dim=1),
t_array,
**out_dict
)
model_out_num = model_out[:, :self.num_numerical_features]
model_out_cat = model_out[:, self.num_numerical_features:]
kl = torch.tensor([0.0])
if has_cat:
kl = self.compute_Lt(
model_out=model_out_cat,
log_x_start=log_x_cat,
log_x_t=log_x_cat_t,
t=t_array,
out_dict=out_dict
)
out = self._vb_terms_bpd(
model_out_num,
x_start=x_num,
x_t=x_num_t,
t=t_array,
clip_denoised=False
)
multinomial_loss.append(kl)
gaussian_loss.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_num) ** 2))
# mu_mse.append(mean_flat(out["mean_mse"]))
out_mean.append(mean_flat(out["out_mean"]))
true_mean.append(mean_flat(out["true_mean"]))
eps = self._predict_eps_from_xstart(x_num_t, t_array, out["pred_xstart"])
mse.append(mean_flat((eps - noise) ** 2))
gaussian_loss = torch.stack(gaussian_loss, dim=1)
multinomial_loss = torch.stack(multinomial_loss, dim=1)
xstart_mse = torch.stack(xstart_mse, dim=1)
mse = torch.stack(mse, dim=1)
# mu_mse = torch.stack(mu_mse, dim=1)
out_mean = torch.stack(out_mean, dim=1)
true_mean = torch.stack(true_mean, dim=1)
prior_gauss = self._prior_gaussian(x_num)
prior_multin = torch.tensor([0.0])
if has_cat:
prior_multin = self.kl_prior(log_x_cat)
total_gauss = gaussian_loss.sum(dim=1) + prior_gauss
total_multin = multinomial_loss.sum(dim=1) + prior_multin
return {
"total_gaussian": total_gauss,
"total_multinomial": total_multin,
"losses_gaussian": gaussian_loss,
"losses_multinimial": multinomial_loss,
"xstart_mse": xstart_mse,
"mse": mse,
# "mu_mse": mu_mse
"out_mean": out_mean,
"true_mean": true_mean
}
@torch.no_grad()
def gaussian_ddim_step(
self,
model_out_num,
x,
t,
clip_denoised=False,
denoised_fn=None,
eta=0.0,
model_kwargs=None,
cond_fn=None
):
out = self.gaussian_p_mean_variance(
model_out_num,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=None,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = extract(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
)
noise = torch.randn_like(x)
mean_pred = (
out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return sample
@torch.no_grad()
def gaussian_ddim_sample(
self,
noise,
T,
out_dict,
eta=0.0,
model_kwargs=None,
cond_fn=None
):
x = noise
b = x.shape[0]
device = x.device
for t in reversed(range(T)):
print(f'Sample timestep {t:4d}', end='\r')
t_array = (torch.ones(b, device=device) * t).long()
out_num = self._denoise_fn(x, t_array, **out_dict)
x = self.gaussian_ddim_step(
out_num,
x,
t_array,
model_kwargs=model_kwargs,
cond_fn=cond_fn
)
print()
return x
@torch.no_grad()
def gaussian_ddim_reverse_step(
self,
model_out_num,
x,
t,
clip_denoised=False,
eta=0.0
):
assert eta == 0.0, "Eta must be zero."
out = self.gaussian_p_mean_variance(
model_out_num,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=None,
model_kwargs=None,
)
eps = (
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = extract(self.alphas_cumprod_next, t, x.shape)
mean_pred = (
out["pred_xstart"] * torch.sqrt(alpha_bar_next)
+ torch.sqrt(1 - alpha_bar_next) * eps
)
return mean_pred
@torch.no_grad()
def gaussian_ddim_reverse_sample(
self,
x,
T,
out_dict,
):
b = x.shape[0]
device = x.device
for t in range(T):
print(f'Reverse timestep {t:4d}', end='\r')
t_array = (torch.ones(b, device=device) * t).long()
out_num = self._denoise_fn(x, t_array, **out_dict)
x = self.gaussian_ddim_reverse_step(
out_num,
x,
t_array,
eta=0.0
)
print()
return x
@torch.no_grad()
def multinomial_ddim_step(
self,
model_out_cat,
log_x_t,
t,
out_dict,
eta=0.0
):
# not ddim, essentially
log_x0 = self.predict_start(model_out_cat, log_x_t=log_x_t, t=t, out_dict=out_dict)
alpha_bar = extract(self.alphas_cumprod, t, log_x_t.shape)
alpha_bar_prev = extract(self.alphas_cumprod_prev, t, log_x_t.shape)
sigma = (
eta
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
)
coef1 = sigma
coef2 = alpha_bar_prev - sigma * alpha_bar
coef3 = 1 - coef1 - coef2
log_ps = torch.stack([
torch.log(coef1) + log_x_t,
torch.log(coef2) + log_x0,
torch.log(coef3) - torch.log(self.num_classes_expanded)
], dim=2)
log_prob = torch.logsumexp(log_ps, dim=2)
out = self.log_sample_categorical(log_prob)
return out
@torch.no_grad()
def sample_ddim(self, num_samples, y_dist, model_kwargs=None, cond_fn=None):
b = num_samples
device = self.log_alpha.device
z_norm = torch.randn((b, self.num_numerical_features), device=device)
has_cat = self.num_classes[0] != 0
log_z = torch.zeros((b, 0), device=device).float()
if has_cat:
uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device)
log_z = self.log_sample_categorical(uniform_logits)
y = torch.multinomial(
y_dist,
num_samples=b,
replacement=True
)
out_dict = {'y': y.long().to(device)}
for i in reversed(range(0, self.num_timesteps)):
print(f'Sample timestep {i:4d}', end='\r')
t = torch.full((b,), i, device=device, dtype=torch.long)
model_out = self._denoise_fn(
torch.cat([z_norm, log_z], dim=1).float(),
t,
**out_dict
)
model_out_num = model_out[:, :self.num_numerical_features]
model_out_cat = model_out[:, self.num_numerical_features:]
z_norm = self.gaussian_ddim_step(
model_out_num,
z_norm,
t,
clip_denoised=False,
model_kwargs=model_kwargs,
cond_fn=cond_fn
)
if has_cat:
log_z = self.multinomial_ddim_step(model_out_cat, log_z, t, out_dict)
print()
z_ohe = torch.exp(log_z).round()
z_cat = log_z
if has_cat:
z_cat = ohe_to_categories(z_ohe, self.num_classes)
sample = torch.cat([z_norm, z_cat], dim=1).cpu()
return sample, out_dict
@torch.no_grad()
def conditional_sample(self, ys, model_kwargs=None, cond_fn=None):
device = self.log_alpha.device
b = len(ys)
z_norm = torch.randn((b, self.num_numerical_features), device=device)
has_cat = self.num_classes[0] != 0
log_z = torch.zeros((b, 0), device=device).float()
out_dict = {'y': ys.long().to(device)}
for i in reversed(range(0, self.num_timesteps)):
print(f'Sample timestep {i:4d}', end='\r')
t = torch.full((b,), i, device=device, dtype=torch.long)
model_out = self._denoise_fn(
torch.cat([z_norm, log_z], dim=1).float(),
t,
**out_dict
)
model_out_num = model_out[:, :self.num_numerical_features]
model_out_cat = model_out[:, self.num_numerical_features:]
z_norm = self.gaussian_p_sample(
model_out_num,
z_norm,
t,
clip_denoised=False,
model_kwargs=model_kwargs,
cond_fn=cond_fn
)['sample']
if has_cat:
log_z = self.p_sample(model_out_cat, log_z, t, out_dict)
print()
z_ohe = torch.exp(log_z).round()
z_cat = log_z
if has_cat:
z_cat = ohe_to_categories(z_ohe, self.num_classes)
sample = torch.cat([z_norm, z_cat], dim=1).cpu()
return sample, out_dict
@torch.no_grad()
def sample(self, num_samples, y_dist, model_kwargs=None, cond_fn=None):
b = num_samples
device = self.log_alpha.device
z_norm = torch.randn((b, self.num_numerical_features), device=device)
has_cat = self.num_classes[0] != 0
log_z = torch.zeros((b, 0), device=device).float()
if has_cat:
uniform_logits = torch.zeros((b, len(self.num_classes_expanded)), device=device)
log_z = self.log_sample_categorical(uniform_logits)
y = torch.multinomial(
y_dist,
num_samples=b,
replacement=True
)
out_dict = {'y': y.long().to(device)}
for i in reversed(range(0, self.num_timesteps)):
print(f'Sample timestep {i:4d}', end='\r')
t = torch.full((b,), i, device=device, dtype=torch.long)
model_out = self._denoise_fn(
torch.cat([z_norm, log_z], dim=1).float(),
t,
**out_dict
)
model_out_num = model_out[:, :self.num_numerical_features]
model_out_cat = model_out[:, self.num_numerical_features:]
z_norm = self.gaussian_p_sample(
model_out_num,
z_norm,
t,
clip_denoised=False,
model_kwargs=model_kwargs,
cond_fn=cond_fn
)['sample']
if has_cat:
log_z = self.p_sample(model_out_cat, log_z, t, out_dict)
print()
z_ohe = torch.exp(log_z).round()
z_cat = log_z
if has_cat:
z_cat = ohe_to_categories(z_ohe, self.num_classes)
sample = torch.cat([z_norm, z_cat], dim=1).cpu()
return sample, out_dict
def sample_all(self, num_samples, batch_size, y_dist, ddim=False, model_kwargs=None, cond_fn=None):
if ddim:
print('Sample using DDIM.')
sample_fn = self.sample_ddim
else:
sample_fn = self.sample
b = batch_size
all_y = []
all_samples = []
num_generated = 0
while num_generated < num_samples:
sample, out_dict = sample_fn(b, y_dist, model_kwargs=model_kwargs, cond_fn=cond_fn)
mask_nan = torch.any(sample.isnan(), dim=1)
sample = sample[~mask_nan]
out_dict['y'] = out_dict['y'][~mask_nan]
all_samples.append(sample)
all_y.append(out_dict['y'].cpu())
if sample.shape[0] != b:
raise FoundNANsError
num_generated += sample.shape[0]
x_gen = torch.cat(all_samples, dim=0)[:num_samples]
y_gen = torch.cat(all_y, dim=0)[:num_samples]
return x_gen, y_gen