File size: 1,567 Bytes
901688c
 
 
 
90157d6
901688c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7012b23
6403294
7012b23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6403294
7012b23
6403294
7012b23
6403294
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# BiomedCLIP MRI + Clinical Text Classifier

This model fine-tunes [BiomedCLIP (PubMedBERT ViT-B/16)](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) for **Alzheimer’s disease classification** from **MRI (3D volumes)** and **synthetic clinical text**.

---

## 🧩 Model Details
- **Backbone**: BiomedCLIP (image + text encoders)
- **Input MRI**: 3D NIfTI → reduced to 3 mid-slices (axial, coronal, sagittal) → stacked into RGB
- **Input Text**: Synthetic patient note (tokenized with PubMedBERT)
- **Fusion**: Concatenate image & text embeddings
- **Head**: MLP (Linear → ReLU → Dropout → Linear) → 3-way classification
- **Labels**:
  - `CN` – Cognitively Normal
  - `MCI` – Mild Cognitive Impairment
  - `Dementia`

---

## 🚀 Usage

### Install
```bash
pip install open_clip_torch nibabel torch torchvision


##Load Pretrained Model
```
import torch
from model import BiomedClipClassifier, predict_from_paths

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load from repo (assuming you pushed pytorch_model.bin + config.json here)
model = BiomedClipClassifier.from_pretrained(".", device=device)

# Example inference
pred, probs = predict_from_paths(
    model,
    "/path/to/sample_brain.nii.gz",
    "Patient shows mild memory impairment and hippocampal atrophy.",
    device=device
)

print("Prediction:", pred)
print("Probabilities:", probs)  # [CN, MCI, Dementia]

```
##Run Inference 
```
python inference.py --weights . --mri /path/to/sample.nii.gz --text "Patient shows memory issues"
```