| import os, sys |
|
|
| currentdir = os.path.dirname(os.path.realpath(__file__)) |
| parentdir = os.path.dirname(currentdir) |
| sys.path.append(parentdir) |
|
|
| import numpy as np |
| import tensorflow as tf |
| from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau |
| from tensorflow.keras import Input |
| from tensorflow.keras.models import Model |
| from tensorflow.python.keras.utils import Progbar |
| from tensorflow.python.framework.errors import InvalidArgumentError |
|
|
| import voxelmorph as vxm |
| import neurite as ne |
| import h5py |
| from datetime import datetime |
| import pickle |
|
|
| import DeepDeformationMapRegistration.utils.constants as C |
| from DeepDeformationMapRegistration.utils.misc import try_mkdir, function_decorator |
| from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti |
| from DeepDeformationMapRegistration.losses import NCC, HausdorffDistanceErosion, GeneralizedDICEScore, StructuralSimilarity_simplified |
| from DeepDeformationMapRegistration.layers import AugmentationLayer |
| from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity, _MSSSIM_WEIGHTS |
| from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated |
|
|
| from Brain_study.data_generator import BatchGenerator |
| from Brain_study.utils import SummaryDictionary, named_logs |
|
|
| import time |
| import warnings |
| import re |
| import tqdm |
|
|
|
|
| def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd', |
| acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64, |
| unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None): |
|
|
| assert dataset_folder is not None and output_folder is not None |
|
|
| os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER |
| os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) |
| C.GPU_NUM = str(gpu_num) |
|
|
| if batch_size != 1 and acc_gradients != 1: |
| warnings.warn('WARNING: Batch size and Accumulative gradient step are set!') |
|
|
| if resume is not None: |
| try: |
| assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume |
| output_folder = resume |
| resume = True |
| except AssertionError: |
| output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")) |
| resume = False |
| else: |
| resume = False |
| os.makedirs(output_folder, exist_ok=True) |
| log_file = open(os.path.join(output_folder, 'log.txt'), 'w') |
| C.TRAINING_DATASET = dataset_folder |
| C.VALIDATION_DATASET = validation_folder |
| C.ACCUM_GRADIENT_STEP = acc_gradients |
| C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1 |
| C.EARLY_STOP_PATIENCE = early_stop_patience |
| C.LEARNING_RATE = lr |
| C.LIMIT_NUM_SAMPLES = None |
| C.EPOCHS = max_epochs |
|
|
| aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \ |
| "GPU: {}\n" \ |
| "BATCH SIZE: {}\n" \ |
| "LR: {}\n" \ |
| "SIMILARITY: {}\n" \ |
| "REG. WEIGHT: {}\n" \ |
| "EPOCHS: {:d}\n" \ |
| "ACCUM. GRAD: {}\n" \ |
| "EARLY STOP PATIENCE: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), |
| C.TRAINING_DATASET, |
| C.VALIDATION_DATASET, |
| C.GPU_NUM, |
| C.BATCH_SIZE, |
| C.LEARNING_RATE, |
| simil, |
| rw, |
| C.EPOCHS, |
| C.ACCUM_GRADIENT_STEP, |
| C.EARLY_STOP_PATIENCE) |
| log_file.write(aux) |
| print(aux) |
|
|
| |
| |
| data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True, |
| C.TRAINING_PERC, labels=['all'], combine_segmentations=False, |
| directory_val=C.VALIDATION_DATASET) |
|
|
| train_generator = data_generator.get_train_generator() |
|
|
| |
| |
| |
| validation_generator = data_generator.get_validation_generator() |
|
|
| image_input_shape = train_generator.get_data_shape()[1][:-1] |
| image_output_shape = [image_size] * 3 |
|
|
| nb_labels = len(train_generator.get_segmentation_labels()) |
|
|
| |
| config = tf.compat.v1.ConfigProto() |
| config.gpu_options.allow_growth = True |
| config.log_device_placement = False |
| |
| sess = tf.Session(config=config) |
| tf.keras.backend.set_session(sess) |
|
|
| |
| input_layer_augm = Input(shape=train_generator.get_data_shape()[0], name='input_augmentation') |
| augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, |
| max_deformation=C.MAX_AUG_DEF, |
| max_rotation=C.MAX_AUG_ANGLE, |
| num_control_points=C.NUM_CONTROL_PTS_AUG, |
| num_augmentations=C.NUM_AUGMENTATIONS, |
| gamma_augmentation=C.GAMMA_AUGMENTATION, |
| brightness_augmentation=C.BRIGHTNESS_AUGMENTATION, |
| in_img_shape=image_input_shape, |
| out_img_shape=image_output_shape, |
| only_image=False, |
| only_resize=False, |
| trainable=False) |
| augm_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm)) |
|
|
| |
| |
| enc_features = unet |
| dec_features = enc_features[::-1] + head |
| nb_features = [enc_features, dec_features] |
|
|
| network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape, |
| nb_labels=nb_labels, |
| nb_unet_features=nb_features, |
| int_steps=0, |
| int_downsize=1, |
| seg_downsize=1) |
| network.summary(line_length=C.SUMMARY_LINE_LENGTH) |
| network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.writelines) |
|
|
| resume_epoch = 0 |
| if resume: |
| cp_dir = os.path.join(output_folder, 'checkpoints') |
| cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))] |
| if len(cp_file_list): |
| cp_file_list.sort() |
| checkpoint_file = cp_file_list[-1] |
| if os.path.exists(checkpoint_file): |
| network.load_weights(checkpoint_file, by_name=True) |
| print('Loaded checkpoint file: ' + checkpoint_file) |
| try: |
| resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1]) |
| except TypeError: |
| |
| resume_epoch = 0 |
| print('Resuming from epoch: {:d}'.format(resume_epoch)) |
| else: |
| warnings.warn('Checkpoint file NOT found. Training from scratch') |
|
|
| |
| SSIM_KER_SIZE = 5 |
| MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3] |
| MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS) |
| if simil=='ssim': |
| loss_simil = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss |
| elif simil=='ms_ssim': |
| loss_simil = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss |
| elif simil=='ncc': |
| loss_simil = NCC(image_input_shape).loss |
| elif simil=='ms_ssim__ncc' or simil=='ncc__ms_ssim': |
| @function_decorator('MS_SSIM_NCC__loss') |
| def loss_simil(y_true, y_pred): |
| return MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) + NCC(image_input_shape).loss(y_true, y_pred) |
| elif simil=='ms_ssim__mse' or simil=='mse__ms_ssim': |
| @function_decorator('MS_SSIM_MSE__loss') |
| def loss_simil(y_true, y_pred): |
| return MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) + vxm.losses.MSE().loss(y_true, y_pred) |
| elif simil=='ssim__ncc' or simil=='ncc__ssim': |
| @function_decorator('SSIM_NCC__loss') |
| def loss_simil(y_true, y_pred): |
| return StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) + NCC(image_input_shape).loss(y_true, y_pred) |
| elif simil=='ssim__mse' or simil=='mse__ssim': |
| @function_decorator('SSIM_MSE__loss') |
| def loss_simil(y_true, y_pred): |
| return StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) + vxm.losses.MSE().loss(y_true, y_pred) |
| else: |
| loss_simil = vxm.losses.MSE().loss |
|
|
| if segm == 'hd': |
| loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss |
| elif segm == 'dice': |
| loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss |
| elif segm == 'dice_macro': |
| loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro |
| else: |
| raise ValueError('No valid value for segm') |
|
|
| losses = {'transformer': loss_simil, |
| 'seg_transformer': loss_segm, |
| 'flow': vxm.losses.Grad('l2').loss} |
| loss_weights = {'transformer': 1, |
| 'seg_transformer': 1., |
| 'flow': 5e-3} |
| metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric, StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric, |
| MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric], |
| 'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro, |
| |
| ]} |
| metrics_weights = {'transformer': 1, |
| 'seg_transformer': 1, |
| 'flow': rw} |
|
|
| |
| os.makedirs(output_folder, exist_ok=True) |
| os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True) |
| os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True) |
| os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True) |
|
|
| callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'), |
| save_best_only=True, monitor='val_loss', verbose=1, mode='min') |
| callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.{epoch:05d}-{val_loss:.2f}.h5'), |
| save_weights_only=True, monitor='val_loss', verbose=0, mode='min') |
| |
| |
| callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'), |
| batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0, |
| update_freq='epoch', |
| write_graph=True, write_grads=True |
| ) |
| callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001) |
| callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10) |
|
|
| |
| optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE) |
| network.compile(optimizer=optimizer, |
| loss=losses, |
| loss_weights=loss_weights, |
| metrics=metrics) |
|
|
| callback_tensorboard.set_model(network) |
| callback_best_model.set_model(network) |
| callback_save_model.set_model(network) |
| callback_early_stop.set_model(network) |
| callback_lr.set_model(network) |
| summary = SummaryDictionary(network, C.BATCH_SIZE, C.ACCUM_GRADIENT_STEP) |
| names = network.metrics_names |
| log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'))) |
| with sess.as_default(): |
| |
| callback_tensorboard.on_train_begin() |
| callback_early_stop.on_train_begin() |
| callback_best_model.on_train_begin() |
| callback_save_model.on_train_begin() |
| callback_lr.on_train_begin() |
| for epoch in range(resume_epoch, C.EPOCHS): |
| callback_tensorboard.on_epoch_begin(epoch) |
| callback_early_stop.on_epoch_begin(epoch) |
| callback_best_model.on_epoch_begin(epoch) |
| callback_save_model.on_epoch_begin(epoch) |
| callback_lr.on_epoch_begin(epoch) |
| print("\nEpoch {}/{}".format(epoch, C.EPOCHS)) |
| print('TRAINING') |
|
|
| log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch)) |
| progress_bar = Progbar(len(train_generator), width=30, verbose=1) |
| t0 = time.time() |
| for step, (in_batch, _) in enumerate(train_generator, 1): |
| |
| |
| callback_best_model.on_train_batch_begin(step) |
| callback_save_model.on_train_batch_begin(step) |
| callback_early_stop.on_train_batch_begin(step) |
| callback_lr.on_train_batch_begin(step) |
| try: |
| t0 = time.time() |
| fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch) |
|
|
| |
| np.nan_to_num(fix_img, copy=False) |
| np.nan_to_num(mov_img, copy=False) |
| if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)): |
| msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img), |
| np.unique(mov_img)) |
| print(msg) |
| log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg)) |
|
|
| except InvalidArgumentError as err: |
| print('TF Error : {}'.format(str(err))) |
| continue |
|
|
| t0 = time.time() |
| ret = network.train_on_batch(x=(mov_img, fix_img, mov_seg), |
| y=(fix_img, fix_img, fix_seg)) |
| |
|
|
| if np.isnan(ret).any(): |
| os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True) |
| save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz')) |
| save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz')) |
| pred_img, dm = network((mov_img, fix_img)) |
| save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz')) |
| save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz')) |
| log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'))) |
| raise ValueError('CORRUPTION ERROR: Halting training') |
|
|
| summary.on_train_batch_end(ret) |
| |
| callback_best_model.on_train_batch_end(step, named_logs(network, ret)) |
| callback_save_model.on_train_batch_end(step, named_logs(network, ret)) |
| callback_early_stop.on_train_batch_end(step, named_logs(network, ret)) |
| callback_lr.on_train_batch_end(step, named_logs(network, ret)) |
| progress_bar.update(step, zip(names, ret)) |
| log_file.write('\t\tStep {:03d}: {}'.format(step, ret)) |
| t0 = time.time() |
| print('End of epoch{}: '.format(step), ret, '\n') |
| val_values = progress_bar._values.copy() |
| ret = [val_values[x][0]/val_values[x][1] for x in names] |
|
|
| print('\nVALIDATION') |
| log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch)) |
| progress_bar = Progbar(len(validation_generator), width=30, verbose=1) |
| for step, (in_batch, _) in enumerate(validation_generator, 1): |
| |
| |
| try: |
| fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch) |
| except InvalidArgumentError as err: |
| print('TF Error : {}'.format(str(err))) |
| continue |
|
|
| ret = network.test_on_batch(x=(mov_img, fix_img, mov_seg), |
| y=(fix_img, fix_img, fix_seg)) |
| |
| summary.on_validation_batch_end(ret) |
| |
| |
| progress_bar.update(step, zip(names, ret)) |
| log_file.write('\t\tStep {:03d}: {}'.format(step, ret)) |
| val_values = progress_bar._values.copy() |
| ret = [val_values[x][0]/val_values[x][1] for x in names] |
|
|
| train_generator.on_epoch_end() |
| validation_generator.on_epoch_end() |
| epoch_summary = summary.on_epoch_end() |
| callback_tensorboard.on_epoch_end(epoch, epoch_summary) |
| callback_early_stop.on_epoch_end(epoch, epoch_summary) |
| callback_best_model.on_epoch_end(epoch, epoch_summary) |
| callback_save_model.on_epoch_end(epoch, epoch_summary) |
| callback_lr.on_epoch_end(epoch, epoch_summary) |
|
|
| callback_tensorboard.on_train_end() |
| callback_save_model.on_train_end() |
| callback_best_model.on_train_end() |
| callback_early_stop.on_train_end() |
| callback_lr.on_train_end() |
|
|
| if __name__ == '__main__': |
| os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER |
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
| config = tf.compat.v1.ConfigProto() |
| config.gpu_options.allow_growth = True |
| config.log_device_placement = False |
| tf.keras.backend.set_session(tf.Session(config=config)) |
|
|
| launch_train('/mnt/EncryptedData1/Users/javier/Brain_study/ERASE', |
| 'TrainOutput/THESIS/UW_None_mse_ssim_haus', |
| 0) |
|
|