| | import tensorflow as tf |
| | from tensorflow import keras |
| | from tensorflow.keras import regularizers |
| | import numpy as np |
| | import tensorflow_probability as tfp |
| |
|
| | |
| | |
| | output_dim = 256 |
| | reg = 0.01 |
| |
|
| | def Coupling(input_shape): |
| | input = keras.layers.Input(shape=input_shape) |
| |
|
| | t_layer_1 = keras.layers.Dense( |
| | output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
| | )(input) |
| | t_layer_2 = keras.layers.Dense( |
| | output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
| | )(t_layer_1) |
| | t_layer_3 = keras.layers.Dense( |
| | output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
| | )(t_layer_2) |
| | t_layer_4 = keras.layers.Dense( |
| | output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
| | )(t_layer_3) |
| | t_layer_5 = keras.layers.Dense( |
| | input_shape, activation="linear", kernel_regularizer=regularizers.l2(reg) |
| | )(t_layer_4) |
| |
|
| | s_layer_1 = keras.layers.Dense( |
| | output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
| | )(input) |
| | s_layer_2 = keras.layers.Dense( |
| | output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
| | )(s_layer_1) |
| | s_layer_3 = keras.layers.Dense( |
| | output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
| | )(s_layer_2) |
| | s_layer_4 = keras.layers.Dense( |
| | output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) |
| | )(s_layer_3) |
| | s_layer_5 = keras.layers.Dense( |
| | input_shape, activation="tanh", kernel_regularizer=regularizers.l2(reg) |
| | )(s_layer_4) |
| |
|
| | return keras.Model(inputs=input, outputs=[s_layer_5, t_layer_5]) |
| |
|
| | |
| | class RealNVP(keras.Model): |
| | def __init__(self, num_coupling_layers): |
| | super(RealNVP, self).__init__() |
| |
|
| | self.num_coupling_layers = num_coupling_layers |
| |
|
| | |
| | self.distribution = tfp.distributions.MultivariateNormalDiag( |
| | loc=[0.0, 0.0], scale_diag=[1.0, 1.0] |
| | ) |
| | self.masks = np.array( |
| | [[0, 1], [1, 0]] * (num_coupling_layers // 2), dtype="float32" |
| | ) |
| | self.loss_tracker = keras.metrics.Mean(name="loss") |
| | self.layers_list = [Coupling(2) for i in range(num_coupling_layers)] |
| |
|
| | @property |
| | def metrics(self): |
| | """List of the model's metrics. |
| | We make sure the loss tracker is listed as part of `model.metrics` |
| | so that `fit()` and `evaluate()` are able to `reset()` the loss tracker |
| | at the start of each epoch and at the start of an `evaluate()` call. |
| | """ |
| | return [self.loss_tracker] |
| |
|
| | def call(self, x, training=True): |
| | log_det_inv = 0 |
| | direction = 1 |
| | if training: |
| | direction = -1 |
| | for i in range(self.num_coupling_layers)[::direction]: |
| | x_masked = x * self.masks[i] |
| | reversed_mask = 1 - self.masks[i] |
| | s, t = self.layers_list[i](x_masked) |
| | s *= reversed_mask |
| | t *= reversed_mask |
| | gate = (direction - 1) / 2 |
| | x = ( |
| | reversed_mask |
| | * (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s)) |
| | + x_masked |
| | ) |
| | log_det_inv += gate * tf.reduce_sum(s, [1]) |
| |
|
| | return x, log_det_inv |
| |
|
| | |
| |
|
| | def log_loss(self, x): |
| | y, logdet = self(x) |
| | log_likelihood = self.distribution.log_prob(y) + logdet |
| | return -tf.reduce_mean(log_likelihood) |
| |
|
| | def train_step(self, data): |
| | with tf.GradientTape() as tape: |
| |
|
| | loss = self.log_loss(data) |
| |
|
| | g = tape.gradient(loss, self.trainable_variables) |
| | self.optimizer.apply_gradients(zip(g, self.trainable_variables)) |
| | self.loss_tracker.update_state(loss) |
| |
|
| | return {"loss": self.loss_tracker.result()} |
| |
|
| | def test_step(self, data): |
| | loss = self.log_loss(data) |
| | self.loss_tracker.update_state(loss) |
| |
|
| | return {"loss": self.loss_tracker.result()} |
| |
|
| | def load_model(): |
| | return RealNVP(num_coupling_layers=6) |
| |
|