File size: 5,456 Bytes
91f8c72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os

import keras.regularizers
import tensorflow as tf
from keras.layers import InputLayer, Conv2D, Flatten, BatchNormalization, Dense, UpSampling2D, Reshape, Dropout, Add
import keras.backend as tfkbk
import numpy as np
from blocks import ResidualBlock
from keras.layers import LeakyReLU, PReLU

INPUT_SHAPE = (64, 64)
LATENT_DIM = 512


def get_encoder():
    encoder = tf.keras.Sequential(name="encoder")

    encoder.add(InputLayer(input_shape=(*INPUT_SHAPE, 1)))

    encoder.add(Conv2D(32, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform'))
    encoder.add(Conv2D(32, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform'))
    encoder.add(Conv2D(64, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform'))
    encoder.add(Conv2D(64, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform'))
    encoder.add(Conv2D(128, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform'))
    encoder.add(Conv2D(128, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform'))

    encoder.add(Flatten())

    encoder.add(Dense(LATENT_DIM * 2, activation=PReLU(), activity_regularizer=tf.keras.regularizers.L2(10e-6)))

    return encoder


def get_decoder():

    inputs = tf.keras.layers.Input(shape=[LATENT_DIM, ])

    x = inputs
    x = Dense(8 * 8 * 16, activation='relu')(x)
    x = Dense(8 * 8 * 16, activation='relu')(x)
    x = Reshape(target_shape=(8, 8, 16))(x)

    x = UpSampling2D(2)(x)
    x = Conv2D(128, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x)
    x = ResidualBlock(128, 3, seed=42, name="res1", padding="reflect")(x)
    x = ResidualBlock(128, 3, seed=42, name="res2", padding="reflect")(x)

    x = UpSampling2D(2)(x)
    x = Conv2D(64, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x)
    x = ResidualBlock(64, 3, seed=42, name="res4", padding="reflect")(x)
    x = ResidualBlock(64, 3, seed=42, name="res5", padding="reflect")(x)

    x = UpSampling2D(2)(x)
    x = Conv2D(32, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x)
    x = ResidualBlock(32, 3, seed=42, name="res7", padding="reflect")(x)
    x = ResidualBlock(32, 3, seed=42, name="res8", padding="reflect")(x)

    x = Conv2D(1, 3, padding='same', kernel_initializer='he_uniform')(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


class CVAE(tf.keras.Model):
    def __init__(self, encoder: tf.keras.models.Model, decoder: tf.keras.models.Model,
                 latent_dim, kl_weight=1, loss_fun='bce', include_regularization: bool = False):
        super(CVAE, self).__init__()
        self.kl_weight = kl_weight
        self.latent_dim = latent_dim
        self.loss_fun = loss_fun
        self.encoder = encoder
        self.decoder = decoder
        self.kl_loss = 0
        self.reconstruction_loss = 0
        self.include_regularization = include_regularization

    def call(self, inputs, training=None, mask=None):
        z_mean, z_log_var = tf.split(self.encoder(inputs), num_or_size_splits=2, axis=1)
        z = self.sampling(z_mean, z_log_var, self.latent_dim)
        # z_mean, z_log_var, z = self.encoder(inputs)
        outputs = self.decoder(z)

        if training:
            regularization_loss = tf.math.reduce_sum(self.encoder.losses)

            if self.loss_fun == 'elbo':
                cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=outputs, labels=inputs)
                logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
                logpz = self.log_normal_pdf(z, 0., 0.)
                logqz_x = self.log_normal_pdf(z, z_mean, z_log_var)
                vae_loss = -tf.reduce_mean(logpx_z + logpz - logqz_x)
            else:
                kl_loss = 1 + z_log_var - tf.math.square(z_mean) - tf.math.exp(z_log_var)
                kl_loss = tf.math.reduce_sum(kl_loss, axis=-1)
                kl_loss *= -0.5 * self.kl_weight
                self.kl_loss = kl_loss
                if self.loss_fun == 'mse':
                    reconstruction_loss = tf.keras.metrics.mean_squared_error(tfkbk.flatten(inputs),
                                                                              tfkbk.flatten(outputs))
                elif self.loss_fun == 'bce':
                    reconstruction_loss = tf.keras.metrics.binary_crossentropy(tfkbk.flatten(inputs),
                                                                               tfkbk.flatten(outputs))
                else:
                    raise ValueError

                reconstruction_loss *= (inputs.shape[1] * inputs.shape[1])
                self.reconstruction_loss = reconstruction_loss
                vae_loss = tf.math.reduce_mean(reconstruction_loss + kl_loss)

            if self.include_regularization:
                vae_loss += regularization_loss

            self.add_loss(vae_loss)
        return outputs

    @staticmethod
    def sampling(z_mean, z_log_var, latent_dim):
        batch = tf.shape(z_mean)[0]
        epsilon = tf.keras.backend.random_normal(shape=(batch, latent_dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    @staticmethod
    def log_normal_pdf(sample, mean, logvar, raxis=1):
        log2pi = tf.math.log(2. * np.pi)
        return tf.reduce_sum(
            -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
            axis=raxis)