|
|
import six |
|
|
|
|
|
import chainer |
|
|
import chainer.functions as F |
|
|
from chainer.functions.loss.vae import gaussian_kl_divergence |
|
|
import chainer.links as L |
|
|
|
|
|
|
|
|
class VAE(chainer.Chain): |
|
|
"""Variational AutoEncoder""" |
|
|
|
|
|
def __init__(self, n_in, n_latent, n_h): |
|
|
super(VAE, self).__init__() |
|
|
with self.init_scope(): |
|
|
|
|
|
self.le1 = L.Linear(n_in, n_h) |
|
|
self.le2_mu = L.Linear(n_h, n_latent) |
|
|
self.le2_ln_var = L.Linear(n_h, n_latent) |
|
|
|
|
|
self.ld1 = L.Linear(n_latent, n_h) |
|
|
self.ld2 = L.Linear(n_h, n_in) |
|
|
|
|
|
def __call__(self, x, sigmoid=True): |
|
|
"""AutoEncoder""" |
|
|
return self.decode(self.encode(x)[0], sigmoid) |
|
|
|
|
|
def encode(self, x): |
|
|
h1 = F.tanh(self.le1(x)) |
|
|
mu = self.le2_mu(h1) |
|
|
ln_var = self.le2_ln_var(h1) |
|
|
return mu, ln_var |
|
|
|
|
|
def decode(self, z, sigmoid=True): |
|
|
h1 = F.tanh(self.ld1(z)) |
|
|
h2 = self.ld2(h1) |
|
|
if sigmoid: |
|
|
return F.sigmoid(h2) |
|
|
else: |
|
|
return h2 |
|
|
|
|
|
def get_loss_func(self, C=1.0, k=1): |
|
|
"""Get loss function of VAE. |
|
|
|
|
|
The loss value is equal to ELBO (Evidence Lower Bound) |
|
|
multiplied by -1. |
|
|
|
|
|
Args: |
|
|
C (int): Usually this is 1.0. Can be changed to control the |
|
|
second term of ELBO bound, which works as regularization. |
|
|
k (int): Number of Monte Carlo samples used in encoded vector. |
|
|
""" |
|
|
def lf(x): |
|
|
mu, ln_var = self.encode(x) |
|
|
batchsize = len(mu.data) |
|
|
|
|
|
rec_loss = 0 |
|
|
for l in six.moves.range(k): |
|
|
z = F.gaussian(mu, ln_var) |
|
|
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \ |
|
|
/ (k * batchsize) |
|
|
self.rec_loss = rec_loss |
|
|
self.loss = self.rec_loss + \ |
|
|
C * gaussian_kl_divergence(mu, ln_var) / batchsize |
|
|
return self.loss |
|
|
return lf |
|
|
|