import six import chainer import chainer.functions as F from chainer.functions.loss.vae import gaussian_kl_divergence import chainer.links as L class VAE(chainer.Chain): """Variational AutoEncoder""" def __init__(self, n_in, n_latent, n_h): super(VAE, self).__init__() with self.init_scope(): # encoder self.le1 = L.Linear(n_in, n_h) self.le2_mu = L.Linear(n_h, n_latent) self.le2_ln_var = L.Linear(n_h, n_latent) # decoder self.ld1 = L.Linear(n_latent, n_h) self.ld2 = L.Linear(n_h, n_in) def __call__(self, x, sigmoid=True): """AutoEncoder""" return self.decode(self.encode(x)[0], sigmoid) def encode(self, x): h1 = F.tanh(self.le1(x)) mu = self.le2_mu(h1) ln_var = self.le2_ln_var(h1) # log(sigma**2) return mu, ln_var def decode(self, z, sigmoid=True): h1 = F.tanh(self.ld1(z)) h2 = self.ld2(h1) if sigmoid: return F.sigmoid(h2) else: return h2 def get_loss_func(self, C=1.0, k=1): """Get loss function of VAE. The loss value is equal to ELBO (Evidence Lower Bound) multiplied by -1. Args: C (int): Usually this is 1.0. Can be changed to control the second term of ELBO bound, which works as regularization. k (int): Number of Monte Carlo samples used in encoded vector. """ def lf(x): mu, ln_var = self.encode(x) batchsize = len(mu.data) # reconstruction loss rec_loss = 0 for l in six.moves.range(k): z = F.gaussian(mu, ln_var) rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \ / (k * batchsize) self.rec_loss = rec_loss self.loss = self.rec_loss + \ C * gaussian_kl_divergence(mu, ln_var) / batchsize return self.loss return lf