salmasoma
Set up inference-only HyperClinical Streamlit app with runtime HF asset download
278bf2b
# -*- coding: utf-8 -*-
"""Misc. functions for AVRA inference pipeline."""
from __future__ import division
import torch
import numpy as np
import os
import tempfile
from collections import OrderedDict
from model.model import AVRA_rnn, VGG_bl
import nibabel
import glob
# Optional: only needed for --use-fsl
def _get_fsl():
import nipype.interfaces.fsl as fsl
return fsl
def load_mri(file):
a = nibabel.load(file)
a = nibabel.as_closest_canonical(a)
a = np.array(a.dataobj, dtype=np.float32)
return a
def load_settings_from_model(path, args):
checkpoint = torch.load(path, map_location=torch.device('cpu'))
args.size_x, args.size_y, args.size_z = checkpoint['size_x'], checkpoint['size_y'], checkpoint['size_z']
args.arch = checkpoint['arch']
args.mse = checkpoint['mse']
if 'offset_x' in checkpoint.keys():
args.offset_x, args.offset_y, args.offset_z = checkpoint['offset_x'], checkpoint['offset_y'], checkpoint['offset_z']
args.nc = checkpoint['nc']
else:
args.nc = 1
if args.vrs == 'mta':
args.offset_x, args.offset_y, args.offset_z = 0, 0, 4
elif args.vrs == 'gca-f':
args.offset_x, args.offset_y, args.offset_z = 0, 5, 14
elif args.vrs == 'pa':
args.offset_x, args.offset_y, args.offset_z = 0, -25, 5
return args
def load_model(path, args):
rating_scale = args.vrs
if rating_scale == 'mta':
classes = [0, 1, 2, 3, 4]
else:
classes = [0, 1, 2, 3]
checkpoint = torch.load(path, map_location='cpu')
mse = checkpoint['mse']
if mse:
output_dim = 1
else:
output_dim = np.size(classes)
try:
d = checkpoint['depth']
h = checkpoint['width']
x, y, z = h, h, d
except Exception:
x = checkpoint['size_x']
y = checkpoint['size_y']
is_data_dp = False
model = AVRA_rnn([x, y, 1])
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k[:6] == 'module':
is_data_dp = True
name = k[7:]
new_state_dict[name] = v
if is_data_dp:
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint['state_dict'])
return model.to(args.device)
def _get_guid(path_to_img, guid):
if guid:
return guid
native_img = os.path.basename(path_to_img)
if 'nii.gz' in native_img:
return os.path.splitext(os.path.splitext(native_img)[0])[0]
return os.path.splitext(native_img)[0]
def native_to_tal_python(path_to_img, force_new_transform=False, dof=6, output_folder='/tmp', guid='', remove_tmp_files=True):
"""
AC-PC alignment using Python only: nibabel (reorient) + SimpleITK (rigid registration).
No FSL required. Uses nilearn's ICBM152 2009 1mm T1 as reference.
"""
import SimpleITK as sitk
from nilearn import datasets as nilearn_datasets
guid = _get_guid(path_to_img, guid)
tal_img = guid + '_mni_dof_' + str(dof) + '.nii'
tal_img_path = os.path.join(output_folder, tal_img)
if os.path.exists(tal_img_path) and not force_new_transform:
return
# 1) Reorient to canonical (RAS) with nibabel
img_nib = nibabel.load(path_to_img)
if img_nib.ndim == 4:
data = np.asarray(img_nib.dataobj, dtype=np.float32)[..., 0]
img_nib = nibabel.Nifti1Image(data, img_nib.affine)
canonical = nibabel.as_closest_canonical(img_nib)
fd, tmp_reorient = tempfile.mkstemp(suffix='.nii.gz', prefix='avra_reorient_')
os.close(fd)
try:
nibabel.save(canonical, tmp_reorient)
moving_sitk = sitk.ReadImage(tmp_reorient, sitk.sitkFloat32)
finally:
if remove_tmp_files and os.path.exists(tmp_reorient):
os.remove(tmp_reorient)
# 2) Load 1mm reference template (nilearn ICBM152 2009)
template_bunch = nilearn_datasets.fetch_icbm152_2009(verbose=0)
template_path = template_bunch['t1']
fixed_sitk = sitk.ReadImage(template_path, sitk.sitkFloat32)
# 3) Rigid registration (6 DOF) with SimpleITK
initial_tx = sitk.CenteredTransformInitializer(
fixed_sitk, moving_sitk,
sitk.Euler3DTransform(),
sitk.CenteredTransformInitializerFilter.GEOMETRY,
)
R = sitk.ImageRegistrationMethod()
R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
R.SetOptimizerAsRegularStepGradientDescent(
learningRate=1.0,
minStep=1e-4,
numberOfIterations=500,
)
R.SetOptimizerScalesFromPhysicalShift()
R.SetInitialTransform(initial_tx, inPlace=False)
R.SetInterpolator(sitk.sitkLinear)
tx = R.Execute(fixed_sitk, moving_sitk)
# 4) Resample moving to fixed grid and save as NIfTI
resampled = sitk.Resample(moving_sitk, fixed_sitk, tx, sitk.sitkLinear, 0.0)
sitk.WriteImage(resampled, tal_img_path)
def native_to_tal_fsl(path_to_img, force_new_transform=False, dof=6, output_folder='/tmp', guid='', remove_tmp_files=True):
"""AC-PC alignment using FSL (requires FSL installed and FSLDIR set)."""
from shutil import copyfile
fsl = _get_fsl()
native_img = os.path.basename(path_to_img)
guid = _get_guid(path_to_img, guid)
tal_img = guid + '_mni_dof_' + str(dof) + '.nii'
bet_img = guid + '_bet.nii'
bet_img_cp = guid + '_bet_cp.nii'
tmp_img = guid + '_tmp.nii'
tmp_img_path = os.path.join(output_folder, tmp_img)
tal_img_path = os.path.join(output_folder, tal_img)
bet_img_path = os.path.join(output_folder, bet_img)
bet_img_path_cp = os.path.join(output_folder, bet_img_cp)
xfm_path = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '.mat')
xfm_path_cp = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '_cp.mat')
xfm_path2 = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '_2.mat')
try:
fsl_path = os.environ['FSLDIR']
except KeyError:
fsl_path = '/usr/local/fsl'
print('FSLDIR not set. Using: ' + fsl_path)
template_img = os.path.join(fsl_path, 'data', 'standard', 'MNI152_T1_1mm.nii.gz')
tal_img_exist = os.path.exists(tal_img_path)
xfm_exist = os.path.exists(xfm_path)
fsl_1 = fsl.FLIRT()
fsl_2 = fsl.FLIRT()
fsl_pre = fsl.Reorient2Std()
if not tal_img_exist or force_new_transform:
fsl_pre.inputs.in_file = path_to_img
fsl_pre.inputs.out_file = tmp_img_path
fsl_pre.inputs.output_type = 'NIFTI'
fsl_pre.run()
btr = fsl.BET()
btr.inputs.in_file = tmp_img_path
btr.inputs.frac = 0.7
btr.inputs.out_file = bet_img_path
btr.inputs.output_type = 'NIFTI'
btr.inputs.robust = True
btr.run()
fsl_1.inputs.in_file = bet_img_path
fsl_1.inputs.reference = template_img
fsl_1.inputs.out_file = bet_img_path
fsl_1.inputs.output_type = 'NIFTI'
fsl_1.inputs.dof = dof
fsl_1.inputs.out_matrix_file = xfm_path
fsl_1.run()
with open(xfm_path, 'r') as f:
l = [[num for num in line.split(' ')] for line in f]
matrix_1 = np.zeros((4, 4))
for m in range(4):
for n in range(4):
matrix_1[m, n] = float(l[m][n])
dist_1 = np.sum(np.square(np.diag(matrix_1) - 1))
dist_lim = .01
translate_lim = 30
if dist_1 > dist_lim or matrix_1[2, 3] > translate_lim:
copyfile(bet_img_path, bet_img_path_cp)
copyfile(xfm_path, xfm_path_cp)
fsl_1.inputs.in_file = tmp_img_path
fsl_1.run()
with open(xfm_path, 'r') as f:
l = [[num for num in line.split(' ')] for line in f]
matrix_2 = np.zeros((4, 4))
for m in range(4):
for n in range(4):
matrix_2[m, n] = float(l[m][n])
dist_2 = np.sum(np.square(np.diag(matrix_2) - 1))
if (dist_1 < dist_lim and dist_2 < dist_lim):
if matrix_1[2, 3] < matrix_2[2, 3]:
xfm_path = xfm_path_cp
elif dist_1 < dist_2:
xfm_path = xfm_path_cp
fsl_2.inputs.in_file = tmp_img_path
fsl_2.inputs.reference = template_img
fsl_2.inputs.out_file = tal_img_path
fsl_2.inputs.output_type = 'NIFTI'
fsl_2.inputs.in_matrix_file = xfm_path
fsl_2.inputs.apply_xfm = True
fsl_2.inputs.out_matrix_file = xfm_path2
fsl_2.run()
if remove_tmp_files:
for img in [tmp_img_path, bet_img_path, xfm_path2, bet_img_path_cp, xfm_path_cp]:
if os.path.exists(img):
os.remove(img)