File size: 5,130 Bytes
44facea
539bc34
 
 
 
 
 
55880f9
539bc34
 
44facea
55880f9
539bc34
 
4f9da36
 
 
 
539bc34
 
 
 
 
 
 
 
 
 
 
66269ec
4f9da36
 
55880f9
 
 
 
539bc34
 
 
 
66269ec
539bc34
55880f9
 
 
539bc34
 
 
4f9da36
 
 
 
 
 
 
 
 
 
 
539bc34
 
 
 
 
4f9da36
 
 
 
539bc34
44facea
539bc34
 
 
 
 
4f9da36
 
 
 
 
 
 
 
 
 
539bc34
66269ec
539bc34
 
 
 
 
 
 
b42b662
 
 
 
539bc34
66269ec
539bc34
55880f9
 
 
 
 
 
 
539bc34
 
 
 
 
 
 
 
44facea
 
 
539bc34
 
44facea
 
 
66269ec
539bc34
 
 
 
 
 
 
66269ec
 
539bc34
 
4f9da36
 
 
539bc34
 
55880f9
66269ec
55880f9
 
 
 
 
 
 
 
 
 
 
66269ec
55880f9
 
 
 
4f9da36
 
 
 
 
 
 
 
 
539bc34
 
 
 
 
44facea
 
 
 
 
539bc34
4f9da36
55880f9
539bc34
55880f9
4f9da36
44facea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import argparse
import logging
import os
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple

import nibabel as nib

from pyment.models.sfcn import sfcn_factory
from pyment.preprocessing.conform import conform

logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
    level=logging.DEBUG
)
logger = logging.getLogger(__name__)

def _parse_folder_name(name: str) -> Tuple[str, str, str]:
    match = re.fullmatch(r'sub-(.*)_ses-(.*)_run-([^_]).*', name)

    if not match:
        return None, None, None

    return match.groups()

def predict_from_fastsurfer_folder(
    source: str,
    folders: List[str] = None,
    weights: str = None,
    model_name: str = 'sfcn-multi',
    targets: List[str] = [
        'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism'
    ],
    destination: str = None
) -> pd.DataFrame:
    if destination is not None and os.path.isfile(destination):
        raise ValueError(f'Destination {destination} already exists')

    logger.info('Loading multi-task model with weights %s', weights)

    model_class = sfcn_factory(model_name)
    model = model_class(weights=weights)

    results = []

    logger.info(f'Reading fastsurfer folders from {source}')

    folders = (
        folders if folders is not None
        else [
            folder for folder in os.listdir(source)
            if os.path.isdir(os.path.join(source, folder))
        ]
    )

    for folder in tqdm(folders):
        orig = os.path.join(source, folder, 'mri', 'orig.mgz')

        subject, session, run = _parse_folder_name(folder)

        if not os.path.isfile(orig):
            logger.warning(
                'No orig.mgz file for folder %s',
                os.path.join(source, folder)
            )
            continue

        orig = nib.load(orig)
        orig = orig.get_fdata()
        brainmask = os.path.join(source, folder, 'mri', 'mask.mgz')

        if not os.path.isfile(brainmask):
            logger.warning(
                'No mask.mgz file for folder %s',
                os.path.join(source, folder)
            )
            continue

        try:
            brainmask = nib.load(brainmask)
        except Exception as e:
            logger.error('Error loading brainmask for folder %s: %s', folder, e)
            continue

        brainmask = brainmask.get_fdata()

        image = orig * brainmask

        logger.debug('Conforming image from %s', os.path.join(source, folder))
        image = conform(image)

        predictions = model.predict(
            np.expand_dims(image, axis=0),
            verbose=0
        )[0]
        logger.debug('Predictions for %s: %s', folder, str(predictions))

        results.append({
            **{
                'source': os.path.join(source, folder),
                'subject': subject,
                'session': session,
                'run': run
            },
            **{targets[i]: predictions[i] for i in range(len(targets))}
        })

    results = pd.DataFrame(results)

    if destination is not None:
        results.to_csv(destination, index=False)

    return results

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        'Generates multi-task predictions for preprocessed images organized '
        'in a FastSurfer folder'
    )

    parser.add_argument(
        'root',
        help=(
            'Path to FastSurfer folder. Should contain subfolders that have '
            'an \'mri\' subfolder that contains files orig.mgz and mask.mgz'
        )
    )
    parser.add_argument(
        '-w', '--weights',
        required=False,
        default='multi-2025',
        help=(
            'Weights to use. Should either point to a local file path, or a '
            'known identifier. If a local file path <path> is used, there '
            'should exist files named <path>.index and '
            '<path>.data-00000-of-00001'
        )
    )
    parser.add_argument(
        '-m', '--model',
        required=False,
        default='sfcn-multi',
        help=(
            'Name of the model to use'
        )
    )
    parser.add_argument(
        '-t', '--targets',
        required=False,
        nargs='+',
        default=[
            'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
            'neuroticism'
        ],
        help='Name to use for each of the prediction heads in the output CSV'
    )
    parser.add_argument(
        '-f', '--folders',
        default=None,
        nargs='+',
        help=(
            'List of folders to process. If not provided, all folders in '
            'the source folder will be processed.'
        )
    )
    parser.add_argument(
        '-d', '--destination',
        required=False,
        default=None,
        help='Path where CSV with predictions are written'
    )

    args = parser.parse_args()

    predict_from_fastsurfer_folder(
        source=args.root,
        folders=args.folders,
        model_name=args.model,
        weights=args.weights,
        targets=args.targets,
        destination=args.destination,
    )