Commit
·
a3cbfc7
1
Parent(s):
cd051bb
Added flag to return the results in the original input image resolution
Browse files
DeepDeformationMapRegistration/main.py
CHANGED
|
@@ -6,6 +6,7 @@ import argparse
|
|
| 6 |
import subprocess
|
| 7 |
import logging
|
| 8 |
import time
|
|
|
|
| 9 |
|
| 10 |
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 11 |
# parentdir = os.path.dirname(currentdir)
|
|
@@ -180,15 +181,17 @@ def main():
|
|
| 180 |
default=None)
|
| 181 |
parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
|
| 182 |
default='UW-NSD')
|
| 183 |
-
parser.add_argument('
|
| 184 |
parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
|
|
|
|
|
|
|
|
|
|
| 185 |
args = parser.parse_args()
|
| 186 |
|
| 187 |
assert os.path.exists(args.fixed), 'Fixed image not found'
|
| 188 |
assert os.path.exists(args.moving), 'Moving image not found'
|
| 189 |
assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
|
| 190 |
assert args.anatomy in C.ANATOMIES.keys(), 'Invalid anatomy option'
|
| 191 |
-
|
| 192 |
if os.path.exists(args.outputdir) and len(os.listdir(args.outputdir)):
|
| 193 |
if args.clear_outputdir:
|
| 194 |
erase = 'y'
|
|
@@ -217,6 +220,12 @@ def main():
|
|
| 217 |
LOGGER.setLevel('DEBUG')
|
| 218 |
LOGGER.debug('DEBUG MODE ENABLED')
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
# Load the file and preprocess it
|
| 221 |
LOGGER.info('Loading image files')
|
| 222 |
fixed_image_or = nib.load(args.fixed)
|
|
@@ -295,6 +304,7 @@ def main():
|
|
| 295 |
time_disp_map_start = time.time()
|
| 296 |
# disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
| 297 |
p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
|
|
|
| 298 |
time_disp_map_end = time.time()
|
| 299 |
LOGGER.info('\t... done')
|
| 300 |
debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
|
|
@@ -303,33 +313,38 @@ def main():
|
|
| 303 |
# pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
| 304 |
# fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
LOGGER.info('Applying displacement map...')
|
| 319 |
time_pred_img_start = time.time()
|
| 320 |
-
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([
|
| 321 |
time_pred_img_end = time.time()
|
| 322 |
LOGGER.info('\t... done')
|
| 323 |
|
| 324 |
LOGGER.info('Computing metrics...')
|
| 325 |
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
| 326 |
-
{'fix_img:0':
|
| 327 |
ssim = np.mean(ssim)
|
| 328 |
ms_ssim = ms_ssim[0]
|
| 329 |
pred_image = pred_image[0, ...]
|
| 330 |
|
| 331 |
save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
|
| 332 |
-
np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'),
|
| 333 |
LOGGER.info('Predicted image (full image) and displacement map saved in: '.format(args.outputdir))
|
| 334 |
LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
|
| 335 |
LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
|
|
@@ -340,15 +355,15 @@ def main():
|
|
| 340 |
LOGGER.info('MSE: {:.03f}'.format(mse))
|
| 341 |
LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
|
| 342 |
|
| 343 |
-
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
| 344 |
-
|
| 345 |
-
ssim = np.mean(ssim)
|
| 346 |
-
ms_ssim = ms_ssim[0]
|
| 347 |
-
LOGGER.info('\nSimilarity metrics (ROI)\n------------------')
|
| 348 |
-
LOGGER.info('SSIM: {:.03f}'.format(ssim))
|
| 349 |
-
LOGGER.info('NCC: {:.03f}'.format(ncc))
|
| 350 |
-
LOGGER.info('MSE: {:.03f}'.format(mse))
|
| 351 |
-
LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
|
| 352 |
|
| 353 |
del registration_model
|
| 354 |
LOGGER.info('Done')
|
|
|
|
| 6 |
import subprocess
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
+
import warnings
|
| 10 |
|
| 11 |
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 12 |
# parentdir = os.path.dirname(currentdir)
|
|
|
|
| 181 |
default=None)
|
| 182 |
parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
|
| 183 |
default='UW-NSD')
|
| 184 |
+
parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
|
| 185 |
parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
|
| 186 |
+
parser.add_argument('--original-resolution', action='store_true',
|
| 187 |
+
help='Re-scale the displacement map to the originla resolution and apply it to the original moving image. WARNING: longer processing time',
|
| 188 |
+
default=False)
|
| 189 |
args = parser.parse_args()
|
| 190 |
|
| 191 |
assert os.path.exists(args.fixed), 'Fixed image not found'
|
| 192 |
assert os.path.exists(args.moving), 'Moving image not found'
|
| 193 |
assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
|
| 194 |
assert args.anatomy in C.ANATOMIES.keys(), 'Invalid anatomy option'
|
|
|
|
| 195 |
if os.path.exists(args.outputdir) and len(os.listdir(args.outputdir)):
|
| 196 |
if args.clear_outputdir:
|
| 197 |
erase = 'y'
|
|
|
|
| 220 |
LOGGER.setLevel('DEBUG')
|
| 221 |
LOGGER.debug('DEBUG MODE ENABLED')
|
| 222 |
|
| 223 |
+
if args.original_resolution:
|
| 224 |
+
LOGGER.info('The results will be rescaled back to the original image resolution. '
|
| 225 |
+
'Expect longer post-processing times.')
|
| 226 |
+
else:
|
| 227 |
+
LOGGER.info(f'The results will NOT be rescaled. Output shape will be {C.IMG_SHAPE[:3]}.')
|
| 228 |
+
|
| 229 |
# Load the file and preprocess it
|
| 230 |
LOGGER.info('Loading image files')
|
| 231 |
fixed_image_or = nib.load(args.fixed)
|
|
|
|
| 304 |
time_disp_map_start = time.time()
|
| 305 |
# disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
| 306 |
p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
| 307 |
+
disp_map = np.squeeze(disp_map)
|
| 308 |
time_disp_map_end = time.time()
|
| 309 |
LOGGER.info('\t... done')
|
| 310 |
debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
|
|
|
|
| 313 |
# pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
| 314 |
# fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
| 315 |
|
| 316 |
+
if args.original_resolution:
|
| 317 |
+
# Up sample the displacement map to the full res
|
| 318 |
+
LOGGER.info('Scaling displacement map...')
|
| 319 |
+
trf = np.eye(4)
|
| 320 |
+
np.fill_diagonal(trf, 1/zoom_factors)
|
| 321 |
+
disp_map = resize_displacement_map(disp_map, None, trf)
|
| 322 |
+
debug_save_image(disp_map, 'disp_map_1_upsampled', args.outputdir, args.debug)
|
| 323 |
+
disp_map_or = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
|
| 324 |
+
debug_save_image(np.squeeze(disp_map_or), 'disp_map_2_padded', args.outputdir, args.debug)
|
| 325 |
+
disp_map_or = gaussian_filter(disp_map_or, 5)
|
| 326 |
+
debug_save_image(np.squeeze(disp_map_or), 'disp_map_3_smoothed', args.outputdir, args.debug)
|
| 327 |
+
LOGGER.info('\t... done')
|
| 328 |
+
|
| 329 |
+
moving_image = moving_image_or
|
| 330 |
+
fixed_image = fixed_image_or
|
| 331 |
+
disp_map = disp_map_or
|
| 332 |
|
| 333 |
LOGGER.info('Applying displacement map...')
|
| 334 |
time_pred_img_start = time.time()
|
| 335 |
+
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
| 336 |
time_pred_img_end = time.time()
|
| 337 |
LOGGER.info('\t... done')
|
| 338 |
|
| 339 |
LOGGER.info('Computing metrics...')
|
| 340 |
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
| 341 |
+
{'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': pred_image})
|
| 342 |
ssim = np.mean(ssim)
|
| 343 |
ms_ssim = ms_ssim[0]
|
| 344 |
pred_image = pred_image[0, ...]
|
| 345 |
|
| 346 |
save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
|
| 347 |
+
np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map)
|
| 348 |
LOGGER.info('Predicted image (full image) and displacement map saved in: '.format(args.outputdir))
|
| 349 |
LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
|
| 350 |
LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
|
|
|
|
| 355 |
LOGGER.info('MSE: {:.03f}'.format(mse))
|
| 356 |
LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
|
| 357 |
|
| 358 |
+
# ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
| 359 |
+
# {'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': p})
|
| 360 |
+
# ssim = np.mean(ssim)
|
| 361 |
+
# ms_ssim = ms_ssim[0]
|
| 362 |
+
# LOGGER.info('\nSimilarity metrics (ROI)\n------------------')
|
| 363 |
+
# LOGGER.info('SSIM: {:.03f}'.format(ssim))
|
| 364 |
+
# LOGGER.info('NCC: {:.03f}'.format(ncc))
|
| 365 |
+
# LOGGER.info('MSE: {:.03f}'.format(mse))
|
| 366 |
+
# LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
|
| 367 |
|
| 368 |
del registration_model
|
| 369 |
LOGGER.info('Done')
|