|
|
import gc |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from tqdm import tqdm |
|
|
|
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.optimizers import Adam |
|
|
from tensorflow.keras.models import Sequential, Model |
|
|
|
|
|
from tensorflow.keras import layers |
|
|
from tensorflow.keras.applications import EfficientNetV2S |
|
|
from tensorflow.keras.layers import ( |
|
|
Dense, Flatten, Conv2D, Activation, BatchNormalization, |
|
|
MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D, |
|
|
Dropout, Input, concatenate, add, Conv2DTranspose, Lambda, |
|
|
SpatialDropout2D, Cropping2D, UpSampling2D, LeakyReLU, |
|
|
ZeroPadding2D, Reshape, Concatenate, Multiply, Permute, Add |
|
|
) |
|
|
|
|
|
from .contour import get_contours_v2 |
|
|
from .modules import ( |
|
|
MultipleTrackers, DropBlockNoise, squeeze_excite_block, spatial_squeeze_excite_block, |
|
|
channel_spatial_squeeze_excite, DoubleConv, UpSampling2D_block, Conv2DTranspose_block, |
|
|
PixelShuffle_block |
|
|
) |
|
|
from .utils import mae |
|
|
|
|
|
|
|
|
IMAGE_SIZE = 224 |
|
|
|
|
|
|
|
|
def adjust_pretrained_weights(model_cls, input_size, name=None): |
|
|
weights_model = model_cls(weights='imagenet', |
|
|
include_top=False, |
|
|
input_shape=(*input_size, 3)) |
|
|
target_model = model_cls(weights=None, |
|
|
include_top=False, |
|
|
input_shape=(*input_size, 1)) |
|
|
weights = weights_model.get_weights() |
|
|
weights[0] = np.sum(weights[0], axis=2, keepdims=True) |
|
|
target_model.set_weights(weights) |
|
|
|
|
|
del weights_model |
|
|
tf.keras.backend.clear_session() |
|
|
gc.collect() |
|
|
if name: |
|
|
target_model._name = name |
|
|
return target_model |
|
|
|
|
|
|
|
|
def get_efficient_unet(name=None, |
|
|
option='full', |
|
|
input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1), |
|
|
encoder_weights=None, |
|
|
block_type='conv-transpose', |
|
|
output_activation='sigmoid', |
|
|
kernel_initializer='glorot_uniform'): |
|
|
|
|
|
if encoder_weights == 'imagenet': |
|
|
encoder = adjust_pretrained_weights(EfficientNetV2S, input_shape[:-1], name) |
|
|
elif encoder_weights is None: |
|
|
encoder = EfficientNetV2S(weights=None, |
|
|
include_top=False, |
|
|
input_shape=input_shape) |
|
|
encoder._name = name |
|
|
else: |
|
|
raise ValueError(encoder_weights) |
|
|
|
|
|
if option == 'encoder': |
|
|
return encoder |
|
|
|
|
|
MBConvBlocks = [] |
|
|
|
|
|
skip_candidates = ['1b', '2d', '3d', '4f'] |
|
|
|
|
|
for mbblock_nr in skip_candidates: |
|
|
mbblock = encoder.get_layer('block{}_add'.format(mbblock_nr)).output |
|
|
MBConvBlocks.append(mbblock) |
|
|
|
|
|
head = encoder.get_layer('top_activation').output |
|
|
blocks = MBConvBlocks + [head] |
|
|
|
|
|
if block_type == 'upsampling': |
|
|
UpBlock = UpSampling2D_block |
|
|
elif block_type == 'conv-transpose': |
|
|
UpBlock = Conv2DTranspose_block |
|
|
elif block_type == 'pixel-shuffle': |
|
|
UpBlock = PixelShuffle_block |
|
|
else: |
|
|
raise ValueError(block_type) |
|
|
|
|
|
o = blocks.pop() |
|
|
o = UpBlock(512, initializer=kernel_initializer, skip=blocks.pop())(o) |
|
|
o = UpBlock(256, initializer=kernel_initializer, skip=blocks.pop())(o) |
|
|
o = UpBlock(128, initializer=kernel_initializer, skip=blocks.pop())(o) |
|
|
o = UpBlock(64, initializer=kernel_initializer, skip=blocks.pop())(o) |
|
|
o = UpBlock(32, initializer=kernel_initializer, skip=None)(o) |
|
|
o = Conv2D(input_shape[-1], (1, 1), padding='same', activation=output_activation, kernel_initializer=kernel_initializer)(o) |
|
|
|
|
|
model = Model(encoder.input, o, name=name) |
|
|
|
|
|
if option == 'full': |
|
|
return model, encoder |
|
|
elif option == 'model': |
|
|
return model |
|
|
else: |
|
|
raise ValueError(option) |
|
|
|
|
|
|
|
|
class DCGAN(): |
|
|
def __init__(self, |
|
|
input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1), |
|
|
architecture='two-stage', |
|
|
pretrain_weights=None, |
|
|
output_activation='sigmoid', |
|
|
block_type='conv-transpose', |
|
|
kernel_initializer='glorot_uniform', |
|
|
noise=None, |
|
|
C=1.): |
|
|
|
|
|
self.C = C |
|
|
|
|
|
kwargs = dict(input_shape=input_shape, |
|
|
output_activation=output_activation, |
|
|
encoder_weights=pretrain_weights, |
|
|
block_type=block_type, |
|
|
kernel_initializer=kernel_initializer) |
|
|
|
|
|
if architecture == 'two-stage': |
|
|
encoder = get_efficient_unet(name='dcgan_disc', |
|
|
option='encoder', |
|
|
**kwargs) |
|
|
|
|
|
self.generator = get_efficient_unet(name='dcgan_gen', option='model', **kwargs) |
|
|
elif architecture == 'shared': |
|
|
|
|
|
self.generator, encoder = get_efficient_unet(name='dcgan', option='full', **kwargs) |
|
|
else: |
|
|
raise ValueError(f'Unsupport architecture: {architecture}') |
|
|
|
|
|
gpooling = GlobalAveragePooling2D()(encoder.output) |
|
|
prediction = Dense(1, activation='sigmoid')(gpooling) |
|
|
self.discriminator = Model(encoder.input, prediction, name='dcgan_disc') |
|
|
|
|
|
tf.keras.backend.clear_session() |
|
|
_ = gc.collect() |
|
|
|
|
|
if noise: |
|
|
gen_inputs = self.generator.input |
|
|
corrupted_inputs = noise(gen_inputs) |
|
|
outputs = self.generator(corrupted_inputs) |
|
|
self.generator = Model(gen_inputs, outputs, name='dcgan_gen') |
|
|
|
|
|
tf.keras.backend.clear_session() |
|
|
_ = gc.collect() |
|
|
|
|
|
if output_activation == 'tanh': |
|
|
|
|
|
self.process_input = layers.Lambda(lambda img: (img*2.-1.), name='dcgan_normalize') |
|
|
self.process_output = layers.Lambda(lambda img: (img*0.5+0.5), name='dcgan_denormalize') |
|
|
gen_inputs = self.generator.input |
|
|
process_inputs = self.process_input(gen_inputs) |
|
|
process_inputs = self.generator(process_inputs) |
|
|
gen_outputs = self.process_output(process_inputs) |
|
|
self.generator = Model(gen_inputs, gen_outputs, name='dcgan_gen') |
|
|
|
|
|
disc_inputs = self.discriminator.input |
|
|
process_inputs = self.process_input(disc_inputs) |
|
|
disc_outputs = self.discriminator(process_inputs) |
|
|
self.discriminator = Model(disc_inputs, disc_outputs, name='dcgan_disc') |
|
|
|
|
|
tf.keras.backend.clear_session() |
|
|
_ = gc.collect() |
|
|
|
|
|
def summary(self): |
|
|
self.generator.summary() |
|
|
self.discriminator.summary() |
|
|
|
|
|
def compile(self, |
|
|
generator_optimizer=Adam(5e-4, 0.5), |
|
|
discriminator_optimizer=Adam(5e-4), |
|
|
reconstruction_loss=mae, |
|
|
discriminative_loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), |
|
|
reconstruction_metrics=[], |
|
|
discriminative_metrics=[]): |
|
|
|
|
|
self.discriminator_optimizer = discriminator_optimizer |
|
|
self.discriminator.compile(optimizer=self.discriminator_optimizer) |
|
|
|
|
|
self.generator_optimizer = generator_optimizer |
|
|
self.generator.compile(optimizer=self.generator_optimizer) |
|
|
|
|
|
self.loss = discriminative_loss |
|
|
self.reconstruction_loss = reconstruction_loss |
|
|
self.d_loss_tracker = tf.keras.metrics.Mean() |
|
|
self.g_loss_tracker = tf.keras.metrics.Mean() |
|
|
self.g_recon_tracker = tf.keras.metrics.Mean() |
|
|
self.g_disc_tracker = tf.keras.metrics.Mean() |
|
|
|
|
|
self.g_metric_trackers = [(tf.keras.metrics.Mean(), metric) for metric in reconstruction_metrics] |
|
|
self.d_metric_trackers = [(tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), tf.keras.metrics.Mean(), metric) for metric in discriminative_metrics] |
|
|
|
|
|
all_trackers = [self.d_loss_tracker, self.g_loss_tracker, self.g_recon_tracker, self.g_disc_tracker] + \ |
|
|
[tracker for tracker,_ in self.g_metric_trackers] + \ |
|
|
[tracker for t in self.d_metric_trackers for tracker in t[:-1]] |
|
|
self.all_trackers = MultipleTrackers(all_trackers) |
|
|
|
|
|
def discriminator_loss(self, real_output, fake_output): |
|
|
real_loss = self.loss(tf.ones_like(real_output), real_output) |
|
|
fake_loss = self.loss(tf.zeros_like(fake_output), fake_output) |
|
|
total_loss = 0.5*(real_loss + fake_loss) |
|
|
return total_loss |
|
|
|
|
|
def generator_loss(self, fake_output): |
|
|
return self.loss(tf.ones_like(fake_output), fake_output) |
|
|
|
|
|
@tf.function |
|
|
def train_step(self, images): |
|
|
masked, original = images |
|
|
n_samples = tf.shape(original)[0] |
|
|
|
|
|
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: |
|
|
generated_images = self.generator(masked, training=True) |
|
|
|
|
|
real_output = self.discriminator(original, training=True) |
|
|
fake_output = self.discriminator(generated_images, training=True) |
|
|
|
|
|
gen_disc_loss = self.generator_loss(fake_output) |
|
|
recon_loss = self.reconstruction_loss(original, generated_images) |
|
|
gen_loss = self.C*recon_loss + gen_disc_loss |
|
|
disc_loss = self.discriminator_loss(real_output, fake_output) |
|
|
|
|
|
gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables) |
|
|
gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables) |
|
|
|
|
|
self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables)) |
|
|
self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables)) |
|
|
|
|
|
self.d_loss_tracker.update_state(tf.repeat([[disc_loss]], repeats=n_samples, axis=0)) |
|
|
self.g_loss_tracker.update_state(tf.repeat([[gen_loss]], repeats=n_samples, axis=0)) |
|
|
self.g_recon_tracker.update_state(tf.repeat([[recon_loss]], repeats=n_samples, axis=0)) |
|
|
self.g_disc_tracker.update_state(tf.repeat([[gen_disc_loss]], repeats=n_samples, axis=0)) |
|
|
|
|
|
logs = {'d_loss': self.d_loss_tracker.result()} |
|
|
|
|
|
for tracker, real_tracker, fake_tracker, metric in self.d_metric_trackers: |
|
|
v_real = metric(tf.ones_like(real_output), real_output) |
|
|
v_fake = metric(tf.zeros_like(fake_output), fake_output) |
|
|
v = 0.5*(v_real + v_fake) |
|
|
tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0)) |
|
|
real_tracker.update_state(tf.repeat([[v_real]], repeats=n_samples, axis=0)) |
|
|
fake_tracker.update_state(tf.repeat([[v_fake]], repeats=n_samples, axis=0)) |
|
|
|
|
|
metric_name = metric.__name__ |
|
|
logs['d_' + metric_name] = tracker.result() |
|
|
logs['d_real_' + metric_name] = real_tracker.result() |
|
|
logs['d_fake_' + metric_name] = fake_tracker.result() |
|
|
|
|
|
logs['g_loss'] = self.g_loss_tracker.result() |
|
|
logs['g_recon'] = self.g_recon_tracker.result() |
|
|
logs['g_disc'] = self.g_disc_tracker.result() |
|
|
|
|
|
for tracker, metric in self.g_metric_trackers: |
|
|
v = metric(original, generated_images) |
|
|
tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0)) |
|
|
logs['g_' + metric.__name__] = tracker.result() |
|
|
|
|
|
return logs |
|
|
|
|
|
@tf.function |
|
|
def val_step(self, images): |
|
|
masked, original = images |
|
|
n_samples = tf.shape(original)[0] |
|
|
|
|
|
generated_images = self.generator(masked, training=False) |
|
|
|
|
|
real_output = self.discriminator(original, training=False) |
|
|
fake_output = self.discriminator(generated_images, training=False) |
|
|
|
|
|
gen_disc_loss = self.generator_loss(fake_output) |
|
|
recon_loss = self.reconstruction_loss(original, generated_images) |
|
|
gen_loss = self.C*recon_loss + gen_disc_loss |
|
|
disc_loss = self.discriminator_loss(real_output, fake_output) |
|
|
|
|
|
self.d_loss_tracker.update_state(tf.repeat([[disc_loss]], repeats=n_samples, axis=0)) |
|
|
self.g_loss_tracker.update_state(tf.repeat([[gen_loss]], repeats=n_samples, axis=0)) |
|
|
self.g_recon_tracker.update_state(tf.repeat([[recon_loss]], repeats=n_samples, axis=0)) |
|
|
self.g_disc_tracker.update_state(tf.repeat([[gen_disc_loss]], repeats=n_samples, axis=0)) |
|
|
|
|
|
logs = {'val_d_loss': self.d_loss_tracker.result()} |
|
|
|
|
|
for tracker, real_tracker, fake_tracker, metric in self.d_metric_trackers: |
|
|
v_real = metric(tf.ones_like(real_output), real_output) |
|
|
v_fake = metric(tf.zeros_like(fake_output), fake_output) |
|
|
v = 0.5*(v_real + v_fake) |
|
|
tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0)) |
|
|
real_tracker.update_state(tf.repeat([[v_real]], repeats=n_samples, axis=0)) |
|
|
fake_tracker.update_state(tf.repeat([[v_fake]], repeats=n_samples, axis=0)) |
|
|
|
|
|
metric_name = metric.__name__ |
|
|
logs['val_d_' + metric_name] = tracker.result() |
|
|
logs['val_d_real_' + metric_name] = real_tracker.result() |
|
|
logs['val_d_fake_' + metric_name] = fake_tracker.result() |
|
|
|
|
|
logs['val_g_loss'] = self.g_loss_tracker.result() |
|
|
logs['val_g_recon'] = self.g_recon_tracker.result() |
|
|
logs['val_g_disc'] = self.g_disc_tracker.result() |
|
|
|
|
|
for tracker, metric in self.g_metric_trackers: |
|
|
v = metric(original, generated_images) |
|
|
tracker.update_state(tf.repeat([[v]], repeats=n_samples, axis=0)) |
|
|
logs['val_g_' + metric.__name__] = tracker.result() |
|
|
|
|
|
return logs |
|
|
|
|
|
def fit(self, |
|
|
trainset, |
|
|
valset=None, |
|
|
trainsize=-1, |
|
|
valsize=-1, |
|
|
epochs=1, |
|
|
display_per_epochs=5, |
|
|
generator_callbacks=[], |
|
|
discriminator_callbacks=[]): |
|
|
|
|
|
print('🌊🐉 Start Training 🐉🌊') |
|
|
gen_callback_tracker = tf.keras.callbacks.CallbackList( |
|
|
generator_callbacks, add_history=True, model=self.generator |
|
|
) |
|
|
|
|
|
disc_callback_tracker = tf.keras.callbacks.CallbackList( |
|
|
discriminator_callbacks, add_history=True, model=self.discriminator |
|
|
) |
|
|
|
|
|
callbacks_tracker = MultipleTrackers([gen_callback_tracker, disc_callback_tracker]) |
|
|
|
|
|
logs = {} |
|
|
callbacks_tracker.on_train_begin(logs=logs) |
|
|
|
|
|
for epoch in range(epochs): |
|
|
print(f'Epochs {epoch+1}/{epochs}:') |
|
|
callbacks_tracker.on_epoch_begin(epoch, logs=logs) |
|
|
|
|
|
batches = tqdm(trainset, |
|
|
desc="Train", |
|
|
total=trainsize, |
|
|
unit="step", |
|
|
position=0, |
|
|
leave=True) |
|
|
|
|
|
for batch, image_batch in enumerate(batches): |
|
|
|
|
|
callbacks_tracker.on_batch_begin(batch, logs=logs) |
|
|
callbacks_tracker.on_train_batch_begin(batch, logs=logs) |
|
|
|
|
|
train_logs = {k:v.numpy() for k, v in self.train_step(image_batch).items()} |
|
|
logs.update(train_logs) |
|
|
|
|
|
callbacks_tracker.on_train_batch_end(batch, logs=logs) |
|
|
callbacks_tracker.on_batch_end(batch, logs=logs) |
|
|
batches.set_postfix({'d_loss': train_logs['d_loss'], |
|
|
'g_loss': train_logs['g_loss'] |
|
|
}) |
|
|
|
|
|
|
|
|
stats = ", ".join("{}={:.3g}".format(k, v) for k, v in logs.items() if 'val_' not in k and 'loss' not in k) |
|
|
print('Train:', stats) |
|
|
|
|
|
batches.close() |
|
|
if valset: |
|
|
self.all_trackers.reset_state() |
|
|
|
|
|
batches = tqdm(valset, |
|
|
desc="Valid", |
|
|
total=valsize, |
|
|
unit="step", |
|
|
position=0, |
|
|
leave=True) |
|
|
|
|
|
for batch, image_batch in enumerate(batches): |
|
|
callbacks_tracker.on_batch_begin(batch, logs=logs) |
|
|
callbacks_tracker.on_test_batch_begin(batch, logs=logs) |
|
|
val_logs = {k:v.numpy() for k, v in self.val_step(image_batch).items()} |
|
|
logs.update(val_logs) |
|
|
|
|
|
callbacks_tracker.on_test_batch_end(batch, logs=logs) |
|
|
callbacks_tracker.on_batch_end(batch, logs=logs) |
|
|
|
|
|
batches.set_postfix({'val_d_loss': val_logs['val_d_loss'], |
|
|
'val_g_loss': val_logs['val_g_loss'] |
|
|
}) |
|
|
|
|
|
stats = ", ".join("{}={:.3g}".format(k, v) for k, v in logs.items() if 'val_' in k and 'loss' not in k) |
|
|
print('Valid:', stats) |
|
|
|
|
|
batches.close() |
|
|
|
|
|
if epoch % display_per_epochs == 0: |
|
|
print('-'*128) |
|
|
self.visualize_samples((image_batch[0][:2], image_batch[1][:2])) |
|
|
|
|
|
self.all_trackers.reset_state() |
|
|
|
|
|
callbacks_tracker.on_epoch_end(epoch, logs=logs) |
|
|
|
|
|
_ = gc.collect() |
|
|
|
|
|
if self.generator.stop_training or self.discriminator.stop_training: |
|
|
break |
|
|
print('-'*128) |
|
|
|
|
|
callbacks_tracker.on_train_end(logs=logs) |
|
|
tf.keras.backend.clear_session() |
|
|
_ = gc.collect() |
|
|
gen_history = None |
|
|
for cb in gen_callback_tracker: |
|
|
if isinstance(cb, tf.keras.callbacks.History): |
|
|
gen_history = cb |
|
|
gen_history.history = {k:v for k,v in cb.history.items() if 'd_' not in k} |
|
|
|
|
|
disc_history = None |
|
|
for cb in disc_callback_tracker: |
|
|
if isinstance(cb, tf.keras.callbacks.History): |
|
|
disc_history = cb |
|
|
disc_history.history = {k:v for k,v in cb.history.items() if 'g_' not in k} |
|
|
|
|
|
return {'generator':gen_history, |
|
|
'discriminator':disc_history} |
|
|
|
|
|
def visualize_samples(self, samples, figsize=(12, 2)): |
|
|
x, y = samples |
|
|
y_pred = self.generator.predict(x[:2], verbose=0) |
|
|
fig, axs = plt.subplots(1, 6, figsize=figsize) |
|
|
for i in range(2): |
|
|
pos = 3*i |
|
|
axs[pos].imshow(x[i], cmap='gray', vmin=0., vmax=1.) |
|
|
axs[pos].set_title('Masked') |
|
|
axs[pos].axis('off') |
|
|
axs[pos+1].imshow(y[i], cmap='gray', vmin=0., vmax=1.) |
|
|
axs[pos+1].set_title('Original') |
|
|
axs[pos+1].axis('off') |
|
|
axs[pos+2].imshow(y_pred[i], cmap='gray', vmin=0., vmax=1.) |
|
|
axs[pos+2].set_title('Predicted') |
|
|
axs[pos+2].axis('off') |
|
|
plt.show() |
|
|
|
|
|
|
|
|
del y_pred |
|
|
_ = gc.collect() |