| import argparse | |
| import logging | |
| import os | |
| import re | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from typing import List, Tuple | |
| import nibabel as nib | |
| from pyment.models.sfcn import sfcn_factory | |
| from pyment.preprocessing.conform import conform | |
| logging.basicConfig( | |
| format='%(asctime)s - %(levelname)s - %(name)s: %(message)s', | |
| level=logging.DEBUG | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def _parse_folder_name(name: str) -> Tuple[str, str, str]: | |
| match = re.fullmatch(r'sub-(.*)_ses-(.*)_run-([^_]).*', name) | |
| if not match: | |
| return None, None, None | |
| return match.groups() | |
| def predict_from_fastsurfer_folder( | |
| source: str, | |
| folders: List[str] = None, | |
| weights: str = None, | |
| model_name: str = 'sfcn-multi', | |
| targets: List[str] = [ | |
| 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism' | |
| ], | |
| destination: str = None | |
| ) -> pd.DataFrame: | |
| if destination is not None and os.path.isfile(destination): | |
| raise ValueError(f'Destination {destination} already exists') | |
| logger.info('Loading multi-task model with weights %s', weights) | |
| model_class = sfcn_factory(model_name) | |
| model = model_class(weights=weights) | |
| results = [] | |
| logger.info(f'Reading fastsurfer folders from {source}') | |
| folders = ( | |
| folders if folders is not None | |
| else [ | |
| folder for folder in os.listdir(source) | |
| if os.path.isdir(os.path.join(source, folder)) | |
| ] | |
| ) | |
| for folder in tqdm(folders): | |
| orig = os.path.join(source, folder, 'mri', 'orig.mgz') | |
| subject, session, run = _parse_folder_name(folder) | |
| if not os.path.isfile(orig): | |
| logger.warning( | |
| 'No orig.mgz file for folder %s', | |
| os.path.join(source, folder) | |
| ) | |
| continue | |
| orig = nib.load(orig) | |
| orig = orig.get_fdata() | |
| brainmask = os.path.join(source, folder, 'mri', 'mask.mgz') | |
| if not os.path.isfile(brainmask): | |
| logger.warning( | |
| 'No mask.mgz file for folder %s', | |
| os.path.join(source, folder) | |
| ) | |
| continue | |
| try: | |
| brainmask = nib.load(brainmask) | |
| except Exception as e: | |
| logger.error('Error loading brainmask for folder %s: %s', folder, e) | |
| continue | |
| brainmask = brainmask.get_fdata() | |
| image = orig * brainmask | |
| logger.debug('Conforming image from %s', os.path.join(source, folder)) | |
| image = conform(image) | |
| predictions = model.predict( | |
| np.expand_dims(image, axis=0), | |
| verbose=0 | |
| )[0] | |
| logger.debug('Predictions for %s: %s', folder, str(predictions)) | |
| results.append({ | |
| **{ | |
| 'source': os.path.join(source, folder), | |
| 'subject': subject, | |
| 'session': session, | |
| 'run': run | |
| }, | |
| **{targets[i]: predictions[i] for i in range(len(targets))} | |
| }) | |
| results = pd.DataFrame(results) | |
| if destination is not None: | |
| results.to_csv(destination, index=False) | |
| return results | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser( | |
| 'Generates multi-task predictions for preprocessed images organized ' | |
| 'in a FastSurfer folder' | |
| ) | |
| parser.add_argument( | |
| 'root', | |
| help=( | |
| 'Path to FastSurfer folder. Should contain subfolders that have ' | |
| 'an \'mri\' subfolder that contains files orig.mgz and mask.mgz' | |
| ) | |
| ) | |
| parser.add_argument( | |
| '-w', '--weights', | |
| required=False, | |
| default='multi-2025', | |
| help=( | |
| 'Weights to use. Should either point to a local file path, or a ' | |
| 'known identifier. If a local file path <path> is used, there ' | |
| 'should exist files named <path>.index and ' | |
| '<path>.data-00000-of-00001' | |
| ) | |
| ) | |
| parser.add_argument( | |
| '-m', '--model', | |
| required=False, | |
| default='sfcn-multi', | |
| help=( | |
| 'Name of the model to use' | |
| ) | |
| ) | |
| parser.add_argument( | |
| '-t', '--targets', | |
| required=False, | |
| nargs='+', | |
| default=[ | |
| 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', | |
| 'neuroticism' | |
| ], | |
| help='Name to use for each of the prediction heads in the output CSV' | |
| ) | |
| parser.add_argument( | |
| '-f', '--folders', | |
| default=None, | |
| nargs='+', | |
| help=( | |
| 'List of folders to process. If not provided, all folders in ' | |
| 'the source folder will be processed.' | |
| ) | |
| ) | |
| parser.add_argument( | |
| '-d', '--destination', | |
| required=False, | |
| default=None, | |
| help='Path where CSV with predictions are written' | |
| ) | |
| args = parser.parse_args() | |
| predict_from_fastsurfer_folder( | |
| source=args.root, | |
| folders=args.folders, | |
| model_name=args.model, | |
| weights=args.weights, | |
| targets=args.targets, | |
| destination=args.destination, | |
| ) | |