Spaces:
Runtime error
Runtime error
| import os | |
| import keras.regularizers | |
| import tensorflow as tf | |
| from keras.layers import InputLayer, Conv2D, Flatten, BatchNormalization, Dense, UpSampling2D, Reshape, Dropout, Add | |
| import keras.backend as tfkbk | |
| import numpy as np | |
| from blocks import ResidualBlock | |
| from keras.layers import LeakyReLU, PReLU | |
| INPUT_SHAPE = (64, 64) | |
| LATENT_DIM = 512 | |
| def get_encoder(): | |
| encoder = tf.keras.Sequential(name="encoder") | |
| encoder.add(InputLayer(input_shape=(*INPUT_SHAPE, 1))) | |
| encoder.add(Conv2D(32, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform')) | |
| encoder.add(Conv2D(32, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform')) | |
| encoder.add(Conv2D(64, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform')) | |
| encoder.add(Conv2D(64, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform')) | |
| encoder.add(Conv2D(128, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform')) | |
| encoder.add(Conv2D(128, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform')) | |
| encoder.add(Flatten()) | |
| encoder.add(Dense(LATENT_DIM * 2, activation=PReLU(), activity_regularizer=tf.keras.regularizers.L2(10e-6))) | |
| return encoder | |
| def get_decoder(): | |
| inputs = tf.keras.layers.Input(shape=[LATENT_DIM, ]) | |
| x = inputs | |
| x = Dense(8 * 8 * 16, activation='relu')(x) | |
| x = Dense(8 * 8 * 16, activation='relu')(x) | |
| x = Reshape(target_shape=(8, 8, 16))(x) | |
| x = UpSampling2D(2)(x) | |
| x = Conv2D(128, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x) | |
| x = ResidualBlock(128, 3, seed=42, name="res1", padding="reflect")(x) | |
| x = ResidualBlock(128, 3, seed=42, name="res2", padding="reflect")(x) | |
| x = UpSampling2D(2)(x) | |
| x = Conv2D(64, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x) | |
| x = ResidualBlock(64, 3, seed=42, name="res4", padding="reflect")(x) | |
| x = ResidualBlock(64, 3, seed=42, name="res5", padding="reflect")(x) | |
| x = UpSampling2D(2)(x) | |
| x = Conv2D(32, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x) | |
| x = ResidualBlock(32, 3, seed=42, name="res7", padding="reflect")(x) | |
| x = ResidualBlock(32, 3, seed=42, name="res8", padding="reflect")(x) | |
| x = Conv2D(1, 3, padding='same', kernel_initializer='he_uniform')(x) | |
| return tf.keras.Model(inputs=inputs, outputs=x) | |
| class CVAE(tf.keras.Model): | |
| def __init__(self, encoder: tf.keras.models.Model, decoder: tf.keras.models.Model, | |
| latent_dim, kl_weight=1, loss_fun='bce', include_regularization: bool = False): | |
| super(CVAE, self).__init__() | |
| self.kl_weight = kl_weight | |
| self.latent_dim = latent_dim | |
| self.loss_fun = loss_fun | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.kl_loss = 0 | |
| self.reconstruction_loss = 0 | |
| self.include_regularization = include_regularization | |
| def call(self, inputs, training=None, mask=None): | |
| z_mean, z_log_var = tf.split(self.encoder(inputs), num_or_size_splits=2, axis=1) | |
| z = self.sampling(z_mean, z_log_var, self.latent_dim) | |
| # z_mean, z_log_var, z = self.encoder(inputs) | |
| outputs = self.decoder(z) | |
| if training: | |
| regularization_loss = tf.math.reduce_sum(self.encoder.losses) | |
| if self.loss_fun == 'elbo': | |
| cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=outputs, labels=inputs) | |
| logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3]) | |
| logpz = self.log_normal_pdf(z, 0., 0.) | |
| logqz_x = self.log_normal_pdf(z, z_mean, z_log_var) | |
| vae_loss = -tf.reduce_mean(logpx_z + logpz - logqz_x) | |
| else: | |
| kl_loss = 1 + z_log_var - tf.math.square(z_mean) - tf.math.exp(z_log_var) | |
| kl_loss = tf.math.reduce_sum(kl_loss, axis=-1) | |
| kl_loss *= -0.5 * self.kl_weight | |
| self.kl_loss = kl_loss | |
| if self.loss_fun == 'mse': | |
| reconstruction_loss = tf.keras.metrics.mean_squared_error(tfkbk.flatten(inputs), | |
| tfkbk.flatten(outputs)) | |
| elif self.loss_fun == 'bce': | |
| reconstruction_loss = tf.keras.metrics.binary_crossentropy(tfkbk.flatten(inputs), | |
| tfkbk.flatten(outputs)) | |
| else: | |
| raise ValueError | |
| reconstruction_loss *= (inputs.shape[1] * inputs.shape[1]) | |
| self.reconstruction_loss = reconstruction_loss | |
| vae_loss = tf.math.reduce_mean(reconstruction_loss + kl_loss) | |
| if self.include_regularization: | |
| vae_loss += regularization_loss | |
| self.add_loss(vae_loss) | |
| return outputs | |
| def sampling(z_mean, z_log_var, latent_dim): | |
| batch = tf.shape(z_mean)[0] | |
| epsilon = tf.keras.backend.random_normal(shape=(batch, latent_dim)) | |
| return z_mean + tf.exp(0.5 * z_log_var) * epsilon | |
| def log_normal_pdf(sample, mean, logvar, raxis=1): | |
| log2pi = tf.math.log(2. * np.pi) | |
| return tf.reduce_sum( | |
| -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), | |
| axis=raxis) | |