bds2714's picture
Upload 331 files
c508d7f
import six
import chainer
import chainer.functions as F
from chainer.functions.loss.vae import gaussian_kl_divergence
import chainer.links as L
class VAE(chainer.Chain):
"""Variational AutoEncoder"""
def __init__(self, n_in, n_latent, n_h):
super(VAE, self).__init__()
with self.init_scope():
# encoder
self.le1 = L.Linear(n_in, n_h)
self.le2_mu = L.Linear(n_h, n_latent)
self.le2_ln_var = L.Linear(n_h, n_latent)
# decoder
self.ld1 = L.Linear(n_latent, n_h)
self.ld2 = L.Linear(n_h, n_in)
def __call__(self, x, sigmoid=True):
"""AutoEncoder"""
return self.decode(self.encode(x)[0], sigmoid)
def encode(self, x):
h1 = F.tanh(self.le1(x))
mu = self.le2_mu(h1)
ln_var = self.le2_ln_var(h1) # log(sigma**2)
return mu, ln_var
def decode(self, z, sigmoid=True):
h1 = F.tanh(self.ld1(z))
h2 = self.ld2(h1)
if sigmoid:
return F.sigmoid(h2)
else:
return h2
def get_loss_func(self, C=1.0, k=1):
"""Get loss function of VAE.
The loss value is equal to ELBO (Evidence Lower Bound)
multiplied by -1.
Args:
C (int): Usually this is 1.0. Can be changed to control the
second term of ELBO bound, which works as regularization.
k (int): Number of Monte Carlo samples used in encoded vector.
"""
def lf(x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction loss
rec_loss = 0
for l in six.moves.range(k):
z = F.gaussian(mu, ln_var)
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \
/ (k * batchsize)
self.rec_loss = rec_loss
self.loss = self.rec_loss + \
C * gaussian_kl_divergence(mu, ln_var) / batchsize
return self.loss
return lf