| |
|
| |
|
| |
|
| |
|
| | import numpy as np
|
| | import torch as th
|
| |
|
| |
|
| | def normal_kl(mean1, logvar1, mean2, logvar2):
|
| | """
|
| | Compute the KL divergence between two gaussians.
|
| | Shapes are automatically broadcasted, so batches can be compared to
|
| | scalars, among other use cases.
|
| | """
|
| | tensor = None
|
| | for obj in (mean1, logvar1, mean2, logvar2):
|
| | if isinstance(obj, th.Tensor):
|
| | tensor = obj
|
| | break
|
| | assert tensor is not None, "at least one argument must be a Tensor"
|
| |
|
| |
|
| |
|
| | logvar1, logvar2 = (
|
| | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
| | for x in (logvar1, logvar2)
|
| | )
|
| |
|
| | return 0.5 * (
|
| | -1.0
|
| | + logvar2
|
| | - logvar1
|
| | + th.exp(logvar1 - logvar2)
|
| | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
| | )
|
| |
|
| |
|
| | def approx_standard_normal_cdf(x):
|
| | """
|
| | A fast approximation of the cumulative distribution function of the
|
| | standard normal.
|
| | """
|
| | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
| |
|
| |
|
| | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
| | """
|
| | Compute the log-likelihood of a continuous Gaussian distribution.
|
| | :param x: the targets
|
| | :param means: the Gaussian mean Tensor.
|
| | :param log_scales: the Gaussian log stddev Tensor.
|
| | :return: a tensor like x of log probabilities (in nats).
|
| | """
|
| | centered_x = x - means
|
| | inv_stdv = th.exp(-log_scales)
|
| | normalized_x = centered_x * inv_stdv
|
| | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
|
| | normalized_x,
|
| | )
|
| | return log_probs
|
| |
|
| |
|
| | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
| | """
|
| | Compute the log-likelihood of a Gaussian distribution discretizing to a
|
| | given image.
|
| | :param x: the target images. It is assumed that this was uint8 values,
|
| | rescaled to the range [-1, 1].
|
| | :param means: the Gaussian mean Tensor.
|
| | :param log_scales: the Gaussian log stddev Tensor.
|
| | :return: a tensor like x of log probabilities (in nats).
|
| | """
|
| | assert x.shape == means.shape == log_scales.shape
|
| | centered_x = x - means
|
| | inv_stdv = th.exp(-log_scales)
|
| | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
| | cdf_plus = approx_standard_normal_cdf(plus_in)
|
| | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
| | cdf_min = approx_standard_normal_cdf(min_in)
|
| | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
| | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
| | cdf_delta = cdf_plus - cdf_min
|
| | log_probs = th.where(
|
| | x < -0.999,
|
| | log_cdf_plus,
|
| | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
| | )
|
| | assert log_probs.shape == x.shape
|
| | return log_probs
|
| |
|