File size: 5,130 Bytes
44facea 539bc34 55880f9 539bc34 44facea 55880f9 539bc34 4f9da36 539bc34 66269ec 4f9da36 55880f9 539bc34 66269ec 539bc34 55880f9 539bc34 4f9da36 539bc34 4f9da36 539bc34 44facea 539bc34 4f9da36 539bc34 66269ec 539bc34 b42b662 539bc34 66269ec 539bc34 55880f9 539bc34 44facea 539bc34 44facea 66269ec 539bc34 66269ec 539bc34 4f9da36 539bc34 55880f9 66269ec 55880f9 66269ec 55880f9 4f9da36 539bc34 44facea 539bc34 4f9da36 55880f9 539bc34 55880f9 4f9da36 44facea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import argparse
import logging
import os
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple
import nibabel as nib
from pyment.models.sfcn import sfcn_factory
from pyment.preprocessing.conform import conform
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
level=logging.DEBUG
)
logger = logging.getLogger(__name__)
def _parse_folder_name(name: str) -> Tuple[str, str, str]:
match = re.fullmatch(r'sub-(.*)_ses-(.*)_run-([^_]).*', name)
if not match:
return None, None, None
return match.groups()
def predict_from_fastsurfer_folder(
source: str,
folders: List[str] = None,
weights: str = None,
model_name: str = 'sfcn-multi',
targets: List[str] = [
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism'
],
destination: str = None
) -> pd.DataFrame:
if destination is not None and os.path.isfile(destination):
raise ValueError(f'Destination {destination} already exists')
logger.info('Loading multi-task model with weights %s', weights)
model_class = sfcn_factory(model_name)
model = model_class(weights=weights)
results = []
logger.info(f'Reading fastsurfer folders from {source}')
folders = (
folders if folders is not None
else [
folder for folder in os.listdir(source)
if os.path.isdir(os.path.join(source, folder))
]
)
for folder in tqdm(folders):
orig = os.path.join(source, folder, 'mri', 'orig.mgz')
subject, session, run = _parse_folder_name(folder)
if not os.path.isfile(orig):
logger.warning(
'No orig.mgz file for folder %s',
os.path.join(source, folder)
)
continue
orig = nib.load(orig)
orig = orig.get_fdata()
brainmask = os.path.join(source, folder, 'mri', 'mask.mgz')
if not os.path.isfile(brainmask):
logger.warning(
'No mask.mgz file for folder %s',
os.path.join(source, folder)
)
continue
try:
brainmask = nib.load(brainmask)
except Exception as e:
logger.error('Error loading brainmask for folder %s: %s', folder, e)
continue
brainmask = brainmask.get_fdata()
image = orig * brainmask
logger.debug('Conforming image from %s', os.path.join(source, folder))
image = conform(image)
predictions = model.predict(
np.expand_dims(image, axis=0),
verbose=0
)[0]
logger.debug('Predictions for %s: %s', folder, str(predictions))
results.append({
**{
'source': os.path.join(source, folder),
'subject': subject,
'session': session,
'run': run
},
**{targets[i]: predictions[i] for i in range(len(targets))}
})
results = pd.DataFrame(results)
if destination is not None:
results.to_csv(destination, index=False)
return results
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'Generates multi-task predictions for preprocessed images organized '
'in a FastSurfer folder'
)
parser.add_argument(
'root',
help=(
'Path to FastSurfer folder. Should contain subfolders that have '
'an \'mri\' subfolder that contains files orig.mgz and mask.mgz'
)
)
parser.add_argument(
'-w', '--weights',
required=False,
default='multi-2025',
help=(
'Weights to use. Should either point to a local file path, or a '
'known identifier. If a local file path <path> is used, there '
'should exist files named <path>.index and '
'<path>.data-00000-of-00001'
)
)
parser.add_argument(
'-m', '--model',
required=False,
default='sfcn-multi',
help=(
'Name of the model to use'
)
)
parser.add_argument(
'-t', '--targets',
required=False,
nargs='+',
default=[
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
'neuroticism'
],
help='Name to use for each of the prediction heads in the output CSV'
)
parser.add_argument(
'-f', '--folders',
default=None,
nargs='+',
help=(
'List of folders to process. If not provided, all folders in '
'the source folder will be processed.'
)
)
parser.add_argument(
'-d', '--destination',
required=False,
default=None,
help='Path where CSV with predictions are written'
)
args = parser.parse_args()
predict_from_fastsurfer_folder(
source=args.root,
folders=args.folders,
model_name=args.model,
weights=args.weights,
targets=args.targets,
destination=args.destination,
)
|