| import tensorflow as tf
|
| import os
|
| import time
|
| import matplotlib.pyplot as plt
|
| from tf_dataset import get_dataset
|
| from tf_models import Generator, Discriminator
|
|
|
|
|
| LAMBDA = 10
|
| EPOCHS = 10
|
| DATA_PATH = "data/horse2zebra"
|
|
|
| generator_g = Generator()
|
| generator_f = Generator()
|
|
|
| discriminator_x = Discriminator()
|
| discriminator_y = Discriminator()
|
|
|
| loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
|
|
| def discriminator_loss(real, generated):
|
| real_loss = loss_obj(tf.ones_like(real), real)
|
| generated_loss = loss_obj(tf.zeros_like(generated), generated)
|
| total_disc_loss = real_loss + generated_loss
|
| return total_disc_loss * 0.5
|
|
|
| def generator_loss(generated):
|
| return loss_obj(tf.ones_like(generated), generated)
|
|
|
| def calc_cycle_loss(real_image, cycled_image):
|
| loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
|
| return LAMBDA * loss1
|
|
|
| def identity_loss(real_image, same_image):
|
| loss = tf.reduce_mean(tf.abs(real_image - same_image))
|
| return LAMBDA * 0.5 * loss
|
|
|
| generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
| generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
|
|
| discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
| discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
|
|
| @tf.function
|
| def train_step(real_x, real_y):
|
| with tf.GradientTape(persistent=True) as tape:
|
|
|
|
|
|
|
| fake_y = generator_g(real_x, training=True)
|
| cycled_x = generator_f(fake_y, training=True)
|
|
|
| fake_x = generator_f(real_y, training=True)
|
| cycled_y = generator_g(fake_x, training=True)
|
|
|
|
|
| same_x = generator_f(real_x, training=True)
|
| same_y = generator_g(real_y, training=True)
|
|
|
| disc_real_x = discriminator_x(real_x, training=True)
|
| disc_real_y = discriminator_y(real_y, training=True)
|
|
|
| disc_fake_x = discriminator_x(fake_x, training=True)
|
| disc_fake_y = discriminator_y(fake_y, training=True)
|
|
|
|
|
| gen_g_loss = generator_loss(disc_fake_y)
|
| gen_f_loss = generator_loss(disc_fake_x)
|
|
|
| total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
|
|
|
|
|
| total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
|
| total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
|
|
|
| disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
|
| disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
|
|
|
|
|
| generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
|
| generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)
|
|
|
| discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
|
| discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)
|
|
|
|
|
| generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
|
| generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
|
|
|
| discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
|
| discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))
|
|
|
| def main():
|
| train_dataset = get_dataset(DATA_PATH, "train").batch(1)
|
|
|
| for epoch in range(EPOCHS):
|
| start = time.time()
|
| print(f"Epoch {epoch} starting...")
|
|
|
| for n, (image_x, image_y) in train_dataset.enumerate():
|
| train_step(image_x, image_y)
|
| if n % 100 == 0:
|
| print ('.', end='', flush=True)
|
|
|
| print(f"\nTime for epoch {epoch} is {time.time()-start} sec")
|
|
|
|
|
| generator_g.save_weights(f"GeneratorHtoZ_epoch_{epoch}.h5")
|
| generator_f.save_weights(f"GeneratorZtoH_epoch_{epoch}.h5")
|
|
|
|
|
| generator_g.save_weights("GeneratorHtoZ.h5")
|
| generator_f.save_weights("GeneratorZtoH.h5")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|