Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Self-contained AVRA inference: only supply --input-file. | |
| Uses relative path to weights (./weights). Outputs: coronal image (no text) + JSON with MTA left, MTA right, PA, GCA-F. | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import glob | |
| import collections | |
| import torch | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| # Run from inference folder so imports resolve | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, SCRIPT_DIR) | |
| from utils.load_transforms import load_transform | |
| from utils.misc import load_model, load_settings_from_model, load_mri, native_to_tal_python, native_to_tal_fsl | |
| # Weights path relative to this script | |
| DEFAULT_MODEL_DIR = os.path.join(SCRIPT_DIR, 'weights') | |
| def run_one(input_file, output_dir, uid=None, model_dir=None, registration=True, use_fsl=False, save_json=True): | |
| """ | |
| Run AVRA on one scan. Saves {uid}_coronal.jpg and optionally {uid}.json. | |
| Returns dict with keys: MTA_left, MTA_right, PA, GCA-F. | |
| """ | |
| model_dir = model_dir or DEFAULT_MODEL_DIR | |
| device = torch.device('cpu') | |
| if uid is None: | |
| uid = os.path.basename(input_file).replace('.nii.gz', '').replace('.nii', '') | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| dof = 6 | |
| if use_fsl: | |
| native_to_tal_fsl(input_file, force_new_transform=registration, dof=dof, | |
| output_folder=output_dir, guid=uid) | |
| else: | |
| native_to_tal_python(input_file, force_new_transform=registration, dof=dof, | |
| output_folder=output_dir, guid=uid) | |
| tal_path = os.path.join(output_dir, uid + '_mni_dof_' + str(dof) + '.nii') | |
| class _Args: | |
| pass | |
| args = _Args() | |
| args.model_dir = model_dir | |
| args.device = device | |
| args.uid = uid | |
| args.output_dir = output_dir | |
| rating_dict = collections.OrderedDict() | |
| coronal_path = None | |
| for args.vrs in ['mta', 'pa', 'gca-f']: | |
| img = load_mri(tal_path) | |
| fnames = np.sort(glob.glob(os.path.join(args.model_dir, args.vrs, '*.pth.tar'))) | |
| if len(fnames) == 0: | |
| raise FileNotFoundError('No weights in %s.' % args.model_dir) | |
| args = load_settings_from_model(fnames[0], args) | |
| _, transform_test, _ = load_transform(args) | |
| model_list = [load_model(ch, args) for ch in fnames] | |
| if args.vrs == 'mta': | |
| img = transform_test(img) | |
| img_left = img.numpy().copy() | |
| img_left = torch.from_numpy(np.flip(img_left, 2).copy()) | |
| img = img.unsqueeze(0) | |
| img_left = img_left.unsqueeze(0) | |
| imgs = [img, img_left] | |
| for i, side in enumerate(['right', 'left']): | |
| output_vec = [] | |
| for m, model in enumerate(model_list): | |
| model.eval() | |
| out = model(imgs[i]).mean() | |
| output_vec.append(out.detach().numpy().squeeze()) | |
| output_vec = np.array(output_vec) | |
| rating_dict[args.vrs + '_' + side + '_mean'] = float(output_vec.mean()) | |
| fig = plt.figure(figsize=(8, 8)) | |
| plt.imshow(np.rot90(img_left[0, 10, 0, :, :].numpy()), cmap='gray') | |
| plt.axis('off') | |
| coronal_path = os.path.join(args.output_dir, args.uid + '_coronal.jpg') | |
| fig.savefig(coronal_path, format='jpg', bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| else: | |
| img = transform_test(img) | |
| img = img.unsqueeze(0) | |
| output_vec = [] | |
| for m, model in enumerate(model_list): | |
| model.eval() | |
| model.to(args.device) | |
| out = model(img).mean() | |
| output_vec.append(out.detach().numpy().squeeze()) | |
| output_vec = np.array(output_vec) | |
| rating_dict[args.vrs + '_mean'] = float(output_vec.mean()) | |
| out = { | |
| 'MTA_left': rating_dict['mta_left_mean'], | |
| 'MTA_right': rating_dict['mta_right_mean'], | |
| 'PA': rating_dict['pa_mean'], | |
| 'GCA-F': rating_dict['gca-f_mean'], | |
| } | |
| if save_json: | |
| json_path = os.path.join(args.output_dir, args.uid + '.json') | |
| with open(json_path, 'w') as f: | |
| json.dump(out, f, indent=2) | |
| return out | |
| def main(): | |
| parser = argparse.ArgumentParser(description='AVRA inference: supply input NIfTI only.') | |
| parser.add_argument('--input-file', required=True, help='Path to input T1 MRI (.nii or .nii.gz)') | |
| parser.add_argument('--output-dir', default='', help='Output directory (default: same as input file directory)') | |
| parser.add_argument('--no-new-registration', dest='registration', action='store_false', | |
| help='Reuse existing AC-PC alignment if present') | |
| parser.add_argument('--use-fsl', action='store_true', | |
| help='Use FSL for registration (default: Python-only with nibabel + SimpleITK + nilearn)') | |
| parser.set_defaults(registration=True) | |
| args = parser.parse_args() | |
| if not os.path.exists(args.input_file): | |
| raise FileNotFoundError('Input file not found: %s' % args.input_file) | |
| if '.nii' not in os.path.basename(args.input_file): | |
| raise ValueError('Input must be .nii or .nii.gz') | |
| args.uid = os.path.basename(args.input_file).replace('.nii.gz', '').replace('.nii', '') | |
| args.output_dir = args.output_dir or os.path.dirname(os.path.abspath(args.input_file)) | |
| out = run_one(args.input_file, args.output_dir, uid=args.uid, | |
| registration=args.registration, use_fsl=args.use_fsl, save_json=True) | |
| coronal_path = os.path.join(args.output_dir, args.uid + '_coronal.jpg') | |
| json_path = os.path.join(args.output_dir, args.uid + '.json') | |
| print('Saved:', coronal_path, json_path) | |
| print('Ratings:', out) | |
| if __name__ == '__main__': | |
| main() | |