|
|
import tensorflow as tf |
|
|
from tensorflow.keras.models import Model |
|
|
|
|
|
|
|
|
class VAE(Model): |
|
|
def __init__(self, encoder, decoder, **kwargs): |
|
|
super(VAE, self).__init__(**kwargs) |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
|
|
|
|
|
|
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") |
|
|
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss") |
|
|
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") |
|
|
|
|
|
@property |
|
|
def metrics(self): |
|
|
return [ |
|
|
self.total_loss_tracker, |
|
|
self.reconstruction_loss_tracker, |
|
|
self.kl_loss_tracker, |
|
|
] |
|
|
|
|
|
@tf.function() |
|
|
def call(self, x): |
|
|
z, z_mean, z_log_var, = self.encoder(x) |
|
|
reconstruction = self.decoder(z) |
|
|
return reconstruction |
|
|
|
|
|
def full_summary(self): |
|
|
for layer in self.layers: |
|
|
print(layer.summary()) |
|
|
|
|
|
@tf.function() |
|
|
def train_step(self, x): |
|
|
with tf.GradientTape() as tape: |
|
|
z, z_mean, z_log_var, = self.encoder(x) |
|
|
reconstruction = self.decoder(z) |
|
|
|
|
|
reconstruction_loss = tf.reduce_mean( |
|
|
tf.reduce_sum( |
|
|
tf.keras.losses.binary_crossentropy(x, reconstruction), axis=(1, 2) |
|
|
) |
|
|
) |
|
|
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) |
|
|
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) |
|
|
if tf.math.is_nan(kl_loss) or tf.math.is_inf(kl_loss): |
|
|
kl_loss = tf.float32.max |
|
|
total_loss = reconstruction_loss + kl_loss |
|
|
|
|
|
grads = tape.gradient(total_loss, self.trainable_weights) |
|
|
self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) |
|
|
self.total_loss_tracker.update_state(total_loss) |
|
|
self.reconstruction_loss_tracker.update_state(reconstruction_loss) |
|
|
self.kl_loss_tracker.update_state(kl_loss) |
|
|
|
|
|
return { |
|
|
"loss": self.total_loss_tracker.result(), |
|
|
"reconstruction_loss": self.reconstruction_loss_tracker.result(), |
|
|
"kl_loss": self.kl_loss_tracker.result(), |
|
|
} |
|
|
|