Commit
·
5b0dbe4
1
Parent(s):
61f0e36
SpatialTransfomer was embedded into a Keras model, and it is downloaded from the repo releases
Browse files
DeepDeformationMapRegistration/layers/SpatialTransformer.py
CHANGED
|
@@ -3,6 +3,9 @@ import tensorflow.keras.backend as K
|
|
| 3 |
import tensorflow as tf
|
| 4 |
import neurite as ne
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class SpatialTransformer(kl.Layer):
|
| 8 |
"""
|
|
@@ -184,3 +187,15 @@ class SpatialTransformer(kl.Layer):
|
|
| 184 |
# test single
|
| 185 |
return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import tensorflow as tf
|
| 4 |
import neurite as ne
|
| 5 |
|
| 6 |
+
import h5py
|
| 7 |
+
from DeepDeformationMapRegistration.utils.constants import IMG_SHAPE, DISP_MAP_SHAPE
|
| 8 |
+
|
| 9 |
|
| 10 |
class SpatialTransformer(kl.Layer):
|
| 11 |
"""
|
|
|
|
| 187 |
# test single
|
| 188 |
return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
|
| 189 |
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
output_file = './spatialtransformer.h5'
|
| 193 |
+
|
| 194 |
+
in_dm = tf.keras.Input(DISP_MAP_SHAPE)
|
| 195 |
+
in_image = tf.keras.Input(IMG_SHAPE)
|
| 196 |
+
pred = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([in_image, in_dm])
|
| 197 |
+
|
| 198 |
+
model = tf.keras.Model(inputs=[in_image, in_dm], outputs=pred)
|
| 199 |
+
|
| 200 |
+
model.save(output_file)
|
| 201 |
+
print(f"SpatialTransformer layer saved in: {output_file}")
|
DeepDeformationMapRegistration/main.py
CHANGED
|
@@ -19,7 +19,7 @@ import DeepDeformationMapRegistration.utils.constants as C
|
|
| 19 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
| 20 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
| 21 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
| 22 |
-
from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model
|
| 23 |
from DeepDeformationMapRegistration.utils.logger import LOGGER
|
| 24 |
|
| 25 |
from importlib.util import find_spec
|
|
@@ -279,8 +279,11 @@ def main():
|
|
| 279 |
|
| 280 |
LOGGER.info(f'Getting model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
|
| 281 |
MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
|
|
|
|
| 282 |
|
| 283 |
network, registration_model = load_model(MODEL_FILE, False, True)
|
|
|
|
|
|
|
| 284 |
|
| 285 |
LOGGER.info('Computing registration')
|
| 286 |
with sess.as_default():
|
|
@@ -297,7 +300,8 @@ def main():
|
|
| 297 |
|
| 298 |
LOGGER.info('Applying displacement map...')
|
| 299 |
time_pred_img_start = time.time()
|
| 300 |
-
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
|
|
|
| 301 |
time_pred_img_end = time.time()
|
| 302 |
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
| 303 |
pred_image = pred_image[0, ...]
|
|
|
|
| 19 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
| 20 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
| 21 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
| 22 |
+
from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model, get_spatialtransformer_model
|
| 23 |
from DeepDeformationMapRegistration.utils.logger import LOGGER
|
| 24 |
|
| 25 |
from importlib.util import find_spec
|
|
|
|
| 279 |
|
| 280 |
LOGGER.info(f'Getting model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
|
| 281 |
MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
|
| 282 |
+
ST_MODEL_FILE = get_spatialtransformer_model()
|
| 283 |
|
| 284 |
network, registration_model = load_model(MODEL_FILE, False, True)
|
| 285 |
+
spatialtransformer_model = tf.keras.models.load_model(ST_MODEL_FILE,
|
| 286 |
+
custom_objects={'SpatialTransformer': SpatialTransformer})
|
| 287 |
|
| 288 |
LOGGER.info('Computing registration')
|
| 289 |
with sess.as_default():
|
|
|
|
| 300 |
|
| 301 |
LOGGER.info('Applying displacement map...')
|
| 302 |
time_pred_img_start = time.time()
|
| 303 |
+
# pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
| 304 |
+
pred_image = spatialtransformer_model.predict([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]])
|
| 305 |
time_pred_img_end = time.time()
|
| 306 |
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
| 307 |
pred_image = pred_image[0, ...]
|
DeepDeformationMapRegistration/utils/model_utils.py
CHANGED
|
@@ -63,3 +63,16 @@ def load_model(weights_file_path: str, trainable: bool = False, return_registrat
|
|
| 63 |
ret_val = (ret_val, ret_val.get_registration_model())
|
| 64 |
|
| 65 |
return ret_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
ret_val = (ret_val, ret_val.get_registration_model())
|
| 64 |
|
| 65 |
return ret_val
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_spatialtransformer_model():
|
| 69 |
+
url = 'https://github.com/jpdefrutos/DDMR/releases/download/spatialtransformer_model_v0/spatialtransformer.h5'
|
| 70 |
+
file_path = os.path.join(os.getcwd(), 'models', 'spatialtransformer.h5')
|
| 71 |
+
if not os.path.exists(file_path):
|
| 72 |
+
LOGGER.info(f'Model not found. Downloading from {url}... ')
|
| 73 |
+
os.makedirs(os.path.split(file_path)[0], exist_ok=True)
|
| 74 |
+
download(url, file_path)
|
| 75 |
+
LOGGER.info(f'... downloaded model. Stored in {file_path}')
|
| 76 |
+
else:
|
| 77 |
+
LOGGER.info(f'Found model: {file_path}')
|
| 78 |
+
return file_path
|