#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Self-contained AVRA inference: only supply --input-file. Uses relative path to weights (./weights). Outputs: coronal image (no text) + JSON with MTA left, MTA right, PA, GCA-F. """ import argparse import json import os import sys import glob import collections import torch import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt # Run from inference folder so imports resolve SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) from utils.load_transforms import load_transform from utils.misc import load_model, load_settings_from_model, load_mri, native_to_tal_python, native_to_tal_fsl # Weights path relative to this script DEFAULT_MODEL_DIR = os.path.join(SCRIPT_DIR, 'weights') def run_one(input_file, output_dir, uid=None, model_dir=None, registration=True, use_fsl=False, save_json=True): """ Run AVRA on one scan. Saves {uid}_coronal.jpg and optionally {uid}.json. Returns dict with keys: MTA_left, MTA_right, PA, GCA-F. """ model_dir = model_dir or DEFAULT_MODEL_DIR device = torch.device('cpu') if uid is None: uid = os.path.basename(input_file).replace('.nii.gz', '').replace('.nii', '') if not os.path.exists(output_dir): os.makedirs(output_dir) dof = 6 if use_fsl: native_to_tal_fsl(input_file, force_new_transform=registration, dof=dof, output_folder=output_dir, guid=uid) else: native_to_tal_python(input_file, force_new_transform=registration, dof=dof, output_folder=output_dir, guid=uid) tal_path = os.path.join(output_dir, uid + '_mni_dof_' + str(dof) + '.nii') class _Args: pass args = _Args() args.model_dir = model_dir args.device = device args.uid = uid args.output_dir = output_dir rating_dict = collections.OrderedDict() coronal_path = None for args.vrs in ['mta', 'pa', 'gca-f']: img = load_mri(tal_path) fnames = np.sort(glob.glob(os.path.join(args.model_dir, args.vrs, '*.pth.tar'))) if len(fnames) == 0: raise FileNotFoundError('No weights in %s.' % args.model_dir) args = load_settings_from_model(fnames[0], args) _, transform_test, _ = load_transform(args) model_list = [load_model(ch, args) for ch in fnames] if args.vrs == 'mta': img = transform_test(img) img_left = img.numpy().copy() img_left = torch.from_numpy(np.flip(img_left, 2).copy()) img = img.unsqueeze(0) img_left = img_left.unsqueeze(0) imgs = [img, img_left] for i, side in enumerate(['right', 'left']): output_vec = [] for m, model in enumerate(model_list): model.eval() out = model(imgs[i]).mean() output_vec.append(out.detach().numpy().squeeze()) output_vec = np.array(output_vec) rating_dict[args.vrs + '_' + side + '_mean'] = float(output_vec.mean()) fig = plt.figure(figsize=(8, 8)) plt.imshow(np.rot90(img_left[0, 10, 0, :, :].numpy()), cmap='gray') plt.axis('off') coronal_path = os.path.join(args.output_dir, args.uid + '_coronal.jpg') fig.savefig(coronal_path, format='jpg', bbox_inches='tight', pad_inches=0) plt.close() else: img = transform_test(img) img = img.unsqueeze(0) output_vec = [] for m, model in enumerate(model_list): model.eval() model.to(args.device) out = model(img).mean() output_vec.append(out.detach().numpy().squeeze()) output_vec = np.array(output_vec) rating_dict[args.vrs + '_mean'] = float(output_vec.mean()) out = { 'MTA_left': rating_dict['mta_left_mean'], 'MTA_right': rating_dict['mta_right_mean'], 'PA': rating_dict['pa_mean'], 'GCA-F': rating_dict['gca-f_mean'], } if save_json: json_path = os.path.join(args.output_dir, args.uid + '.json') with open(json_path, 'w') as f: json.dump(out, f, indent=2) return out def main(): parser = argparse.ArgumentParser(description='AVRA inference: supply input NIfTI only.') parser.add_argument('--input-file', required=True, help='Path to input T1 MRI (.nii or .nii.gz)') parser.add_argument('--output-dir', default='', help='Output directory (default: same as input file directory)') parser.add_argument('--no-new-registration', dest='registration', action='store_false', help='Reuse existing AC-PC alignment if present') parser.add_argument('--use-fsl', action='store_true', help='Use FSL for registration (default: Python-only with nibabel + SimpleITK + nilearn)') parser.set_defaults(registration=True) args = parser.parse_args() if not os.path.exists(args.input_file): raise FileNotFoundError('Input file not found: %s' % args.input_file) if '.nii' not in os.path.basename(args.input_file): raise ValueError('Input must be .nii or .nii.gz') args.uid = os.path.basename(args.input_file).replace('.nii.gz', '').replace('.nii', '') args.output_dir = args.output_dir or os.path.dirname(os.path.abspath(args.input_file)) out = run_one(args.input_file, args.output_dir, uid=args.uid, registration=args.registration, use_fsl=args.use_fsl, save_json=True) coronal_path = os.path.join(args.output_dir, args.uid + '_coronal.jpg') json_path = os.path.join(args.output_dir, args.uid + '.json') print('Saved:', coronal_path, json_path) print('Ratings:', out) if __name__ == '__main__': main()