# -*- coding: utf-8 -*- """Misc. functions for AVRA inference pipeline.""" from __future__ import division import torch import numpy as np import os import tempfile from collections import OrderedDict from model.model import AVRA_rnn, VGG_bl import nibabel import glob # Optional: only needed for --use-fsl def _get_fsl(): import nipype.interfaces.fsl as fsl return fsl def load_mri(file): a = nibabel.load(file) a = nibabel.as_closest_canonical(a) a = np.array(a.dataobj, dtype=np.float32) return a def load_settings_from_model(path, args): checkpoint = torch.load(path, map_location=torch.device('cpu')) args.size_x, args.size_y, args.size_z = checkpoint['size_x'], checkpoint['size_y'], checkpoint['size_z'] args.arch = checkpoint['arch'] args.mse = checkpoint['mse'] if 'offset_x' in checkpoint.keys(): args.offset_x, args.offset_y, args.offset_z = checkpoint['offset_x'], checkpoint['offset_y'], checkpoint['offset_z'] args.nc = checkpoint['nc'] else: args.nc = 1 if args.vrs == 'mta': args.offset_x, args.offset_y, args.offset_z = 0, 0, 4 elif args.vrs == 'gca-f': args.offset_x, args.offset_y, args.offset_z = 0, 5, 14 elif args.vrs == 'pa': args.offset_x, args.offset_y, args.offset_z = 0, -25, 5 return args def load_model(path, args): rating_scale = args.vrs if rating_scale == 'mta': classes = [0, 1, 2, 3, 4] else: classes = [0, 1, 2, 3] checkpoint = torch.load(path, map_location='cpu') mse = checkpoint['mse'] if mse: output_dim = 1 else: output_dim = np.size(classes) try: d = checkpoint['depth'] h = checkpoint['width'] x, y, z = h, h, d except Exception: x = checkpoint['size_x'] y = checkpoint['size_y'] is_data_dp = False model = AVRA_rnn([x, y, 1]) new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): if k[:6] == 'module': is_data_dp = True name = k[7:] new_state_dict[name] = v if is_data_dp: model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint['state_dict']) return model.to(args.device) def _get_guid(path_to_img, guid): if guid: return guid native_img = os.path.basename(path_to_img) if 'nii.gz' in native_img: return os.path.splitext(os.path.splitext(native_img)[0])[0] return os.path.splitext(native_img)[0] def native_to_tal_python(path_to_img, force_new_transform=False, dof=6, output_folder='/tmp', guid='', remove_tmp_files=True): """ AC-PC alignment using Python only: nibabel (reorient) + SimpleITK (rigid registration). No FSL required. Uses nilearn's ICBM152 2009 1mm T1 as reference. """ import SimpleITK as sitk from nilearn import datasets as nilearn_datasets guid = _get_guid(path_to_img, guid) tal_img = guid + '_mni_dof_' + str(dof) + '.nii' tal_img_path = os.path.join(output_folder, tal_img) if os.path.exists(tal_img_path) and not force_new_transform: return # 1) Reorient to canonical (RAS) with nibabel img_nib = nibabel.load(path_to_img) if img_nib.ndim == 4: data = np.asarray(img_nib.dataobj, dtype=np.float32)[..., 0] img_nib = nibabel.Nifti1Image(data, img_nib.affine) canonical = nibabel.as_closest_canonical(img_nib) fd, tmp_reorient = tempfile.mkstemp(suffix='.nii.gz', prefix='avra_reorient_') os.close(fd) try: nibabel.save(canonical, tmp_reorient) moving_sitk = sitk.ReadImage(tmp_reorient, sitk.sitkFloat32) finally: if remove_tmp_files and os.path.exists(tmp_reorient): os.remove(tmp_reorient) # 2) Load 1mm reference template (nilearn ICBM152 2009) template_bunch = nilearn_datasets.fetch_icbm152_2009(verbose=0) template_path = template_bunch['t1'] fixed_sitk = sitk.ReadImage(template_path, sitk.sitkFloat32) # 3) Rigid registration (6 DOF) with SimpleITK initial_tx = sitk.CenteredTransformInitializer( fixed_sitk, moving_sitk, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY, ) R = sitk.ImageRegistrationMethod() R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) R.SetOptimizerAsRegularStepGradientDescent( learningRate=1.0, minStep=1e-4, numberOfIterations=500, ) R.SetOptimizerScalesFromPhysicalShift() R.SetInitialTransform(initial_tx, inPlace=False) R.SetInterpolator(sitk.sitkLinear) tx = R.Execute(fixed_sitk, moving_sitk) # 4) Resample moving to fixed grid and save as NIfTI resampled = sitk.Resample(moving_sitk, fixed_sitk, tx, sitk.sitkLinear, 0.0) sitk.WriteImage(resampled, tal_img_path) def native_to_tal_fsl(path_to_img, force_new_transform=False, dof=6, output_folder='/tmp', guid='', remove_tmp_files=True): """AC-PC alignment using FSL (requires FSL installed and FSLDIR set).""" from shutil import copyfile fsl = _get_fsl() native_img = os.path.basename(path_to_img) guid = _get_guid(path_to_img, guid) tal_img = guid + '_mni_dof_' + str(dof) + '.nii' bet_img = guid + '_bet.nii' bet_img_cp = guid + '_bet_cp.nii' tmp_img = guid + '_tmp.nii' tmp_img_path = os.path.join(output_folder, tmp_img) tal_img_path = os.path.join(output_folder, tal_img) bet_img_path = os.path.join(output_folder, bet_img) bet_img_path_cp = os.path.join(output_folder, bet_img_cp) xfm_path = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '.mat') xfm_path_cp = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '_cp.mat') xfm_path2 = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '_2.mat') try: fsl_path = os.environ['FSLDIR'] except KeyError: fsl_path = '/usr/local/fsl' print('FSLDIR not set. Using: ' + fsl_path) template_img = os.path.join(fsl_path, 'data', 'standard', 'MNI152_T1_1mm.nii.gz') tal_img_exist = os.path.exists(tal_img_path) xfm_exist = os.path.exists(xfm_path) fsl_1 = fsl.FLIRT() fsl_2 = fsl.FLIRT() fsl_pre = fsl.Reorient2Std() if not tal_img_exist or force_new_transform: fsl_pre.inputs.in_file = path_to_img fsl_pre.inputs.out_file = tmp_img_path fsl_pre.inputs.output_type = 'NIFTI' fsl_pre.run() btr = fsl.BET() btr.inputs.in_file = tmp_img_path btr.inputs.frac = 0.7 btr.inputs.out_file = bet_img_path btr.inputs.output_type = 'NIFTI' btr.inputs.robust = True btr.run() fsl_1.inputs.in_file = bet_img_path fsl_1.inputs.reference = template_img fsl_1.inputs.out_file = bet_img_path fsl_1.inputs.output_type = 'NIFTI' fsl_1.inputs.dof = dof fsl_1.inputs.out_matrix_file = xfm_path fsl_1.run() with open(xfm_path, 'r') as f: l = [[num for num in line.split(' ')] for line in f] matrix_1 = np.zeros((4, 4)) for m in range(4): for n in range(4): matrix_1[m, n] = float(l[m][n]) dist_1 = np.sum(np.square(np.diag(matrix_1) - 1)) dist_lim = .01 translate_lim = 30 if dist_1 > dist_lim or matrix_1[2, 3] > translate_lim: copyfile(bet_img_path, bet_img_path_cp) copyfile(xfm_path, xfm_path_cp) fsl_1.inputs.in_file = tmp_img_path fsl_1.run() with open(xfm_path, 'r') as f: l = [[num for num in line.split(' ')] for line in f] matrix_2 = np.zeros((4, 4)) for m in range(4): for n in range(4): matrix_2[m, n] = float(l[m][n]) dist_2 = np.sum(np.square(np.diag(matrix_2) - 1)) if (dist_1 < dist_lim and dist_2 < dist_lim): if matrix_1[2, 3] < matrix_2[2, 3]: xfm_path = xfm_path_cp elif dist_1 < dist_2: xfm_path = xfm_path_cp fsl_2.inputs.in_file = tmp_img_path fsl_2.inputs.reference = template_img fsl_2.inputs.out_file = tal_img_path fsl_2.inputs.output_type = 'NIFTI' fsl_2.inputs.in_matrix_file = xfm_path fsl_2.inputs.apply_xfm = True fsl_2.inputs.out_matrix_file = xfm_path2 fsl_2.run() if remove_tmp_files: for img in [tmp_img_path, bet_img_path, xfm_path2, bet_img_path_cp, xfm_path_cp]: if os.path.exists(img): os.remove(img)