File size: 856 Bytes
95153a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
### **📄 inference.py (optional CLI tool)**
# inference.py
import argparse
import torch
from model import BiomedClipClassifier, predict_from_paths
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--weights", type=str, default=".")
parser.add_argument("--mri", type=str, required=True, help="Path to NIfTI MRI file")
parser.add_argument("--text", type=str, required=True, help="Clinical text")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BiomedClipClassifier.from_pretrained(args.weights, device=device)
pred, probs = predict_from_paths(model, args.mri, args.text, device=device)
print("Prediction:", pred)
print("Probabilities [CN, MCI, Dementia]:", [round(p, 4) for p in probs])
if __name__ == "__main__":
main()
|