import torch import numpy as np class AbstractDistribution: def sample(self): raise NotImplementedError() def mode(self): raise NotImplementedError() class DiracDistribution(AbstractDistribution): def __init__(self, value): self.value = value def sample(self): return self.value def mode(self): return self.value class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): """ parameters: input is expected to be a 2D tensor, where the first half of the last dimension are the means and the second half are the log-variances. deterministic: if set to True, would mean that there is no randomness in the distribution (i.e., variance and standard deviation are set to zero). mathematical: self.mean = µ self.std = σ self.var = σ^2 self.logvar = log(σ^2) = 2log(σ) The logarithm of the variance (self.logvar) is also often used in formulas in statistics. For example, the log-likelihood of a Gaussian distribution involves the log of the variance. Therefore, working directly with the log -variance can make the formulas simpler and more numerically stable. """ self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self): """ Reparameterization: if Z is a standard normal random variable (i.e., Gaussian distributed with mean 0 and standard deviation 1), X = μ + σZ is a normal random variable with mean μ and standard deviation σ. """ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) return x def kl(self, other=None): """ This function is to compute the KL-divergence of the current Gaussian distribution with another one. If other is None, then compute the KL-divergence with a standard distribution. $ KL(P||Q) = log(σ_2 / σ_1) + \frac{σ_1^2 + (μ1 - μ2)^2}{2σ_2^2} - 0.5 $ """ if self.deterministic: return torch.Tensor([0.]) else: if other is None: return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3]) def nll(self, sample, dims=[1, 2, 3]): """ The negative log likelihood (NLL) of observing a sample x from a normal distribution with mean μ and variance σ^2 is given by: NLL = 0.5 * log(2πσ^2) + (1 / 2σ^2) * (x - μ)^2 """ if self.deterministic: return torch.Tensor([0.]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean def normal_kl(mean1, logvar1, mean2, logvar2): """ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 Compute the KL divergence between two Gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) )