pyment-public / scripts /predict_from_fastsurfer_folder.py
estenhl's picture
Working on preprocess and predict container
4f9da36
import argparse
import logging
import os
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple
import nibabel as nib
from pyment.models.sfcn import sfcn_factory
from pyment.preprocessing.conform import conform
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
level=logging.DEBUG
)
logger = logging.getLogger(__name__)
def _parse_folder_name(name: str) -> Tuple[str, str, str]:
match = re.fullmatch(r'sub-(.*)_ses-(.*)_run-([^_]).*', name)
if not match:
return None, None, None
return match.groups()
def predict_from_fastsurfer_folder(
source: str,
folders: List[str] = None,
weights: str = None,
model_name: str = 'sfcn-multi',
targets: List[str] = [
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism'
],
destination: str = None
) -> pd.DataFrame:
if destination is not None and os.path.isfile(destination):
raise ValueError(f'Destination {destination} already exists')
logger.info('Loading multi-task model with weights %s', weights)
model_class = sfcn_factory(model_name)
model = model_class(weights=weights)
results = []
logger.info(f'Reading fastsurfer folders from {source}')
folders = (
folders if folders is not None
else [
folder for folder in os.listdir(source)
if os.path.isdir(os.path.join(source, folder))
]
)
for folder in tqdm(folders):
orig = os.path.join(source, folder, 'mri', 'orig.mgz')
subject, session, run = _parse_folder_name(folder)
if not os.path.isfile(orig):
logger.warning(
'No orig.mgz file for folder %s',
os.path.join(source, folder)
)
continue
orig = nib.load(orig)
orig = orig.get_fdata()
brainmask = os.path.join(source, folder, 'mri', 'mask.mgz')
if not os.path.isfile(brainmask):
logger.warning(
'No mask.mgz file for folder %s',
os.path.join(source, folder)
)
continue
try:
brainmask = nib.load(brainmask)
except Exception as e:
logger.error('Error loading brainmask for folder %s: %s', folder, e)
continue
brainmask = brainmask.get_fdata()
image = orig * brainmask
logger.debug('Conforming image from %s', os.path.join(source, folder))
image = conform(image)
predictions = model.predict(
np.expand_dims(image, axis=0),
verbose=0
)[0]
logger.debug('Predictions for %s: %s', folder, str(predictions))
results.append({
**{
'source': os.path.join(source, folder),
'subject': subject,
'session': session,
'run': run
},
**{targets[i]: predictions[i] for i in range(len(targets))}
})
results = pd.DataFrame(results)
if destination is not None:
results.to_csv(destination, index=False)
return results
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'Generates multi-task predictions for preprocessed images organized '
'in a FastSurfer folder'
)
parser.add_argument(
'root',
help=(
'Path to FastSurfer folder. Should contain subfolders that have '
'an \'mri\' subfolder that contains files orig.mgz and mask.mgz'
)
)
parser.add_argument(
'-w', '--weights',
required=False,
default='multi-2025',
help=(
'Weights to use. Should either point to a local file path, or a '
'known identifier. If a local file path <path> is used, there '
'should exist files named <path>.index and '
'<path>.data-00000-of-00001'
)
)
parser.add_argument(
'-m', '--model',
required=False,
default='sfcn-multi',
help=(
'Name of the model to use'
)
)
parser.add_argument(
'-t', '--targets',
required=False,
nargs='+',
default=[
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
'neuroticism'
],
help='Name to use for each of the prediction heads in the output CSV'
)
parser.add_argument(
'-f', '--folders',
default=None,
nargs='+',
help=(
'List of folders to process. If not provided, all folders in '
'the source folder will be processed.'
)
)
parser.add_argument(
'-d', '--destination',
required=False,
default=None,
help='Path where CSV with predictions are written'
)
args = parser.parse_args()
predict_from_fastsurfer_folder(
source=args.root,
folders=args.folders,
model_name=args.model,
weights=args.weights,
targets=args.targets,
destination=args.destination,
)