Spaces:
Running
Running
| """ | |
| Title: CycleGAN | |
| Author: [A_K_Nain](https://twitter.com/A_K_Nain) | |
| Date created: 2020/08/12 | |
| Last modified: 2024/09/30 | |
| Description: Implementation of CycleGAN. | |
| Accelerator: GPU | |
| """ | |
| """ | |
| ## CycleGAN | |
| CycleGAN is a model that aims to solve the image-to-image translation | |
| problem. The goal of the image-to-image translation problem is to learn the | |
| mapping between an input image and an output image using a training set of | |
| aligned image pairs. However, obtaining paired examples isn't always feasible. | |
| CycleGAN tries to learn this mapping without requiring paired input-output images, | |
| using cycle-consistent adversarial networks. | |
| - [Paper](https://arxiv.org/abs/1703.10593) | |
| - [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) | |
| """ | |
| """ | |
| ## Setup | |
| """ | |
| import os | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| import keras | |
| from keras import layers, ops | |
| import tensorflow_datasets as tfds | |
| tfds.disable_progress_bar() | |
| autotune = tf.data.AUTOTUNE | |
| os.environ["KERAS_BACKEND"] = "tensorflow" | |
| """ | |
| ## Prepare the dataset | |
| In this example, we will be using the | |
| [horse to zebra](https://www.tensorflow.org/datasets/catalog/cycle_gan#cycle_ganhorse2zebra) | |
| dataset. | |
| """ | |
| # Load the horse-zebra dataset using tensorflow-datasets. | |
| dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True) | |
| train_horses, train_zebras = dataset["trainA"], dataset["trainB"] | |
| test_horses, test_zebras = dataset["testA"], dataset["testB"] | |
| # Define the standard image size. | |
| orig_img_size = (286, 286) | |
| # Size of the random crops to be used during training. | |
| input_img_size = (256, 256, 3) | |
| # Weights initializer for the layers. | |
| kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02) | |
| # Gamma initializer for instance normalization. | |
| gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02) | |
| buffer_size = 256 | |
| batch_size = 1 | |
| def normalize_img(img): | |
| img = ops.cast(img, dtype=tf.float32) | |
| # Map values in the range [-1, 1] | |
| return (img / 127.5) - 1.0 | |
| def preprocess_train_image(img, label): | |
| # Random flip | |
| img = tf.image.random_flip_left_right(img) | |
| # Resize to the original size first | |
| img = ops.image.resize(img, [*orig_img_size]) | |
| # Random crop to 256X256 | |
| img = tf.image.random_crop(img, size=[*input_img_size]) | |
| # Normalize the pixel values in the range [-1, 1] | |
| img = normalize_img(img) | |
| return img | |
| def preprocess_test_image(img, label): | |
| # Only resizing and normalization for the test images. | |
| img = ops.image.resize(img, [input_img_size[0], input_img_size[1]]) | |
| img = normalize_img(img) | |
| return img | |
| """ | |
| ## Create `Dataset` objects | |
| """ | |
| # Apply the preprocessing operations to the training data | |
| train_horses = ( | |
| train_horses.map(preprocess_train_image, num_parallel_calls=autotune) | |
| .cache() | |
| .shuffle(buffer_size) | |
| .batch(batch_size) | |
| ) | |
| train_zebras = ( | |
| train_zebras.map(preprocess_train_image, num_parallel_calls=autotune) | |
| .cache() | |
| .shuffle(buffer_size) | |
| .batch(batch_size) | |
| ) | |
| # Apply the preprocessing operations to the test data | |
| test_horses = ( | |
| test_horses.map(preprocess_test_image, num_parallel_calls=autotune) | |
| .cache() | |
| .shuffle(buffer_size) | |
| .batch(batch_size) | |
| ) | |
| test_zebras = ( | |
| test_zebras.map(preprocess_test_image, num_parallel_calls=autotune) | |
| .cache() | |
| .shuffle(buffer_size) | |
| .batch(batch_size) | |
| ) | |
| """ | |
| ## Visualize some samples | |
| """ | |
| _, ax = plt.subplots(4, 2, figsize=(10, 15)) | |
| for i, samples in enumerate(zip(train_horses.take(4), train_zebras.take(4))): | |
| horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8) | |
| zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8) | |
| ax[i, 0].imshow(horse) | |
| ax[i, 1].imshow(zebra) | |
| plt.show() | |
| """ | |
| ## Building blocks used in the CycleGAN generators and discriminators | |
| """ | |
| class ReflectionPadding2D(layers.Layer): | |
| """Implements Reflection Padding as a layer. | |
| Args: | |
| padding(tuple): Amount of padding for the | |
| spatial dimensions. | |
| Returns: | |
| A padded tensor with the same type as the input tensor. | |
| """ | |
| def __init__(self, padding=(1, 1), **kwargs): | |
| self.padding = tuple(padding) | |
| super().__init__(**kwargs) | |
| def call(self, input_tensor, mask=None): | |
| padding_width, padding_height = self.padding | |
| padding_tensor = [ | |
| [0, 0], | |
| [padding_height, padding_height], | |
| [padding_width, padding_width], | |
| [0, 0], | |
| ] | |
| return ops.pad(input_tensor, padding_tensor, mode="REFLECT") | |
| def residual_block( | |
| x, | |
| activation, | |
| kernel_initializer=kernel_init, | |
| kernel_size=(3, 3), | |
| strides=(1, 1), | |
| padding="valid", | |
| gamma_initializer=gamma_init, | |
| use_bias=False, | |
| ): | |
| dim = x.shape[-1] | |
| input_tensor = x | |
| x = ReflectionPadding2D()(input_tensor) | |
| x = layers.Conv2D( | |
| dim, | |
| kernel_size, | |
| strides=strides, | |
| kernel_initializer=kernel_initializer, | |
| padding=padding, | |
| use_bias=use_bias, | |
| )(x) | |
| x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( | |
| x | |
| ) | |
| x = activation(x) | |
| x = ReflectionPadding2D()(x) | |
| x = layers.Conv2D( | |
| dim, | |
| kernel_size, | |
| strides=strides, | |
| kernel_initializer=kernel_initializer, | |
| padding=padding, | |
| use_bias=use_bias, | |
| )(x) | |
| x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( | |
| x | |
| ) | |
| x = layers.add([input_tensor, x]) | |
| return x | |
| def downsample( | |
| x, | |
| filters, | |
| activation, | |
| kernel_initializer=kernel_init, | |
| kernel_size=(3, 3), | |
| strides=(2, 2), | |
| padding="same", | |
| gamma_initializer=gamma_init, | |
| use_bias=False, | |
| ): | |
| x = layers.Conv2D( | |
| filters, | |
| kernel_size, | |
| strides=strides, | |
| kernel_initializer=kernel_initializer, | |
| padding=padding, | |
| use_bias=use_bias, | |
| )(x) | |
| x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( | |
| x | |
| ) | |
| if activation: | |
| x = activation(x) | |
| return x | |
| def upsample( | |
| x, | |
| filters, | |
| activation, | |
| kernel_size=(3, 3), | |
| strides=(2, 2), | |
| padding="same", | |
| kernel_initializer=kernel_init, | |
| gamma_initializer=gamma_init, | |
| use_bias=False, | |
| ): | |
| x = layers.Conv2DTranspose( | |
| filters, | |
| kernel_size, | |
| strides=strides, | |
| padding=padding, | |
| kernel_initializer=kernel_initializer, | |
| use_bias=use_bias, | |
| )(x) | |
| x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( | |
| x | |
| ) | |
| if activation: | |
| x = activation(x) | |
| return x | |
| """ | |
| ## Build the generators | |
| The generator consists of downsampling blocks: nine residual blocks | |
| and upsampling blocks. The structure of the generator is the following: | |
| ``` | |
| c7s1-64 ==> Conv block with `relu` activation, filter size of 7 | |
| d128 ====| | |
| |-> 2 downsampling blocks | |
| d256 ====| | |
| R256 ====| | |
| R256 | | |
| R256 | | |
| R256 | | |
| R256 |-> 9 residual blocks | |
| R256 | | |
| R256 | | |
| R256 | | |
| R256 ====| | |
| u128 ====| | |
| |-> 2 upsampling blocks | |
| u64 ====| | |
| c7s1-3 => Last conv block with `tanh` activation, filter size of 7. | |
| ``` | |
| """ | |
| def get_resnet_generator( | |
| filters=64, | |
| num_downsampling_blocks=2, | |
| num_residual_blocks=9, | |
| num_upsample_blocks=2, | |
| gamma_initializer=gamma_init, | |
| name=None, | |
| ): | |
| img_input = layers.Input(shape=input_img_size, name=name + "_img_input") | |
| x = ReflectionPadding2D(padding=(3, 3))(img_input) | |
| x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)( | |
| x | |
| ) | |
| x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)( | |
| x | |
| ) | |
| x = layers.Activation("relu")(x) | |
| # Downsampling | |
| for _ in range(num_downsampling_blocks): | |
| filters *= 2 | |
| x = downsample(x, filters=filters, activation=layers.Activation("relu")) | |
| # Residual blocks | |
| for _ in range(num_residual_blocks): | |
| x = residual_block(x, activation=layers.Activation("relu")) | |
| # Upsampling | |
| for _ in range(num_upsample_blocks): | |
| filters //= 2 | |
| x = upsample(x, filters, activation=layers.Activation("relu")) | |
| # Final block | |
| x = ReflectionPadding2D(padding=(3, 3))(x) | |
| x = layers.Conv2D(3, (7, 7), padding="valid")(x) | |
| x = layers.Activation("tanh")(x) | |
| model = keras.models.Model(img_input, x, name=name) | |
| return model | |
| """ | |
| ## Build the discriminators | |
| The discriminators implement the following architecture: | |
| `C64->C128->C256->C512` | |
| """ | |
| def get_discriminator( | |
| filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None | |
| ): | |
| img_input = layers.Input(shape=input_img_size, name=name + "_img_input") | |
| x = layers.Conv2D( | |
| filters, | |
| (4, 4), | |
| strides=(2, 2), | |
| padding="same", | |
| kernel_initializer=kernel_initializer, | |
| )(img_input) | |
| x = layers.LeakyReLU(0.2)(x) | |
| num_filters = filters | |
| for num_downsample_block in range(3): | |
| num_filters *= 2 | |
| if num_downsample_block < 2: | |
| x = downsample( | |
| x, | |
| filters=num_filters, | |
| activation=layers.LeakyReLU(0.2), | |
| kernel_size=(4, 4), | |
| strides=(2, 2), | |
| ) | |
| else: | |
| x = downsample( | |
| x, | |
| filters=num_filters, | |
| activation=layers.LeakyReLU(0.2), | |
| kernel_size=(4, 4), | |
| strides=(1, 1), | |
| ) | |
| x = layers.Conv2D( | |
| 1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer | |
| )(x) | |
| model = keras.models.Model(inputs=img_input, outputs=x, name=name) | |
| return model | |
| # Get the generators | |
| gen_G = get_resnet_generator(name="generator_G") | |
| gen_F = get_resnet_generator(name="generator_F") | |
| # Get the discriminators | |
| disc_X = get_discriminator(name="discriminator_X") | |
| disc_Y = get_discriminator(name="discriminator_Y") | |
| """ | |
| ## Build the CycleGAN model | |
| We will override the `train_step()` method of the `Model` class | |
| for training via `fit()`. | |
| """ | |
| class CycleGan(keras.Model): | |
| def __init__( | |
| self, | |
| generator_G, | |
| generator_F, | |
| discriminator_X, | |
| discriminator_Y, | |
| lambda_cycle=10.0, | |
| lambda_identity=0.5, | |
| ): | |
| super().__init__() | |
| self.gen_G = generator_G | |
| self.gen_F = generator_F | |
| self.disc_X = discriminator_X | |
| self.disc_Y = discriminator_Y | |
| self.lambda_cycle = lambda_cycle | |
| self.lambda_identity = lambda_identity | |
| def call(self, inputs): | |
| return ( | |
| self.disc_X(inputs), | |
| self.disc_Y(inputs), | |
| self.gen_G(inputs), | |
| self.gen_F(inputs), | |
| ) | |
| def compile( | |
| self, | |
| gen_G_optimizer, | |
| gen_F_optimizer, | |
| disc_X_optimizer, | |
| disc_Y_optimizer, | |
| gen_loss_fn, | |
| disc_loss_fn, | |
| ): | |
| super().compile() | |
| self.gen_G_optimizer = gen_G_optimizer | |
| self.gen_F_optimizer = gen_F_optimizer | |
| self.disc_X_optimizer = disc_X_optimizer | |
| self.disc_Y_optimizer = disc_Y_optimizer | |
| self.generator_loss_fn = gen_loss_fn | |
| self.discriminator_loss_fn = disc_loss_fn | |
| self.cycle_loss_fn = keras.losses.MeanAbsoluteError() | |
| self.identity_loss_fn = keras.losses.MeanAbsoluteError() | |
| def train_step(self, batch_data): | |
| # x is Horse and y is zebra | |
| real_x, real_y = batch_data | |
| # For CycleGAN, we need to calculate different | |
| # kinds of losses for the generators and discriminators. | |
| # We will perform the following steps here: | |
| # | |
| # 1. Pass real images through the generators and get the generated images | |
| # 2. Pass the generated images back to the generators to check if we | |
| # can predict the original image from the generated image. | |
| # 3. Do an identity mapping of the real images using the generators. | |
| # 4. Pass the generated images in 1) to the corresponding discriminators. | |
| # 5. Calculate the generators total loss (adversarial + cycle + identity) | |
| # 6. Calculate the discriminators loss | |
| # 7. Update the weights of the generators | |
| # 8. Update the weights of the discriminators | |
| # 9. Return the losses in a dictionary | |
| with tf.GradientTape(persistent=True) as tape: | |
| # Horse to fake zebra | |
| fake_y = self.gen_G(real_x, training=True) | |
| # Zebra to fake horse -> y2x | |
| fake_x = self.gen_F(real_y, training=True) | |
| # Cycle (Horse to fake zebra to fake horse): x -> y -> x | |
| cycled_x = self.gen_F(fake_y, training=True) | |
| # Cycle (Zebra to fake horse to fake zebra) y -> x -> y | |
| cycled_y = self.gen_G(fake_x, training=True) | |
| # Identity mapping | |
| same_x = self.gen_F(real_x, training=True) | |
| same_y = self.gen_G(real_y, training=True) | |
| # Discriminator output | |
| disc_real_x = self.disc_X(real_x, training=True) | |
| disc_fake_x = self.disc_X(fake_x, training=True) | |
| disc_real_y = self.disc_Y(real_y, training=True) | |
| disc_fake_y = self.disc_Y(fake_y, training=True) | |
| # Generator adversarial loss | |
| gen_G_loss = self.generator_loss_fn(disc_fake_y) | |
| gen_F_loss = self.generator_loss_fn(disc_fake_x) | |
| # Generator cycle loss | |
| cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle | |
| cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle | |
| # Generator identity loss | |
| id_loss_G = ( | |
| self.identity_loss_fn(real_y, same_y) | |
| * self.lambda_cycle | |
| * self.lambda_identity | |
| ) | |
| id_loss_F = ( | |
| self.identity_loss_fn(real_x, same_x) | |
| * self.lambda_cycle | |
| * self.lambda_identity | |
| ) | |
| # Total generator loss | |
| total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G | |
| total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F | |
| # Discriminator loss | |
| disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x) | |
| disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y) | |
| # Get the gradients for the generators | |
| grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables) | |
| grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables) | |
| # Get the gradients for the discriminators | |
| disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables) | |
| disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables) | |
| # Update the weights of the generators | |
| self.gen_G_optimizer.apply_gradients( | |
| zip(grads_G, self.gen_G.trainable_variables) | |
| ) | |
| self.gen_F_optimizer.apply_gradients( | |
| zip(grads_F, self.gen_F.trainable_variables) | |
| ) | |
| # Update the weights of the discriminators | |
| self.disc_X_optimizer.apply_gradients( | |
| zip(disc_X_grads, self.disc_X.trainable_variables) | |
| ) | |
| self.disc_Y_optimizer.apply_gradients( | |
| zip(disc_Y_grads, self.disc_Y.trainable_variables) | |
| ) | |
| return { | |
| "G_loss": total_loss_G, | |
| "F_loss": total_loss_F, | |
| "D_X_loss": disc_X_loss, | |
| "D_Y_loss": disc_Y_loss, | |
| } | |
| """ | |
| ## Create a callback that periodically saves generated images | |
| """ | |
| class GANMonitor(keras.callbacks.Callback): | |
| """A callback to generate and save images after each epoch""" | |
| def __init__(self, num_img=4): | |
| self.num_img = num_img | |
| def on_epoch_end(self, epoch, logs=None): | |
| _, ax = plt.subplots(4, 2, figsize=(12, 12)) | |
| for i, img in enumerate(test_horses.take(self.num_img)): | |
| prediction = self.model.gen_G(img)[0].numpy() | |
| prediction = (prediction * 127.5 + 127.5).astype(np.uint8) | |
| img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8) | |
| ax[i, 0].imshow(img) | |
| ax[i, 1].imshow(prediction) | |
| ax[i, 0].set_title("Input image") | |
| ax[i, 1].set_title("Translated image") | |
| ax[i, 0].axis("off") | |
| ax[i, 1].axis("off") | |
| prediction = keras.utils.array_to_img(prediction) | |
| prediction.save( | |
| "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1) | |
| ) | |
| plt.show() | |
| plt.close() | |
| """ | |
| ## Train the end-to-end model | |
| """ | |
| # Loss function for evaluating adversarial loss | |
| adv_loss_fn = keras.losses.MeanSquaredError() | |
| # Define the loss function for the generators | |
| def generator_loss_fn(fake): | |
| fake_loss = adv_loss_fn(ops.ones_like(fake), fake) | |
| return fake_loss | |
| # Define the loss function for the discriminators | |
| def discriminator_loss_fn(real, fake): | |
| real_loss = adv_loss_fn(ops.ones_like(real), real) | |
| fake_loss = adv_loss_fn(ops.zeros_like(fake), fake) | |
| return (real_loss + fake_loss) * 0.5 | |
| # Create cycle gan model | |
| cycle_gan_model = CycleGan( | |
| generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y | |
| ) | |
| # Compile the model | |
| cycle_gan_model.compile( | |
| gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), | |
| gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), | |
| disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), | |
| disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), | |
| gen_loss_fn=generator_loss_fn, | |
| disc_loss_fn=discriminator_loss_fn, | |
| ) | |
| # Callbacks | |
| plotter = GANMonitor() | |
| checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5" | |
| model_checkpoint_callback = keras.callbacks.ModelCheckpoint( | |
| filepath=checkpoint_filepath, save_weights_only=True | |
| ) | |
| # Here we will train the model for just one epoch as each epoch takes around | |
| # 7 minutes on a single P100 backed machine. | |
| cycle_gan_model.fit( | |
| tf.data.Dataset.zip((train_horses, train_zebras)), | |
| epochs=90, | |
| callbacks=[plotter, model_checkpoint_callback], | |
| ) | |
| """ | |
| Test the performance of the model. | |
| """ | |
| # Once the weights are loaded, we will take a few samples from the test data and check the model's performance. | |
| # Load the checkpoints | |
| cycle_gan_model.load_weights(checkpoint_filepath) | |
| print("Weights loaded successfully") | |
| _, ax = plt.subplots(4, 2, figsize=(10, 15)) | |
| for i, img in enumerate(test_horses.take(4)): | |
| prediction = cycle_gan_model.gen_G(img, training=False)[0].numpy() | |
| prediction = (prediction * 127.5 + 127.5).astype(np.uint8) | |
| img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8) | |
| ax[i, 0].imshow(img) | |
| ax[i, 1].imshow(prediction) | |
| ax[i, 0].set_title("Input image") | |
| ax[i, 0].set_title("Input image") | |
| ax[i, 1].set_title("Translated image") | |
| ax[i, 0].axis("off") | |
| ax[i, 1].axis("off") | |
| prediction = keras.utils.array_to_img(prediction) | |
| prediction.save("predicted_img_{i}.png".format(i=i)) | |
| plt.tight_layout() | |
| plt.show() | |