Commit
·
0978cbc
1
Parent(s):
f4c45ef
Created convenient function for loading models
Browse files
DeepDeformationMapRegistration/main.py
CHANGED
|
@@ -7,9 +7,9 @@ import subprocess
|
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
|
| 10 |
-
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 11 |
-
parentdir = os.path.dirname(currentdir)
|
| 12 |
-
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 13 |
|
| 14 |
import tensorflow as tf
|
| 15 |
|
|
@@ -29,6 +29,7 @@ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimila
|
|
| 29 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
| 30 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
| 31 |
from DeepDeformationMapRegistration.utils.model_downloader import get_models_path
|
|
|
|
| 32 |
|
| 33 |
from importlib.util import find_spec
|
| 34 |
|
|
@@ -284,39 +285,7 @@ def main():
|
|
| 284 |
LOGGER.info(f'Using model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
|
| 285 |
MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
|
| 286 |
|
| 287 |
-
|
| 288 |
-
# network = tf.keras.models.load_model(MODEL_FILE,
|
| 289 |
-
# {'VxmDense': vxm.networks.VxmDense,
|
| 290 |
-
# # 'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
| 291 |
-
# 'AdamAccumulated': AdamAccumulated
|
| 292 |
-
# },
|
| 293 |
-
# compile=False)
|
| 294 |
-
# except ValueError as e:
|
| 295 |
-
# enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
|
| 296 |
-
# dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
|
| 297 |
-
# nb_features = [enc_features, dec_features]
|
| 298 |
-
# if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
| 299 |
-
# network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
|
| 300 |
-
# nb_unet_features=nb_features,
|
| 301 |
-
# int_steps=0,
|
| 302 |
-
# int_downsize=1,
|
| 303 |
-
# seg_downsize=1)
|
| 304 |
-
# else:
|
| 305 |
-
# network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
|
| 306 |
-
# nb_unet_features=nb_features,
|
| 307 |
-
# int_steps=0)
|
| 308 |
-
# network.load_weights(MODEL_FILE, by_name=True)
|
| 309 |
-
|
| 310 |
-
enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
|
| 311 |
-
dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
|
| 312 |
-
nb_features = [enc_features, dec_features]
|
| 313 |
-
network = vxm.networks.VxmDense(inshape=C.IMG_SHAPE[:-1],
|
| 314 |
-
nb_unet_features=nb_features,
|
| 315 |
-
int_steps=0)
|
| 316 |
-
network.load_weights(MODEL_FILE, by_name=True)
|
| 317 |
-
network.trainable = False
|
| 318 |
-
|
| 319 |
-
registration_model = network.get_registration_model()
|
| 320 |
deb_model = network.apply_transform
|
| 321 |
|
| 322 |
LOGGER.info('Computing registration')
|
|
|
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
|
| 10 |
+
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 11 |
+
# parentdir = os.path.dirname(currentdir)
|
| 12 |
+
# sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 13 |
|
| 14 |
import tensorflow as tf
|
| 15 |
|
|
|
|
| 29 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
| 30 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
| 31 |
from DeepDeformationMapRegistration.utils.model_downloader import get_models_path
|
| 32 |
+
from DeepDeformationMapRegistration.networks import load_model
|
| 33 |
|
| 34 |
from importlib.util import find_spec
|
| 35 |
|
|
|
|
| 285 |
LOGGER.info(f'Using model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
|
| 286 |
MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
|
| 287 |
|
| 288 |
+
network, registration_model = load_model(MODEL_FILE, False, True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
deb_model = network.apply_transform
|
| 290 |
|
| 291 |
LOGGER.info('Computing registration')
|
DeepDeformationMapRegistration/networks.py
CHANGED
|
@@ -1,14 +1,31 @@
|
|
| 1 |
import os, sys
|
| 2 |
-
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 3 |
-
parentdir = os.path.dirname(currentdir)
|
| 4 |
-
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 5 |
-
|
| 6 |
-
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
| 7 |
|
| 8 |
import tensorflow as tf
|
| 9 |
import voxelmorph as vxm
|
| 10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
| 11 |
from tensorflow.keras.layers import UpSampling3D
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class WeaklySupervised(LoadableModel):
|
|
|
|
| 1 |
import os, sys
|
| 2 |
+
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 3 |
+
# parentdir = os.path.dirname(currentdir)
|
| 4 |
+
# sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 5 |
+
#
|
| 6 |
+
# PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
| 7 |
|
| 8 |
import tensorflow as tf
|
| 9 |
import voxelmorph as vxm
|
| 10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
| 11 |
from tensorflow.keras.layers import UpSampling3D
|
| 12 |
+
from DeepDeformationMapRegistration.utils.constants import ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
|
| 16 |
+
assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
|
| 17 |
+
assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
|
| 18 |
+
|
| 19 |
+
ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
|
| 20 |
+
nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
|
| 21 |
+
int_steps=0)
|
| 22 |
+
ret_val.load_weights(weights_file_path, by_name=True)
|
| 23 |
+
ret_val.trainable = trainable
|
| 24 |
+
|
| 25 |
+
if return_registration_model:
|
| 26 |
+
ret_val = (ret_val, ret_val.get_registration_model())
|
| 27 |
+
|
| 28 |
+
return ret_val
|
| 29 |
|
| 30 |
|
| 31 |
class WeaklySupervised(LoadableModel):
|
DeepDeformationMapRegistration/utils/constants.py
CHANGED
|
@@ -196,8 +196,8 @@ DROPOUT = True
|
|
| 196 |
DROPOUT_RATE = 0.2
|
| 197 |
MAX_DATA_SIZE = (1000, 1000, 1)
|
| 198 |
PLATEAU_THR = 0.01 # A slope between +-PLATEAU_THR will be considered a plateau for the LR updating function
|
| 199 |
-
ENCODER_FILTERS = [
|
| 200 |
-
|
| 201 |
# SSIM
|
| 202 |
SSIM_FILTER_SIZE = 11 # Size of Gaussian filter
|
| 203 |
SSIM_FILTER_SIGMA = 1.5 # Width of Gaussian filter
|
|
@@ -205,7 +205,7 @@ SSIM_K1 = 0.01 # Def. 0.01
|
|
| 205 |
SSIM_K2 = 0.03 # Recommended values 0 < K2 < 0.4
|
| 206 |
MAX_VALUE = 1.0 # Maximum intensity values
|
| 207 |
|
| 208 |
-
#
|
| 209 |
EPS = 1e-8
|
| 210 |
EPS_tf = tf.constant(EPS, dtype=tf.float32)
|
| 211 |
LOG2 = tf.math.log(tf.constant(2, dtype=tf.float32))
|
|
|
|
| 196 |
DROPOUT_RATE = 0.2
|
| 197 |
MAX_DATA_SIZE = (1000, 1000, 1)
|
| 198 |
PLATEAU_THR = 0.01 # A slope between +-PLATEAU_THR will be considered a plateau for the LR updating function
|
| 199 |
+
ENCODER_FILTERS = [32, 64, 128, 256, 512, 1024]
|
| 200 |
+
DECODER_FILTERS = ENCODER_FILTERS[::-1] + [16, 16]
|
| 201 |
# SSIM
|
| 202 |
SSIM_FILTER_SIZE = 11 # Size of Gaussian filter
|
| 203 |
SSIM_FILTER_SIGMA = 1.5 # Width of Gaussian filter
|
|
|
|
| 205 |
SSIM_K2 = 0.03 # Recommended values 0 < K2 < 0.4
|
| 206 |
MAX_VALUE = 1.0 # Maximum intensity values
|
| 207 |
|
| 208 |
+
# Mathematics constants
|
| 209 |
EPS = 1e-8
|
| 210 |
EPS_tf = tf.constant(EPS, dtype=tf.float32)
|
| 211 |
LOG2 = tf.math.log(tf.constant(2, dtype=tf.float32))
|