salmasoma
Set up inference-only HyperClinical Streamlit app with runtime HF asset download
278bf2b
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Self-contained AVRA inference: only supply --input-file.
Uses relative path to weights (./weights). Outputs: coronal image (no text) + JSON with MTA left, MTA right, PA, GCA-F.
"""
import argparse
import json
import os
import sys
import glob
import collections
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# Run from inference folder so imports resolve
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
from utils.load_transforms import load_transform
from utils.misc import load_model, load_settings_from_model, load_mri, native_to_tal_python, native_to_tal_fsl
# Weights path relative to this script
DEFAULT_MODEL_DIR = os.path.join(SCRIPT_DIR, 'weights')
def run_one(input_file, output_dir, uid=None, model_dir=None, registration=True, use_fsl=False, save_json=True):
"""
Run AVRA on one scan. Saves {uid}_coronal.jpg and optionally {uid}.json.
Returns dict with keys: MTA_left, MTA_right, PA, GCA-F.
"""
model_dir = model_dir or DEFAULT_MODEL_DIR
device = torch.device('cpu')
if uid is None:
uid = os.path.basename(input_file).replace('.nii.gz', '').replace('.nii', '')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
dof = 6
if use_fsl:
native_to_tal_fsl(input_file, force_new_transform=registration, dof=dof,
output_folder=output_dir, guid=uid)
else:
native_to_tal_python(input_file, force_new_transform=registration, dof=dof,
output_folder=output_dir, guid=uid)
tal_path = os.path.join(output_dir, uid + '_mni_dof_' + str(dof) + '.nii')
class _Args:
pass
args = _Args()
args.model_dir = model_dir
args.device = device
args.uid = uid
args.output_dir = output_dir
rating_dict = collections.OrderedDict()
coronal_path = None
for args.vrs in ['mta', 'pa', 'gca-f']:
img = load_mri(tal_path)
fnames = np.sort(glob.glob(os.path.join(args.model_dir, args.vrs, '*.pth.tar')))
if len(fnames) == 0:
raise FileNotFoundError('No weights in %s.' % args.model_dir)
args = load_settings_from_model(fnames[0], args)
_, transform_test, _ = load_transform(args)
model_list = [load_model(ch, args) for ch in fnames]
if args.vrs == 'mta':
img = transform_test(img)
img_left = img.numpy().copy()
img_left = torch.from_numpy(np.flip(img_left, 2).copy())
img = img.unsqueeze(0)
img_left = img_left.unsqueeze(0)
imgs = [img, img_left]
for i, side in enumerate(['right', 'left']):
output_vec = []
for m, model in enumerate(model_list):
model.eval()
out = model(imgs[i]).mean()
output_vec.append(out.detach().numpy().squeeze())
output_vec = np.array(output_vec)
rating_dict[args.vrs + '_' + side + '_mean'] = float(output_vec.mean())
fig = plt.figure(figsize=(8, 8))
plt.imshow(np.rot90(img_left[0, 10, 0, :, :].numpy()), cmap='gray')
plt.axis('off')
coronal_path = os.path.join(args.output_dir, args.uid + '_coronal.jpg')
fig.savefig(coronal_path, format='jpg', bbox_inches='tight', pad_inches=0)
plt.close()
else:
img = transform_test(img)
img = img.unsqueeze(0)
output_vec = []
for m, model in enumerate(model_list):
model.eval()
model.to(args.device)
out = model(img).mean()
output_vec.append(out.detach().numpy().squeeze())
output_vec = np.array(output_vec)
rating_dict[args.vrs + '_mean'] = float(output_vec.mean())
out = {
'MTA_left': rating_dict['mta_left_mean'],
'MTA_right': rating_dict['mta_right_mean'],
'PA': rating_dict['pa_mean'],
'GCA-F': rating_dict['gca-f_mean'],
}
if save_json:
json_path = os.path.join(args.output_dir, args.uid + '.json')
with open(json_path, 'w') as f:
json.dump(out, f, indent=2)
return out
def main():
parser = argparse.ArgumentParser(description='AVRA inference: supply input NIfTI only.')
parser.add_argument('--input-file', required=True, help='Path to input T1 MRI (.nii or .nii.gz)')
parser.add_argument('--output-dir', default='', help='Output directory (default: same as input file directory)')
parser.add_argument('--no-new-registration', dest='registration', action='store_false',
help='Reuse existing AC-PC alignment if present')
parser.add_argument('--use-fsl', action='store_true',
help='Use FSL for registration (default: Python-only with nibabel + SimpleITK + nilearn)')
parser.set_defaults(registration=True)
args = parser.parse_args()
if not os.path.exists(args.input_file):
raise FileNotFoundError('Input file not found: %s' % args.input_file)
if '.nii' not in os.path.basename(args.input_file):
raise ValueError('Input must be .nii or .nii.gz')
args.uid = os.path.basename(args.input_file).replace('.nii.gz', '').replace('.nii', '')
args.output_dir = args.output_dir or os.path.dirname(os.path.abspath(args.input_file))
out = run_one(args.input_file, args.output_dir, uid=args.uid,
registration=args.registration, use_fsl=args.use_fsl, save_json=True)
coronal_path = os.path.join(args.output_dir, args.uid + '_coronal.jpg')
json_path = os.path.join(args.output_dir, args.uid + '.json')
print('Saved:', coronal_path, json_path)
print('Ratings:', out)
if __name__ == '__main__':
main()