import gc import numpy as np import cv2 from tqdm import tqdm import tensorflow as tf from tensorflow.keras.optimizers import Adam from tensorflow.keras.models import Sequential, Model from tensorflow.keras import layers from tensorflow.keras.applications import EfficientNetV2S from tensorflow.keras.layers import ( Dense, Flatten, Conv2D, Activation, BatchNormalization, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D, Dropout, Input, concatenate, add, Conv2DTranspose, Lambda, SpatialDropout2D, Cropping2D, UpSampling2D, LeakyReLU, ZeroPadding2D, Reshape, Concatenate, Multiply, Permute, Add ) from .contour import get_contours_v2 from .modules import ( MultipleTrackers, DropBlockNoise, squeeze_excite_block, spatial_squeeze_excite_block, channel_spatial_squeeze_excite, DoubleConv, UpSampling2D_block, Conv2DTranspose_block, PixelShuffle_block ) from .utils import mae IMAGE_SIZE = 224 def adjust_pretrained_weights(model_cls, input_size, name=None): weights_model = model_cls(weights='imagenet', include_top=False, input_shape=(*input_size, 3)) target_model = model_cls(weights=None, include_top=False, input_shape=(*input_size, 1)) weights = weights_model.get_weights() weights[0] = np.sum(weights[0], axis=2, keepdims=True) target_model.set_weights(weights) del weights_model tf.keras.backend.clear_session() gc.collect() if name: target_model._name = name return target_model def get_efficient_unet(name=None, option='full', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1), encoder_weights=None, block_type='conv-transpose', output_activation='sigmoid', kernel_initializer='glorot_uniform'): if encoder_weights == 'imagenet': encoder = adjust_pretrained_weights(EfficientNetV2S, input_shape[:-1], name) elif encoder_weights is None: encoder = EfficientNetV2S(weights=None, include_top=False, input_shape=input_shape) encoder._name = name else: raise ValueError(encoder_weights) if option == 'encoder': return encoder MBConvBlocks = [] skip_candidates = ['1b', '2d', '3d', '4f'] for mbblock_nr in skip_candidates: mbblock = encoder.get_layer('block{}_add'.format(mbblock_nr)).output MBConvBlocks.append(mbblock) head = encoder.get_layer('top_activation').output blocks = MBConvBlocks + [head] if block_type == 'upsampling': UpBlock = UpSampling2D_block elif block_type == 'conv-transpose': UpBlock = Conv2DTranspose_block elif block_type == 'pixel-shuffle': UpBlock = PixelShuffle_block else: raise ValueError(block_type) o = blocks.pop() o = UpBlock(512, initializer=kernel_initializer, skip=blocks.pop())(o) o = UpBlock(256, initializer=kernel_initializer, skip=blocks.pop())(o) o = UpBlock(128, initializer=kernel_initializer, skip=blocks.pop())(o) o = UpBlock(64, initializer=kernel_initializer, skip=blocks.pop())(o) o = UpBlock(32, initializer=kernel_initializer, skip=None)(o) o = Conv2D(input_shape[-1], (1, 1), padding='same', activation=output_activation, kernel_initializer=kernel_initializer)(o) model = Model(encoder.input, o, name=name) if option == 'full': return model, encoder elif option == 'model': return model else: raise ValueError(option) class DCGAN(): def __init__(self, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1), architecture='two-stage', pretrain_weights=None, output_activation='sigmoid', block_type='conv-transpose', kernel_initializer='glorot_uniform', noise=None, C=1.): self.C = C # Build kwargs = dict(input_shape=input_shape, output_activation=output_activation, encoder_weights=pretrain_weights, block_type=block_type, kernel_initializer=kernel_initializer) if architecture == 'two-stage': encoder = get_efficient_unet(name='dcgan_disc', option='encoder', **kwargs) self.generator = get_efficient_unet(name='dcgan_gen', option='model', **kwargs) elif architecture == 'shared': self.generator, encoder = get_efficient_unet(name='dcgan', option='full', **kwargs) else: raise ValueError(f'Unsupport architecture: {architecture}') gpooling = GlobalAveragePooling2D()(encoder.output) prediction = Dense(1, activation='sigmoid')(gpooling) self.discriminator = Model(encoder.input, prediction, name='dcgan_disc') tf.keras.backend.clear_session() _ = gc.collect() if noise: gen_inputs = self.generator.input corrupted_inputs = noise(gen_inputs) outputs = self.generator(corrupted_inputs) self.generator = Model(gen_inputs, outputs, name='dcgan_gen') tf.keras.backend.clear_session() _ = gc.collect() if output_activation == 'tanh': self.process_input = layers.Lambda(lambda img: (img*2.-1.), name='dcgan_normalize') self.process_output = layers.Lambda(lambda img: (img*0.5+0.5), name='dcgan_denormalize') gen_inputs = self.generator.input process_inputs = self.process_input(gen_inputs) process_inputs = self.generator(process_inputs) gen_outputs = self.process_output(process_inputs) self.generator = Model(gen_inputs, gen_outputs, name='dcgan_gen') disc_inputs = self.discriminator.input process_inputs = self.process_input(disc_inputs) disc_outputs = self.discriminator(process_inputs) self.discriminator = Model(disc_inputs, disc_outputs, name='dcgan_disc') tf.keras.backend.clear_session() _ = gc.collect() def summary(self): self.generator.summary() self.discriminator.summary() def compile(self, generator_optimizer=Adam(5e-4, 0.5), discriminator_optimizer=Adam(5e-4), reconstruction_loss=mae, discriminative_loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), reconstruction_metrics=[], discriminative_metrics=[]): self.discriminator_optimizer = discriminator_optimizer self.discriminator.compile(optimizer=self.discriminator_optimizer) self.generator_optimizer = generator_optimizer self.generator.compile(optimizer=self.generator_optimizer) self.loss = discriminative_loss self.reconstruction_loss = reconstruction_loss self.d_loss_tracker = tf.keras.metrics.Mean() self.g_loss_tracker = tf.keras.metrics.Mean() self.g_recon_tracker = tf.keras.metrics.Mean() self.g_disc_tracker = tf.keras.metrics.Mean() self.g_metric_trackers = [(tf.keras.metrics.Mean(), metric) for metric in reconstruction_metrics] self.d_metric_trackers = [(tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), metric) for metric in discriminative_metrics] all_trackers = [self.d_loss_tracker, self.g_loss_tracker, self.g_recon_tracker, self.g_disc_tracker] + \ [tracker for tracker,_ in self.g_metric_trackers] + \ [tracker for t in self.d_metric_trackers for tracker in t[:-1]] self.all_trackers = MultipleTrackers(all_trackers) def discriminator_loss(self, real_output, fake_output): real_loss = self.loss(tf.ones_like(real_output), real_output) fake_loss = self.loss(tf.zeros_like(fake_output), fake_output) total_loss = 0.5*(real_loss + fake_loss) return total_loss def generator_loss(self, fake_output): return self.loss(tf.ones_like(fake_output), fake_output) @tf.function def train_step(self, images): masked, original = images n_samples = tf.shape(original)[0] with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = self.generator(masked, training=True) real_output = self.discriminator(original, training=True) fake_output = self.discriminator(generated_images, training=True) gen_disc_loss = self.generator_loss(fake_output) recon_loss = self.reconstruction_loss(original, generated_images) gen_loss = self.C*recon_loss + gen_disc_loss disc_loss = self.discriminator_loss(real_output, fake_output) gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables) self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables)) self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables)) self.d_loss_tracker.update_state(tf.repeat([[disc_loss]], repeats=n_samples, axis=0)) self.g_loss_tracker.update_state(tf.repeat([[gen_loss]], repeats=n_samples, axis=0)) self.g_recon_tracker.update_state(tf.repeat([[recon_loss]], repeats=n_samples, axis=0)) self.g_disc_tracker.update_state(tf.repeat([[gen_disc_loss]], repeats=n_samples, axis=0)) logs = {'d_loss': self.d_loss_tracker.result()} for tracker, real_tracker, fake_tracker, metric in self.d_metric_trackers: v_real = metric(tf.ones_like(real_output), real_output) v_fake = metric(tf.zeros_like(fake_output), fake_output) v = 0.5*(v_real + v_fake) tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0)) real_tracker.update_state(tf.repeat([[v_real]], repeats=n_samples, axis=0)) fake_tracker.update_state(tf.repeat([[v_fake]], repeats=n_samples, axis=0)) metric_name = metric.__name__ logs['d_' + metric_name] = tracker.result() logs['d_real_' + metric_name] = real_tracker.result() logs['d_fake_' + metric_name] = fake_tracker.result() logs['g_loss'] = self.g_loss_tracker.result() logs['g_recon'] = self.g_recon_tracker.result() logs['g_disc'] = self.g_disc_tracker.result() for tracker, metric in self.g_metric_trackers: v = metric(original, generated_images) tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0)) logs['g_' + metric.__name__] = tracker.result() return logs @tf.function def val_step(self, images): masked, original = images n_samples = tf.shape(original)[0] generated_images = self.generator(masked, training=False) real_output = self.discriminator(original, training=False) fake_output = self.discriminator(generated_images, training=False) gen_disc_loss = self.generator_loss(fake_output) recon_loss = self.reconstruction_loss(original, generated_images) gen_loss = self.C*recon_loss + gen_disc_loss disc_loss = self.discriminator_loss(real_output, fake_output) self.d_loss_tracker.update_state(tf.repeat([[disc_loss]], repeats=n_samples, axis=0)) self.g_loss_tracker.update_state(tf.repeat([[gen_loss]], repeats=n_samples, axis=0)) self.g_recon_tracker.update_state(tf.repeat([[recon_loss]], repeats=n_samples, axis=0)) self.g_disc_tracker.update_state(tf.repeat([[gen_disc_loss]], repeats=n_samples, axis=0)) logs = {'val_d_loss': self.d_loss_tracker.result()} for tracker, real_tracker, fake_tracker, metric in self.d_metric_trackers: v_real = metric(tf.ones_like(real_output), real_output) v_fake = metric(tf.zeros_like(fake_output), fake_output) v = 0.5*(v_real + v_fake) tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0)) real_tracker.update_state(tf.repeat([[v_real]], repeats=n_samples, axis=0)) fake_tracker.update_state(tf.repeat([[v_fake]], repeats=n_samples, axis=0)) metric_name = metric.__name__ logs['val_d_' + metric_name] = tracker.result() logs['val_d_real_' + metric_name] = real_tracker.result() logs['val_d_fake_' + metric_name] = fake_tracker.result() logs['val_g_loss'] = self.g_loss_tracker.result() logs['val_g_recon'] = self.g_recon_tracker.result() logs['val_g_disc'] = self.g_disc_tracker.result() for tracker, metric in self.g_metric_trackers: v = metric(original, generated_images) tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0)) logs['val_g_' + metric.__name__] = tracker.result() return logs def fit(self, trainset, valset=None, trainsize=-1, valsize=-1, epochs=1, display_per_epochs=5, generator_callbacks=[], discriminator_callbacks=[]): print('🌊🐉 Start Training 🐉🌊') gen_callback_tracker = tf.keras.callbacks.CallbackList( generator_callbacks, add_history=True, model=self.generator ) disc_callback_tracker = tf.keras.callbacks.CallbackList( discriminator_callbacks, add_history=True, model=self.discriminator ) callbacks_tracker = MultipleTrackers([gen_callback_tracker, disc_callback_tracker]) logs = {} callbacks_tracker.on_train_begin(logs=logs) for epoch in range(epochs): print(f'Epochs {epoch+1}/{epochs}:') callbacks_tracker.on_epoch_begin(epoch, logs=logs) batches = tqdm(trainset, desc="Train", total=trainsize, unit="step", position=0, leave=True) for batch, image_batch in enumerate(batches): callbacks_tracker.on_batch_begin(batch, logs=logs) callbacks_tracker.on_train_batch_begin(batch, logs=logs) train_logs = {k:v.numpy() for k, v in self.train_step(image_batch).items()} logs.update(train_logs) callbacks_tracker.on_train_batch_end(batch, logs=logs) callbacks_tracker.on_batch_end(batch, logs=logs) batches.set_postfix({'d_loss': train_logs['d_loss'], 'g_loss': train_logs['g_loss'] }) # Presentation stats = ", ".join("{}={:.3g}".format(k, v) for k, v in logs.items() if 'val_' not in k and 'loss' not in k) print('Train:', stats) batches.close() if valset: self.all_trackers.reset_state() batches = tqdm(valset, desc="Valid", total=valsize, unit="step", position=0, leave=True) for batch, image_batch in enumerate(batches): callbacks_tracker.on_batch_begin(batch, logs=logs) callbacks_tracker.on_test_batch_begin(batch, logs=logs) val_logs = {k:v.numpy() for k, v in self.val_step(image_batch).items()} logs.update(val_logs) callbacks_tracker.on_test_batch_end(batch, logs=logs) callbacks_tracker.on_batch_end(batch, logs=logs) # Presentation batches.set_postfix({'val_d_loss': val_logs['val_d_loss'], 'val_g_loss': val_logs['val_g_loss'] }) stats = ", ".join("{}={:.3g}".format(k, v) for k, v in logs.items() if 'val_' in k and 'loss' not in k) print('Valid:', stats) batches.close() if epoch % display_per_epochs == 0: print('-'*128) self.visualize_samples((image_batch[0][:2], image_batch[1][:2])) self.all_trackers.reset_state() callbacks_tracker.on_epoch_end(epoch, logs=logs) # tf.keras.backend.clear_session() _ = gc.collect() if self.generator.stop_training or self.discriminator.stop_training: break print('-'*128) callbacks_tracker.on_train_end(logs=logs) tf.keras.backend.clear_session() _ = gc.collect() gen_history = None for cb in gen_callback_tracker: if isinstance(cb, tf.keras.callbacks.History): gen_history = cb gen_history.history = {k:v for k,v in cb.history.items() if 'd_' not in k} disc_history = None for cb in disc_callback_tracker: if isinstance(cb, tf.keras.callbacks.History): disc_history = cb disc_history.history = {k:v for k,v in cb.history.items() if 'g_' not in k} return {'generator':gen_history, 'discriminator':disc_history} def visualize_samples(self, samples, figsize=(12, 2)): x, y = samples y_pred = self.generator.predict(x[:2], verbose=0) fig, axs = plt.subplots(1, 6, figsize=figsize) for i in range(2): pos = 3*i axs[pos].imshow(x[i], cmap='gray', vmin=0., vmax=1.) axs[pos].set_title('Masked') axs[pos].axis('off') axs[pos+1].imshow(y[i], cmap='gray', vmin=0., vmax=1.) axs[pos+1].set_title('Original') axs[pos+1].axis('off') axs[pos+2].imshow(y_pred[i], cmap='gray', vmin=0., vmax=1.) axs[pos+2].set_title('Predicted') axs[pos+2].axis('off') plt.show() # tf.keras.backend.clear_session() del y_pred _ = gc.collect()