File size: 8,671 Bytes
278bf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# -*- coding: utf-8 -*-
"""Misc. functions for AVRA inference pipeline."""
from __future__ import division
import torch
import numpy as np
import os
import tempfile
from collections import OrderedDict
from model.model import AVRA_rnn, VGG_bl
import nibabel
import glob

# Optional: only needed for --use-fsl
def _get_fsl():
    import nipype.interfaces.fsl as fsl
    return fsl


def load_mri(file):
    a = nibabel.load(file)
    a = nibabel.as_closest_canonical(a)
    a = np.array(a.dataobj, dtype=np.float32)
    return a


def load_settings_from_model(path, args):
    checkpoint = torch.load(path, map_location=torch.device('cpu'))
    args.size_x, args.size_y, args.size_z = checkpoint['size_x'], checkpoint['size_y'], checkpoint['size_z']
    args.arch = checkpoint['arch']
    args.mse = checkpoint['mse']
    if 'offset_x' in checkpoint.keys():
        args.offset_x, args.offset_y, args.offset_z = checkpoint['offset_x'], checkpoint['offset_y'], checkpoint['offset_z']
        args.nc = checkpoint['nc']
    else:
        args.nc = 1
        if args.vrs == 'mta':
            args.offset_x, args.offset_y, args.offset_z = 0, 0, 4
        elif args.vrs == 'gca-f':
            args.offset_x, args.offset_y, args.offset_z = 0, 5, 14
        elif args.vrs == 'pa':
            args.offset_x, args.offset_y, args.offset_z = 0, -25, 5
    return args


def load_model(path, args):
    rating_scale = args.vrs
    if rating_scale == 'mta':
        classes = [0, 1, 2, 3, 4]
    else:
        classes = [0, 1, 2, 3]
    checkpoint = torch.load(path, map_location='cpu')
    mse = checkpoint['mse']
    if mse:
        output_dim = 1
    else:
        output_dim = np.size(classes)
    try:
        d = checkpoint['depth']
        h = checkpoint['width']
        x, y, z = h, h, d
    except Exception:
        x = checkpoint['size_x']
        y = checkpoint['size_y']
    is_data_dp = False
    model = AVRA_rnn([x, y, 1])
    new_state_dict = OrderedDict()
    for k, v in checkpoint['state_dict'].items():
        if k[:6] == 'module':
            is_data_dp = True
        name = k[7:]
        new_state_dict[name] = v
    if is_data_dp:
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(checkpoint['state_dict'])
    return model.to(args.device)


def _get_guid(path_to_img, guid):
    if guid:
        return guid
    native_img = os.path.basename(path_to_img)
    if 'nii.gz' in native_img:
        return os.path.splitext(os.path.splitext(native_img)[0])[0]
    return os.path.splitext(native_img)[0]


def native_to_tal_python(path_to_img, force_new_transform=False, dof=6, output_folder='/tmp', guid='', remove_tmp_files=True):
    """
    AC-PC alignment using Python only: nibabel (reorient) + SimpleITK (rigid registration).
    No FSL required. Uses nilearn's ICBM152 2009 1mm T1 as reference.
    """
    import SimpleITK as sitk
    from nilearn import datasets as nilearn_datasets

    guid = _get_guid(path_to_img, guid)
    tal_img = guid + '_mni_dof_' + str(dof) + '.nii'
    tal_img_path = os.path.join(output_folder, tal_img)
    if os.path.exists(tal_img_path) and not force_new_transform:
        return

    # 1) Reorient to canonical (RAS) with nibabel
    img_nib = nibabel.load(path_to_img)
    if img_nib.ndim == 4:
        data = np.asarray(img_nib.dataobj, dtype=np.float32)[..., 0]
        img_nib = nibabel.Nifti1Image(data, img_nib.affine)
    canonical = nibabel.as_closest_canonical(img_nib)
    fd, tmp_reorient = tempfile.mkstemp(suffix='.nii.gz', prefix='avra_reorient_')
    os.close(fd)
    try:
        nibabel.save(canonical, tmp_reorient)
        moving_sitk = sitk.ReadImage(tmp_reorient, sitk.sitkFloat32)
    finally:
        if remove_tmp_files and os.path.exists(tmp_reorient):
            os.remove(tmp_reorient)

    # 2) Load 1mm reference template (nilearn ICBM152 2009)
    template_bunch = nilearn_datasets.fetch_icbm152_2009(verbose=0)
    template_path = template_bunch['t1']
    fixed_sitk = sitk.ReadImage(template_path, sitk.sitkFloat32)

    # 3) Rigid registration (6 DOF) with SimpleITK
    initial_tx = sitk.CenteredTransformInitializer(
        fixed_sitk, moving_sitk,
        sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY,
    )
    R = sitk.ImageRegistrationMethod()
    R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    R.SetOptimizerAsRegularStepGradientDescent(
        learningRate=1.0,
        minStep=1e-4,
        numberOfIterations=500,
    )
    R.SetOptimizerScalesFromPhysicalShift()
    R.SetInitialTransform(initial_tx, inPlace=False)
    R.SetInterpolator(sitk.sitkLinear)
    tx = R.Execute(fixed_sitk, moving_sitk)

    # 4) Resample moving to fixed grid and save as NIfTI
    resampled = sitk.Resample(moving_sitk, fixed_sitk, tx, sitk.sitkLinear, 0.0)
    sitk.WriteImage(resampled, tal_img_path)


