Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torch.profiler import record_function | |
| from inspect import isfunction | |
| def normal_kl(mean1, logvar1, mean2, logvar2): | |
| """ | |
| Compute the KL divergence between two gaussians. | |
| Shapes are automatically broadcasted, so batches can be compared to | |
| scalars, among other use cases. | |
| """ | |
| tensor = None | |
| for obj in (mean1, logvar1, mean2, logvar2): | |
| if isinstance(obj, torch.Tensor): | |
| tensor = obj | |
| break | |
| assert tensor is not None, "at least one argument must be a Tensor" | |
| # Force variances to be Tensors. Broadcasting helps convert scalars to | |
| # Tensors, but it does not work for torch.exp(). | |
| logvar1, logvar2 = [ | |
| x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) | |
| for x in (logvar1, logvar2) | |
| ] | |
| return 0.5 * ( | |
| -1.0 | |
| + logvar2 | |
| - logvar1 | |
| + torch.exp(logvar1 - logvar2) | |
| + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) | |
| ) | |
| def approx_standard_normal_cdf(x): | |
| """ | |
| A fast approximation of the cumulative distribution function of the | |
| standard normal. | |
| """ | |
| return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
| def discretized_gaussian_log_likelihood(x, *, means, log_scales): | |
| """ | |
| Compute the log-likelihood of a Gaussian distribution discretizing to a | |
| given image. | |
| :param x: the target images. It is assumed that this was uint8 values, | |
| rescaled to the range [-1, 1]. | |
| :param means: the Gaussian mean Tensor. | |
| :param log_scales: the Gaussian log stddev Tensor. | |
| :return: a tensor like x of log probabilities (in nats). | |
| """ | |
| assert x.shape == means.shape == log_scales.shape | |
| centered_x = x - means | |
| inv_stdv = torch.exp(-log_scales) | |
| plus_in = inv_stdv * (centered_x + 1.0 / 255.0) | |
| cdf_plus = approx_standard_normal_cdf(plus_in) | |
| min_in = inv_stdv * (centered_x - 1.0 / 255.0) | |
| cdf_min = approx_standard_normal_cdf(min_in) | |
| log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) | |
| log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) | |
| cdf_delta = cdf_plus - cdf_min | |
| log_probs = torch.where( | |
| x < -0.999, | |
| log_cdf_plus, | |
| torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), | |
| ) | |
| assert log_probs.shape == x.shape | |
| return log_probs | |
| def sum_except_batch(x, num_dims=1): | |
| ''' | |
| Sums all dimensions except the first. | |
| Args: | |
| x: Tensor, shape (batch_size, ...) | |
| num_dims: int, number of batch dims (default=1) | |
| Returns: | |
| x_sum: Tensor, shape (batch_size,) | |
| ''' | |
| return x.reshape(*x.shape[:num_dims], -1).sum(-1) | |
| def mean_flat(tensor): | |
| """ | |
| Take the mean over all non-batch dimensions. | |
| """ | |
| return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
| def ohe_to_categories(ohe, K): | |
| K = torch.from_numpy(K) | |
| indices = torch.cat([torch.zeros((1,)), K.cumsum(dim=0)], dim=0).int().tolist() | |
| res = [] | |
| for i in range(len(indices) - 1): | |
| res.append(ohe[:, indices[i]:indices[i+1]].argmax(dim=1)) | |
| return torch.stack(res, dim=1) | |
| def log_1_min_a(a): | |
| return torch.log(1 - a.exp() + 1e-40) | |
| def log_add_exp(a, b): | |
| maximum = torch.max(a, b) | |
| return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) | |
| def exists(x): | |
| return x is not None | |
| def extract(a, t, x_shape): | |
| b, *_ = t.shape | |
| t = t.to(a.device) | |
| out = a.gather(-1, t) | |
| while len(out.shape) < len(x_shape): | |
| out = out[..., None] | |
| return out.expand(x_shape) | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| def log_categorical(log_x_start, log_prob): | |
| return (log_x_start.exp() * log_prob).sum(dim=1) | |
| def index_to_log_onehot(x, num_classes): | |
| onehots = [] | |
| for i in range(len(num_classes)): | |
| onehots.append(F.one_hot(x[:, i], num_classes[i])) | |
| x_onehot = torch.cat(onehots, dim=1) | |
| log_onehot = torch.log(x_onehot.float().clamp(min=1e-30)) | |
| return log_onehot | |
| def log_sum_exp_by_classes(x, slices): | |
| device = x.device | |
| res = torch.zeros_like(x) | |
| for ixs in slices: | |
| res[:, ixs] = torch.logsumexp(x[:, ixs], dim=1, keepdim=True) | |
| assert x.size() == res.size() | |
| return res | |
| def log_sub_exp(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| m = torch.maximum(a, b) | |
| return torch.log(torch.exp(a - m) - torch.exp(b - m)) + m | |
| def sliced_logsumexp(x, slices): | |
| lse = torch.logcumsumexp( | |
| torch.nn.functional.pad(x, [1, 0, 0, 0], value=-float('inf')), | |
| dim=-1) | |
| slice_starts = slices[:-1] | |
| slice_ends = slices[1:] | |
| slice_lse = log_sub_exp(lse[:, slice_ends], lse[:, slice_starts]) | |
| slice_lse_repeated = torch.repeat_interleave( | |
| slice_lse, | |
| slice_ends - slice_starts, | |
| dim=-1 | |
| ) | |
| return slice_lse_repeated | |
| def log_onehot_to_index(log_x): | |
| return log_x.argmax(1) | |
| class FoundNANsError(BaseException): | |
| """Found NANs during sampling""" | |
| def __init__(self, message='Found NANs during sampling.'): | |
| super(FoundNANsError, self).__init__(message) |