Commit
·
8c1dc9d
1
Parent(s):
371bf06
Removed NCC, training only on Hausdorff and DICE
Browse files
TrainingScripts/Train_3d_weaklySupervised.py
CHANGED
|
@@ -17,7 +17,7 @@ import DeepDeformationMapRegistration.utils.constants as C
|
|
| 17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
| 18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
| 19 |
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
| 20 |
-
from DeepDeformationMapRegistration.losses import
|
| 21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
| 22 |
|
| 23 |
|
|
@@ -49,24 +49,24 @@ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=
|
|
| 49 |
# Losses and loss weights
|
| 50 |
|
| 51 |
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
| 52 |
-
|
| 53 |
def dice_loss(y_true, y_pred):
|
| 54 |
# Dice().loss returns -Dice score
|
| 55 |
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
| 56 |
|
| 57 |
-
multiLoss = UncertaintyWeighting(num_loss_fns=
|
| 58 |
num_reg_fns=1,
|
| 59 |
-
loss_fns=[
|
| 60 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
| 61 |
-
prior_loss_w=[1., 1.],
|
| 62 |
prior_reg_w=[0.01],
|
| 63 |
name='MultiLossLayer')
|
| 64 |
-
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
|
| 65 |
-
vxm_model.references.pred_segm, vxm_model.references.pred_segm,
|
| 66 |
grad,
|
| 67 |
vxm_model.references.pos_flow])
|
| 68 |
|
| 69 |
-
full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
|
| 70 |
|
| 71 |
# Compile the model
|
| 72 |
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|
|
|
|
| 17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
| 18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
| 19 |
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
| 20 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistance
|
| 21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
| 22 |
|
| 23 |
|
|
|
|
| 49 |
# Losses and loss weights
|
| 50 |
|
| 51 |
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
| 52 |
+
fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
|
| 53 |
def dice_loss(y_true, y_pred):
|
| 54 |
# Dice().loss returns -Dice score
|
| 55 |
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
| 56 |
|
| 57 |
+
multiLoss = UncertaintyWeighting(num_loss_fns=2,
|
| 58 |
num_reg_fns=1,
|
| 59 |
+
loss_fns=[HausdorffDistance(3, 5).loss, dice_loss],
|
| 60 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
| 61 |
+
prior_loss_w=[1., 1., 1.],
|
| 62 |
prior_reg_w=[0.01],
|
| 63 |
name='MultiLossLayer')
|
| 64 |
+
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
|
| 65 |
+
vxm_model.references.pred_segm, vxm_model.references.pred_segm,
|
| 66 |
grad,
|
| 67 |
vxm_model.references.pos_flow])
|
| 68 |
|
| 69 |
+
full_model = tf.keras.Model(inputs=vxm_model.inputs + [fix_img, grad], outputs=vxm_model.outputs + [loss])
|
| 70 |
|
| 71 |
# Compile the model
|
| 72 |
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|