Yasmine97 commited on
Commit
95153a7
·
verified ·
1 Parent(s): 7012b23

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +24 -0
inference.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### **📄 inference.py (optional CLI tool)**
3
+
4
+ # inference.py
5
+ import argparse
6
+ import torch
7
+ from model import BiomedClipClassifier, predict_from_paths
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--weights", type=str, default=".")
12
+ parser.add_argument("--mri", type=str, required=True, help="Path to NIfTI MRI file")
13
+ parser.add_argument("--text", type=str, required=True, help="Clinical text")
14
+ args = parser.parse_args()
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = BiomedClipClassifier.from_pretrained(args.weights, device=device)
18
+
19
+ pred, probs = predict_from_paths(model, args.mri, args.text, device=device)
20
+ print("Prediction:", pred)
21
+ print("Probabilities [CN, MCI, Dementia]:", [round(p, 4) for p in probs])
22
+
23
+ if __name__ == "__main__":
24
+ main()