import tensorflow as tf from tensorflow.keras.models import Model class VAE(Model): def __init__(self, encoder, decoder, **kwargs): super(VAE, self).__init__(**kwargs) self.encoder = encoder self.decoder = decoder # Loss Trackers self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss") self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") @property def metrics(self): return [ self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker, ] @tf.function() def call(self, x): z, z_mean, z_log_var, = self.encoder(x) reconstruction = self.decoder(z) return reconstruction def full_summary(self): for layer in self.layers: print(layer.summary()) @tf.function() def train_step(self, x): with tf.GradientTape() as tape: z, z_mean, z_log_var, = self.encoder(x) reconstruction = self.decoder(z) reconstruction_loss = tf.reduce_mean( tf.reduce_sum( tf.keras.losses.binary_crossentropy(x, reconstruction), axis=(1, 2) ) ) kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) if tf.math.is_nan(kl_loss) or tf.math.is_inf(kl_loss): kl_loss = tf.float32.max total_loss = reconstruction_loss + kl_loss grads = tape.gradient(total_loss, self.trainable_weights) self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), }