Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| from .encoder import EncoderBase | |
| from ..utils import log_sum_exp | |
| class GaussianEncoderBase(EncoderBase): | |
| """docstring for EncoderBase""" | |
| def __init__(self): | |
| super(GaussianEncoderBase, self).__init__() | |
| def freeze(self): | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: (batch_size, *) | |
| Returns: Tensor1, Tensor2 | |
| Tensor1: the mean tensor, shape (batch, nz) | |
| Tensor2: the logvar tensor, shape (batch, nz) | |
| """ | |
| raise NotImplementedError | |
| def encode_stats(self, x): | |
| return self.forward(x) | |
| def sample(self, input, nsamples): | |
| """sampling from the encoder | |
| Returns: Tensor1 | |
| Tensor1: the tensor latent z with shape [batch, nsamples, nz] | |
| """ | |
| # (batch_size, nz) | |
| mu, logvar = self.forward(input) | |
| # (batch, nsamples, nz) | |
| z = self.reparameterize(mu, logvar, nsamples) | |
| return z, (mu, logvar) | |
| def encode(self, input, nsamples): | |
| """perform the encoding and compute the KL term | |
| Returns: Tensor1, Tensor2 | |
| Tensor1: the tensor latent z with shape [batch, nsamples, nz] | |
| Tensor2: the tenor of KL for each x with shape [batch] | |
| """ | |
| # (batch_size, nz) | |
| mu, logvar = self.forward(input) | |
| # (batch, nsamples, nz) | |
| z = self.reparameterize(mu, logvar, nsamples) | |
| KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) | |
| return z, KL | |
| def reparameterize(self, mu, logvar, nsamples=1): | |
| """sample from posterior Gaussian family | |
| Args: | |
| mu: Tensor | |
| Mean of gaussian distribution with shape (batch, nz) | |
| logvar: Tensor | |
| logvar of gaussian distibution with shape (batch, nz) | |
| Returns: Tensor | |
| Sampled z with shape (batch, nsamples, nz) | |
| """ | |
| batch_size, nz = mu.size() | |
| std = logvar.mul(0.5).exp() | |
| mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) | |
| std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) | |
| eps = torch.zeros_like(std_expd).normal_() | |
| return mu_expd + torch.mul(eps, std_expd) | |
| def eval_inference_dist(self, x, z, param=None): | |
| """this function computes log q(z | x) | |
| Args: | |
| z: tensor | |
| different z points that will be evaluated, with | |
| shape [batch, nsamples, nz] | |
| Returns: Tensor1 | |
| Tensor1: log q(z|x) with shape [batch, nsamples] | |
| """ | |
| nz = z.size(2) | |
| if not param: | |
| mu, logvar = self.forward(x) | |
| else: | |
| mu, logvar = param | |
| # (batch_size, 1, nz) | |
| mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1) | |
| var = logvar.exp() | |
| # (batch_size, nsamples, nz) | |
| dev = z - mu | |
| # (batch_size, nsamples) | |
| log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ | |
| 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) | |
| return log_density | |
| def calc_mi(self, x): | |
| """Approximate the mutual information between x and z | |
| I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) | |
| Returns: Float | |
| """ | |
| # [x_batch, nz] | |
| mu, logvar = self.forward(x) | |
| x_batch, nz = mu.size() | |
| # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) | |
| neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean() | |
| # [z_batch, 1, nz] | |
| z_samples = self.reparameterize(mu, logvar, 1) | |
| # [1, x_batch, nz] | |
| mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) | |
| var = logvar.exp() | |
| # (z_batch, x_batch, nz) | |
| dev = z_samples - mu | |
| # (z_batch, x_batch) | |
| log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ | |
| 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) | |
| # log q(z): aggregate posterior | |
| # [z_batch] | |
| log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch) | |
| return (neg_entropy - log_qz.mean(-1)).item() |