Zilong-Zhao's picture
first commit
c4ac745
import torch
import numpy as np
import torch.nn.functional as F
from torch.profiler import record_function
from inspect import isfunction
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < -0.999,
log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
def sum_except_batch(x, num_dims=1):
'''
Sums all dimensions except the first.
Args:
x: Tensor, shape (batch_size, ...)
num_dims: int, number of batch dims (default=1)
Returns:
x_sum: Tensor, shape (batch_size,)
'''
return x.reshape(*x.shape[:num_dims], -1).sum(-1)
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def ohe_to_categories(ohe, K):
K = torch.from_numpy(K)
indices = torch.cat([torch.zeros((1,)), K.cumsum(dim=0)], dim=0).int().tolist()
res = []
for i in range(len(indices) - 1):
res.append(ohe[:, indices[i]:indices[i+1]].argmax(dim=1))
return torch.stack(res, dim=1)
def log_1_min_a(a):
return torch.log(1 - a.exp() + 1e-40)
def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
def exists(x):
return x is not None
def extract(a, t, x_shape):
b, *_ = t.shape
t = t.to(a.device)
out = a.gather(-1, t)
while len(out.shape) < len(x_shape):
out = out[..., None]
return out.expand(x_shape)
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def log_categorical(log_x_start, log_prob):
return (log_x_start.exp() * log_prob).sum(dim=1)
def index_to_log_onehot(x, num_classes):
onehots = []
for i in range(len(num_classes)):
onehots.append(F.one_hot(x[:, i], num_classes[i]))
x_onehot = torch.cat(onehots, dim=1)
log_onehot = torch.log(x_onehot.float().clamp(min=1e-30))
return log_onehot
def log_sum_exp_by_classes(x, slices):
device = x.device
res = torch.zeros_like(x)
for ixs in slices:
res[:, ixs] = torch.logsumexp(x[:, ixs], dim=1, keepdim=True)
assert x.size() == res.size()
return res
@torch.jit.script
def log_sub_exp(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
m = torch.maximum(a, b)
return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m
@torch.jit.script
def sliced_logsumexp(x, slices):
lse = torch.logcumsumexp(
torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float('inf')),
dim=-1)
slice_starts = slices[:-1]
slice_ends = slices[1:]
slice_lse = log_sub_exp(lse[:, slice_ends], lse[:, slice_starts])
slice_lse_repeated = torch.repeat_interleave(
slice_lse,
slice_ends - slice_starts,
dim=-1
)
return slice_lse_repeated
def log_onehot_to_index(log_x):
return log_x.argmax(1)
class FoundNANsError(BaseException):
"""Found NANs during sampling"""
def __init__(self, message='Found NANs during sampling.'):
super(FoundNANsError, self).__init__(message)