|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
import torch
|
|
|
from model import BiomedClipClassifier, predict_from_paths
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--weights", type=str, default=".")
|
|
|
parser.add_argument("--mri", type=str, required=True, help="Path to NIfTI MRI file")
|
|
|
parser.add_argument("--text", type=str, required=True, help="Clinical text")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
model = BiomedClipClassifier.from_pretrained(args.weights, device=device)
|
|
|
|
|
|
pred, probs = predict_from_paths(model, args.mri, args.text, device=device)
|
|
|
print("Prediction:", pred)
|
|
|
print("Probabilities [CN, MCI, Dementia]:", [round(p, 4) for p in probs])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|