Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """Misc. functions for AVRA inference pipeline.""" | |
| from __future__ import division | |
| import torch | |
| import numpy as np | |
| import os | |
| import tempfile | |
| from collections import OrderedDict | |
| from model.model import AVRA_rnn, VGG_bl | |
| import nibabel | |
| import glob | |
| # Optional: only needed for --use-fsl | |
| def _get_fsl(): | |
| import nipype.interfaces.fsl as fsl | |
| return fsl | |
| def load_mri(file): | |
| a = nibabel.load(file) | |
| a = nibabel.as_closest_canonical(a) | |
| a = np.array(a.dataobj, dtype=np.float32) | |
| return a | |
| def load_settings_from_model(path, args): | |
| checkpoint = torch.load(path, map_location=torch.device('cpu')) | |
| args.size_x, args.size_y, args.size_z = checkpoint['size_x'], checkpoint['size_y'], checkpoint['size_z'] | |
| args.arch = checkpoint['arch'] | |
| args.mse = checkpoint['mse'] | |
| if 'offset_x' in checkpoint.keys(): | |
| args.offset_x, args.offset_y, args.offset_z = checkpoint['offset_x'], checkpoint['offset_y'], checkpoint['offset_z'] | |
| args.nc = checkpoint['nc'] | |
| else: | |
| args.nc = 1 | |
| if args.vrs == 'mta': | |
| args.offset_x, args.offset_y, args.offset_z = 0, 0, 4 | |
| elif args.vrs == 'gca-f': | |
| args.offset_x, args.offset_y, args.offset_z = 0, 5, 14 | |
| elif args.vrs == 'pa': | |
| args.offset_x, args.offset_y, args.offset_z = 0, -25, 5 | |
| return args | |
| def load_model(path, args): | |
| rating_scale = args.vrs | |
| if rating_scale == 'mta': | |
| classes = [0, 1, 2, 3, 4] | |
| else: | |
| classes = [0, 1, 2, 3] | |
| checkpoint = torch.load(path, map_location='cpu') | |
| mse = checkpoint['mse'] | |
| if mse: | |
| output_dim = 1 | |
| else: | |
| output_dim = np.size(classes) | |
| try: | |
| d = checkpoint['depth'] | |
| h = checkpoint['width'] | |
| x, y, z = h, h, d | |
| except Exception: | |
| x = checkpoint['size_x'] | |
| y = checkpoint['size_y'] | |
| is_data_dp = False | |
| model = AVRA_rnn([x, y, 1]) | |
| new_state_dict = OrderedDict() | |
| for k, v in checkpoint['state_dict'].items(): | |
| if k[:6] == 'module': | |
| is_data_dp = True | |
| name = k[7:] | |
| new_state_dict[name] = v | |
| if is_data_dp: | |
| model.load_state_dict(new_state_dict) | |
| else: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| return model.to(args.device) | |
| def _get_guid(path_to_img, guid): | |
| if guid: | |
| return guid | |
| native_img = os.path.basename(path_to_img) | |
| if 'nii.gz' in native_img: | |
| return os.path.splitext(os.path.splitext(native_img)[0])[0] | |
| return os.path.splitext(native_img)[0] | |
| def native_to_tal_python(path_to_img, force_new_transform=False, dof=6, output_folder='/tmp', guid='', remove_tmp_files=True): | |
| """ | |
| AC-PC alignment using Python only: nibabel (reorient) + SimpleITK (rigid registration). | |
| No FSL required. Uses nilearn's ICBM152 2009 1mm T1 as reference. | |
| """ | |
| import SimpleITK as sitk | |
| from nilearn import datasets as nilearn_datasets | |
| guid = _get_guid(path_to_img, guid) | |
| tal_img = guid + '_mni_dof_' + str(dof) + '.nii' | |
| tal_img_path = os.path.join(output_folder, tal_img) | |
| if os.path.exists(tal_img_path) and not force_new_transform: | |
| return | |
| # 1) Reorient to canonical (RAS) with nibabel | |
| img_nib = nibabel.load(path_to_img) | |
| if img_nib.ndim == 4: | |
| data = np.asarray(img_nib.dataobj, dtype=np.float32)[..., 0] | |
| img_nib = nibabel.Nifti1Image(data, img_nib.affine) | |
| canonical = nibabel.as_closest_canonical(img_nib) | |
| fd, tmp_reorient = tempfile.mkstemp(suffix='.nii.gz', prefix='avra_reorient_') | |
| os.close(fd) | |
| try: | |
| nibabel.save(canonical, tmp_reorient) | |
| moving_sitk = sitk.ReadImage(tmp_reorient, sitk.sitkFloat32) | |
| finally: | |
| if remove_tmp_files and os.path.exists(tmp_reorient): | |
| os.remove(tmp_reorient) | |
| # 2) Load 1mm reference template (nilearn ICBM152 2009) | |
| template_bunch = nilearn_datasets.fetch_icbm152_2009(verbose=0) | |
| template_path = template_bunch['t1'] | |
| fixed_sitk = sitk.ReadImage(template_path, sitk.sitkFloat32) | |
| # 3) Rigid registration (6 DOF) with SimpleITK | |
| initial_tx = sitk.CenteredTransformInitializer( | |
| fixed_sitk, moving_sitk, | |
| sitk.Euler3DTransform(), | |
| sitk.CenteredTransformInitializerFilter.GEOMETRY, | |
| ) | |
| R = sitk.ImageRegistrationMethod() | |
| R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) | |
| R.SetOptimizerAsRegularStepGradientDescent( | |
| learningRate=1.0, | |
| minStep=1e-4, | |
| numberOfIterations=500, | |
| ) | |
| R.SetOptimizerScalesFromPhysicalShift() | |
| R.SetInitialTransform(initial_tx, inPlace=False) | |
| R.SetInterpolator(sitk.sitkLinear) | |
| tx = R.Execute(fixed_sitk, moving_sitk) | |
| # 4) Resample moving to fixed grid and save as NIfTI | |
| resampled = sitk.Resample(moving_sitk, fixed_sitk, tx, sitk.sitkLinear, 0.0) | |
| sitk.WriteImage(resampled, tal_img_path) | |
| def native_to_tal_fsl(path_to_img, force_new_transform=False, dof=6, output_folder='/tmp', guid='', remove_tmp_files=True): | |
| """AC-PC alignment using FSL (requires FSL installed and FSLDIR set).""" | |
| from shutil import copyfile | |
| fsl = _get_fsl() | |
| native_img = os.path.basename(path_to_img) | |
| guid = _get_guid(path_to_img, guid) | |
| tal_img = guid + '_mni_dof_' + str(dof) + '.nii' | |
| bet_img = guid + '_bet.nii' | |
| bet_img_cp = guid + '_bet_cp.nii' | |
| tmp_img = guid + '_tmp.nii' | |
| tmp_img_path = os.path.join(output_folder, tmp_img) | |
| tal_img_path = os.path.join(output_folder, tal_img) | |
| bet_img_path = os.path.join(output_folder, bet_img) | |
| bet_img_path_cp = os.path.join(output_folder, bet_img_cp) | |
| xfm_path = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '.mat') | |
| xfm_path_cp = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '_cp.mat') | |
| xfm_path2 = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '_2.mat') | |
| try: | |
| fsl_path = os.environ['FSLDIR'] | |
| except KeyError: | |
| fsl_path = '/usr/local/fsl' | |
| print('FSLDIR not set. Using: ' + fsl_path) | |
| template_img = os.path.join(fsl_path, 'data', 'standard', 'MNI152_T1_1mm.nii.gz') | |
| tal_img_exist = os.path.exists(tal_img_path) | |
| xfm_exist = os.path.exists(xfm_path) | |
| fsl_1 = fsl.FLIRT() | |
| fsl_2 = fsl.FLIRT() | |
| fsl_pre = fsl.Reorient2Std() | |
| if not tal_img_exist or force_new_transform: | |
| fsl_pre.inputs.in_file = path_to_img | |
| fsl_pre.inputs.out_file = tmp_img_path | |
| fsl_pre.inputs.output_type = 'NIFTI' | |
| fsl_pre.run() | |
| btr = fsl.BET() | |
| btr.inputs.in_file = tmp_img_path | |
| btr.inputs.frac = 0.7 | |
| btr.inputs.out_file = bet_img_path | |
| btr.inputs.output_type = 'NIFTI' | |
| btr.inputs.robust = True | |
| btr.run() | |
| fsl_1.inputs.in_file = bet_img_path | |
| fsl_1.inputs.reference = template_img | |
| fsl_1.inputs.out_file = bet_img_path | |
| fsl_1.inputs.output_type = 'NIFTI' | |
| fsl_1.inputs.dof = dof | |
| fsl_1.inputs.out_matrix_file = xfm_path | |
| fsl_1.run() | |
| with open(xfm_path, 'r') as f: | |
| l = [[num for num in line.split(' ')] for line in f] | |
| matrix_1 = np.zeros((4, 4)) | |
| for m in range(4): | |
| for n in range(4): | |
| matrix_1[m, n] = float(l[m][n]) | |
| dist_1 = np.sum(np.square(np.diag(matrix_1) - 1)) | |
| dist_lim = .01 | |
| translate_lim = 30 | |
| if dist_1 > dist_lim or matrix_1[2, 3] > translate_lim: | |
| copyfile(bet_img_path, bet_img_path_cp) | |
| copyfile(xfm_path, xfm_path_cp) | |
| fsl_1.inputs.in_file = tmp_img_path | |
| fsl_1.run() | |
| with open(xfm_path, 'r') as f: | |
| l = [[num for num in line.split(' ')] for line in f] | |
| matrix_2 = np.zeros((4, 4)) | |
| for m in range(4): | |
| for n in range(4): | |
| matrix_2[m, n] = float(l[m][n]) | |
| dist_2 = np.sum(np.square(np.diag(matrix_2) - 1)) | |
| if (dist_1 < dist_lim and dist_2 < dist_lim): | |
| if matrix_1[2, 3] < matrix_2[2, 3]: | |
| xfm_path = xfm_path_cp | |
| elif dist_1 < dist_2: | |
| xfm_path = xfm_path_cp | |
| fsl_2.inputs.in_file = tmp_img_path | |
| fsl_2.inputs.reference = template_img | |
| fsl_2.inputs.out_file = tal_img_path | |
| fsl_2.inputs.output_type = 'NIFTI' | |
| fsl_2.inputs.in_matrix_file = xfm_path | |
| fsl_2.inputs.apply_xfm = True | |
| fsl_2.inputs.out_matrix_file = xfm_path2 | |
| fsl_2.run() | |
| if remove_tmp_files: | |
| for img in [tmp_img_path, bet_img_path, xfm_path2, bet_img_path_cp, xfm_path_cp]: | |
| if os.path.exists(img): | |
| os.remove(img) | |