def native_to_tal_fsl(path_to_img, force_new_transform=False, dof=6, output_folder='/tmp', guid='', remove_tmp_files=True):
    """AC-PC alignment using FSL (requires FSL installed and FSLDIR set)."""
    from shutil import copyfile
    fsl = _get_fsl()

    native_img = os.path.basename(path_to_img)
    guid = _get_guid(path_to_img, guid)
    tal_img = guid + '_mni_dof_' + str(dof) + '.nii'
    bet_img = guid + '_bet.nii'
    bet_img_cp = guid + '_bet_cp.nii'
    tmp_img = guid + '_tmp.nii'
    tmp_img_path = os.path.join(output_folder, tmp_img)
    tal_img_path = os.path.join(output_folder, tal_img)
    bet_img_path = os.path.join(output_folder, bet_img)
    bet_img_path_cp = os.path.join(output_folder, bet_img_cp)
    xfm_path = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '.mat')
    xfm_path_cp = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '_cp.mat')
    xfm_path2 = os.path.join(output_folder, guid + '_mni_dof_' + str(dof) + '_2.mat')
    try:
        fsl_path = os.environ['FSLDIR']
    except KeyError:
        fsl_path = '/usr/local/fsl'
        print('FSLDIR not set. Using: ' + fsl_path)
    template_img = os.path.join(fsl_path, 'data', 'standard', 'MNI152_T1_1mm.nii.gz')
    tal_img_exist = os.path.exists(tal_img_path)
    xfm_exist = os.path.exists(xfm_path)
    fsl_1 = fsl.FLIRT()
    fsl_2 = fsl.FLIRT()
    fsl_pre = fsl.Reorient2Std()

    if not tal_img_exist or force_new_transform:
        fsl_pre.inputs.in_file = path_to_img
        fsl_pre.inputs.out_file = tmp_img_path
        fsl_pre.inputs.output_type = 'NIFTI'
        fsl_pre.run()
        btr = fsl.BET()
        btr.inputs.in_file = tmp_img_path
        btr.inputs.frac = 0.7
        btr.inputs.out_file = bet_img_path
        btr.inputs.output_type = 'NIFTI'
        btr.inputs.robust = True
        btr.run()
        fsl_1.inputs.in_file = bet_img_path
        fsl_1.inputs.reference = template_img
        fsl_1.inputs.out_file = bet_img_path
        fsl_1.inputs.output_type = 'NIFTI'
        fsl_1.inputs.dof = dof
        fsl_1.inputs.out_matrix_file = xfm_path
        fsl_1.run()
        with open(xfm_path, 'r') as f:
            l = [[num for num in line.split('  ')] for line in f]
        matrix_1 = np.zeros((4, 4))
        for m in range(4):
            for n in range(4):
                matrix_1[m, n] = float(l[m][n])
        dist_1 = np.sum(np.square(np.diag(matrix_1) - 1))
        dist_lim = .01
        translate_lim = 30
        if dist_1 > dist_lim or matrix_1[2, 3] > translate_lim:
            copyfile(bet_img_path, bet_img_path_cp)
            copyfile(xfm_path, xfm_path_cp)
            fsl_1.inputs.in_file = tmp_img_path
            fsl_1.run()
            with open(xfm_path, 'r') as f:
                l = [[num for num in line.split('  ')] for line in f]
            matrix_2 = np.zeros((4, 4))
            for m in range(4):
                for n in range(4):
                    matrix_2[m, n] = float(l[m][n])
            dist_2 = np.sum(np.square(np.diag(matrix_2) - 1))
            if (dist_1 < dist_lim and dist_2 < dist_lim):
                if matrix_1[2, 3] < matrix_2[2, 3]:
                    xfm_path = xfm_path_cp
            elif dist_1 < dist_2:
                xfm_path = xfm_path_cp
        fsl_2.inputs.in_file = tmp_img_path
        fsl_2.inputs.reference = template_img
        fsl_2.inputs.out_file = tal_img_path
        fsl_2.inputs.output_type = 'NIFTI'
        fsl_2.inputs.in_matrix_file = xfm_path
        fsl_2.inputs.apply_xfm = True
        fsl_2.inputs.out_matrix_file = xfm_path2
        fsl_2.run()
        if remove_tmp_files:
            for img in [tmp_img_path, bet_img_path, xfm_path2, bet_img_path_cp, xfm_path_cp]:
                if os.path.exists(img):
                    os.remove(img)