Commit
·
2458333
1
Parent(s):
8c1dc9d
Train only on segmentation data
Browse files
EvaluationScripts/Evaluate_3d_weaklySupervised.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 3 |
+
parentdir = os.path.dirname(currentdir)
|
| 4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 5 |
+
|
| 6 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
import voxelmorph as vxm
|
| 11 |
+
import neurite as ne
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
|
| 14 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
| 15 |
+
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
| 16 |
+
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
| 17 |
+
from DeepDeformationMapRegistration.utils.nifty_utils import save_nifti
|
| 18 |
+
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
| 19 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
|
| 20 |
+
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
| 24 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
|
| 25 |
+
|
| 26 |
+
C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_vessels'
|
| 27 |
+
C.BATCH_SIZE = 2
|
| 28 |
+
C.LIMIT_NUM_SAMPLES = None
|
| 29 |
+
C.EPOCHS = 10000
|
| 30 |
+
|
| 31 |
+
# Load data
|
| 32 |
+
# Build data generator
|
| 33 |
+
|
| 34 |
+
data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
|
| 35 |
+
1 - C.TRAINING_PERC, voxelmorph=True, segmentations=True)
|
| 36 |
+
|
| 37 |
+
train_generator = data_generator.get_generator('train')
|
| 38 |
+
validation_generator = data_generator.get_generator('validation')
|
| 39 |
+
|
| 40 |
+
data_folder = '../train_3d_multiloss_segm_haus_dice_ncc_grad_203925-29012021'
|
| 41 |
+
|
| 42 |
+
# Build model
|
| 43 |
+
in_shape = train_generator.get_input_shape()[1:-1]
|
| 44 |
+
enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
|
| 45 |
+
dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
|
| 46 |
+
nb_features = [enc_features, dec_features]
|
| 47 |
+
vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=nb_features, int_steps=5)
|
| 48 |
+
vxm_model.load_weights(os.path.join(data_folder, 'checkpoints', 'best_model.h5'), by_name=True)
|
| 49 |
+
|
| 50 |
+
# Get some samples and plot them
|
| 51 |
+
sample = validation_generator[0]
|
| 52 |
+
|
| 53 |
+
samp_id = 1
|
| 54 |
+
pred_img, pred_seg, pred_flow = vxm_model.predict([sample[0][0][samp_id, ...][np.newaxis, ...],
|
| 55 |
+
sample[0][1][samp_id, ...][np.newaxis, ...],
|
| 56 |
+
sample[0][2][samp_id, ...][np.newaxis, ...]])
|
| 57 |
+
|
| 58 |
+
save_nifti(np.squeeze(pred_img), os.path.join(data_folder, 'pred_img.nii.gz'))
|
| 59 |
+
save_nifti(np.squeeze(pred_seg), os.path.join(data_folder, 'pred_seg.nii.gz'))
|
| 60 |
+
save_nifti(sample[0][0][samp_id, ...], os.path.join(data_folder, 'mov_seg.nii.gz'))
|
| 61 |
+
save_nifti(sample[0][1][samp_id, ...], os.path.join(data_folder, 'fix_seg.nii.gz'))
|
| 62 |
+
save_nifti(sample[0][2][samp_id, ...], os.path.join(data_folder, 'mov_img.nii.gz'))
|
| 63 |
+
save_nifti(sample[0][-2][samp_id, ...], os.path.join(data_folder, 'fix_img.nii.gz'))
|
TrainingScripts/Train_3d_weaklySupervised.py
CHANGED
|
@@ -17,7 +17,7 @@ import DeepDeformationMapRegistration.utils.constants as C
|
|
| 17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
| 18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
| 19 |
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
| 20 |
-
from DeepDeformationMapRegistration.losses import
|
| 21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
| 22 |
|
| 23 |
|
|
@@ -49,16 +49,16 @@ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=
|
|
| 49 |
# Losses and loss weights
|
| 50 |
|
| 51 |
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
| 52 |
-
fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
|
| 53 |
def dice_loss(y_true, y_pred):
|
| 54 |
# Dice().loss returns -Dice score
|
| 55 |
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
| 56 |
|
| 57 |
multiLoss = UncertaintyWeighting(num_loss_fns=2,
|
| 58 |
num_reg_fns=1,
|
| 59 |
-
loss_fns=[
|
| 60 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
| 61 |
-
prior_loss_w=[1., 1
|
| 62 |
prior_reg_w=[0.01],
|
| 63 |
name='MultiLossLayer')
|
| 64 |
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
|
|
@@ -66,7 +66,7 @@ loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
|
|
| 66 |
grad,
|
| 67 |
vxm_model.references.pos_flow])
|
| 68 |
|
| 69 |
-
full_model = tf.keras.Model(inputs=vxm_model.inputs + [
|
| 70 |
|
| 71 |
# Compile the model
|
| 72 |
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|
|
|
|
| 17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
| 18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
| 19 |
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
| 20 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
|
| 21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
| 22 |
|
| 23 |
|
|
|
|
| 49 |
# Losses and loss weights
|
| 50 |
|
| 51 |
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
| 52 |
+
# fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
|
| 53 |
def dice_loss(y_true, y_pred):
|
| 54 |
# Dice().loss returns -Dice score
|
| 55 |
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
| 56 |
|
| 57 |
multiLoss = UncertaintyWeighting(num_loss_fns=2,
|
| 58 |
num_reg_fns=1,
|
| 59 |
+
loss_fns=[HausdorffDistanceErosion(3, 5).loss, dice_loss],
|
| 60 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
| 61 |
+
prior_loss_w=[1., 1.],
|
| 62 |
prior_reg_w=[0.01],
|
| 63 |
name='MultiLossLayer')
|
| 64 |
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
|
|
|
|
| 66 |
grad,
|
| 67 |
vxm_model.references.pos_flow])
|
| 68 |
|
| 69 |
+
full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
|
| 70 |
|
| 71 |
# Compile the model
|
| 72 |
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|