| import argparse | |
| import logging | |
| import os | |
| import re | |
| import numpy as np | |
| import pandas as pd | |
| from typing import List | |
| from tqdm import tqdm | |
| import nibabel as nib | |
| from pyment.models.sfcn import sfcn_factory | |
| from pyment.preprocessing.conform import conform | |
| logging.basicConfig( | |
| format='%(asctime)s - %(levelname)s - %(name)s: %(message)s', | |
| level=logging.DEBUG | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def _extract_run(filename: str) -> str: | |
| match = re.fullmatch( | |
| r'.*_run-([^_]+)(?:_.*)?(?:\.nii(?:\.gz)?|\.mgz)', | |
| filename | |
| ) | |
| if not match: | |
| logger.warning('Unable to extract run for filename %s', filename) | |
| return None | |
| return match.groups()[0] | |
| def _extract_modality(filename: str) -> str: | |
| match = re.fullmatch( | |
| r'.*_run-(?:[^_]+)(?:_(.*))?(?:\.nii(?:\.gz)?|\.mgz)', | |
| filename | |
| ) | |
| if not match: | |
| logger.warning('Unable to extract modality for filename %s', filename) | |
| return None | |
| return match.groups()[0] | |
| def predict_from_bids_folder( | |
| source: str, | |
| weights: str, | |
| model_name: str = 'sfcn-multi', | |
| targets: List[str] = [ | |
| 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism' | |
| ], | |
| destination: str = None, | |
| per_image_normalization: bool = False | |
| ) -> pd.DataFrame: | |
| if destination is not None and os.path.isfile(destination): | |
| raise ValueError(f'Destination {destination} already exists') | |
| logger.info('Loading multi-task model with weights %s', weights) | |
| model_class = sfcn_factory(model_name) | |
| model = model_class(weights=weights) | |
| results = [] | |
| for subject in tqdm(os.listdir(source)): | |
| for session in os.listdir(os.path.join(source, subject)): | |
| anat_folder = os.path.join(source, subject, session, 'anat') | |
| if not os.path.isdir(anat_folder): | |
| logger.warning( | |
| 'No anat-folder exists for subject %s and session %s', | |
| subject, session | |
| ) | |
| continue | |
| for filename in os.listdir(anat_folder): | |
| path = os.path.join(anat_folder, filename) | |
| run = _extract_run(filename) | |
| modality = _extract_modality(filename) | |
| logger.debug(f'Loading image {path}') | |
| image = nib.load(os.path.join(anat_folder, filename)) | |
| image = image.get_fdata() | |
| logger.debug(f'Conforming image {path}') | |
| image = conform( | |
| image, | |
| relative_normalization=per_image_normalization | |
| ) | |
| predictions = model.predict(np.expand_dims(image, axis=0))[0] | |
| logger.debug( | |
| 'Predictions for %s, %s: %s', | |
| subject, session, str(predictions) | |
| ) | |
| results.append({ | |
| **{ | |
| 'source': path, | |
| 'subject': subject.replace('sub-', ''), | |
| 'session': session.replace('ses-', ''), | |
| 'run': run, | |
| 'modality': modality | |
| }, | |
| **{targets[i]: predictions[i] for i in range(len(targets))} | |
| }) | |
| results = pd.DataFrame(results) | |
| if destination is not None: | |
| results.to_csv(destination, index=False) | |
| return results | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser( | |
| 'Generates multi-task predictions for preprocessed images organized ' | |
| 'in a BIDS folder' | |
| ) | |
| parser.add_argument('bids', help='Path to BIDS folder') | |
| parser.add_argument( | |
| '-w', '--weights', | |
| required=True, | |
| help=( | |
| 'Weights to use. Should either point to a local file path, or a ' | |
| 'known keyword. If a local file path <path> is used, there should ' | |
| 'exist files named <path>.index and <path>.data-00000-of-00001' | |
| ) | |
| ) | |
| parser.add_argument( | |
| '-m', '--model', | |
| required=False, | |
| default='sfcn-multi', | |
| help=( | |
| 'Name of the model to use' | |
| ) | |
| ) | |
| parser.add_argument( | |
| '-t', '--targets', | |
| required=False, | |
| nargs='+', | |
| default=[ | |
| 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', | |
| 'neuroticism' | |
| ], | |
| help='Name to use for each of the prediction heads in the output CSV' | |
| ) | |
| parser.add_argument( | |
| '-d', '--destination', | |
| required=False, | |
| default=None, | |
| help='Path where CSV with predictions are written' | |
| ) | |
| parser.add_argument( | |
| '-n', '--per_image_normalization', | |
| action='store_true', | |
| help=( | |
| 'If set, the voxel values of each individual image is conformed ' | |
| 'by dividing by the image max instead of the constant 255' | |
| ) | |
| ) | |
| args = parser.parse_args() | |
| predict_from_bids_folder( | |
| source=args.bids, | |
| weights=args.weights, | |
| model_name=args.model, | |
| targets=args.targets, | |
| destination=args.destination, | |
| per_image_normalization=args.per_image_normalization | |
| ) |