Update README.md
Browse files
README.md
CHANGED
|
@@ -119,47 +119,124 @@ ProtoPatient/
|
|
| 119 |
### 1. Install Dependencies
|
| 120 |
|
| 121 |
```bash
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
| 123 |
```
|
| 124 |
|
| 125 |
### 2. Load the Model via Hugging Face
|
| 126 |
|
| 127 |
```python
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
model.eval()
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
```
|
| 140 |
|
| 141 |
-
## 3.
|
| 142 |
|
| 143 |
-
|
| 144 |
-
- Which tokens receive high attention for each diagnosis.
|
| 145 |
-
- Which prototypical patients are retrieved as similar examples.
|
| 146 |
|
| 147 |
-
|
| 148 |
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
|
| 155 |
```python
|
| 156 |
-
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
```
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
# Intended Use, Limitations & Ethical Considerations
|
| 164 |
|
| 165 |
## Intended Use
|
|
|
|
| 119 |
### 1. Install Dependencies
|
| 120 |
|
| 121 |
```bash
|
| 122 |
+
git clone https://huggingface.co/row56/ProtoPatient
|
| 123 |
+
cd ProtoPatient
|
| 124 |
+
pip install -e . transformers torch safetensors
|
| 125 |
+
export TOKENIZERS_PARALLELISM=false
|
| 126 |
```
|
| 127 |
|
| 128 |
### 2. Load the Model via Hugging Face
|
| 129 |
|
| 130 |
```python
|
| 131 |
+
import torch
|
| 132 |
+
from transformers import AutoTokenizer
|
| 133 |
+
from proto_model.configuration_proto import ProtoConfig
|
| 134 |
+
from proto_model.modeling_proto import ProtoForMultiLabelClassification
|
| 135 |
+
|
| 136 |
+
# Load & configure
|
| 137 |
+
cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
|
| 138 |
+
cfg.pretrained_model_name_or_path = "bert-base-uncased"
|
| 139 |
+
cfg.use_cuda = False
|
| 140 |
+
|
| 141 |
+
tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
|
| 142 |
+
model = ProtoForMultiLabelClassification.from_pretrained(
|
| 143 |
+
"row56/ProtoPatient",
|
| 144 |
+
config=cfg,
|
| 145 |
+
ignore_mismatched_sizes=True
|
| 146 |
+
)
|
| 147 |
model.eval()
|
| 148 |
+
model.cpu()
|
| 149 |
+
|
| 150 |
+
# Helper
|
| 151 |
+
def get_proto_logits(texts):
|
| 152 |
+
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
| 153 |
+
batch = {
|
| 154 |
+
"input_ids": enc["input_ids"],
|
| 155 |
+
"attention_masks": enc["attention_mask"],
|
| 156 |
+
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
|
| 157 |
+
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
|
| 158 |
+
}
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
logits, _ = model.proto_module(batch)
|
| 161 |
+
return logits
|
| 162 |
+
|
| 163 |
+
# Run
|
| 164 |
+
texts = [
|
| 165 |
+
"Patient shows elevated heart rate and low oxygen saturation.",
|
| 166 |
+
"No significant findings; patient is healthy."
|
| 167 |
+
]
|
| 168 |
+
logits = get_proto_logits(texts)
|
| 169 |
+
print("Logits shape:", logits.shape)
|
| 170 |
+
print("Logits:\n", logits)
|
| 171 |
+
|
| 172 |
+
probs = torch.sigmoid(logits)
|
| 173 |
+
print("Probabilities:\n", probs)
|
| 174 |
```
|
| 175 |
|
| 176 |
+
## 3. Training Data & Licenses
|
| 177 |
|
| 178 |
+
This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement.
|
|
|
|
|
|
|
| 179 |
|
| 180 |
+
To obtain MIMIC-III:
|
| 181 |
|
| 182 |
+
Visit https://physionet.org/content/mimiciii/1.4/
|
| 183 |
+
Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training.
|
| 184 |
+
Sign the MIMIC-III Data Use Agreement (DUA).
|
| 185 |
+
Download the raw notes and run the preprocessing scripts from the paper’s repository.
|
| 186 |
+
Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license.
|
| 187 |
+
|
| 188 |
+
## 4. Load Precomputed Training Data for Prototype Retrieval
|
| 189 |
|
| 190 |
+
After you have MIMIC-III and have applied the published preprocessing, you should produce:
|
| 191 |
|
| 192 |
+
data/train_embeds.npy — NumPy array of shape (N, d) with per-example, per-class embeddings.
|
| 193 |
+
data/train_texts.json — JSON array of length N of the raw admission-note strings.
|
| 194 |
+
Place those in data/ and then:
|
| 195 |
|
| 196 |
```python
|
| 197 |
+
import numpy as np
|
| 198 |
+
import json
|
| 199 |
|
| 200 |
+
train_embeds = np.load("data/train_embeds.npy") # shape (N, d)
|
| 201 |
+
with open("data/train_texts.json", "r") as f:
|
| 202 |
+
train_texts = json.load(f) # list[str]
|
| 203 |
+
|
| 204 |
+
print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
|
| 205 |
```
|
| 206 |
|
| 207 |
+
## 5. Interpreting Outputs & Retrieving Prototypes
|
| 208 |
+
|
| 209 |
+
```python
|
| 210 |
+
from sklearn.neighbors import NearestNeighbors
|
| 211 |
+
|
| 212 |
+
text = "Patient has chest pain and shortness of breath."
|
| 213 |
+
enc = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
|
| 214 |
+
batch = {
|
| 215 |
+
"input_ids": enc["input_ids"],
|
| 216 |
+
"attention_masks": enc["attention_mask"],
|
| 217 |
+
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
|
| 218 |
+
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
logits, metadata = model.proto_module(batch)
|
| 223 |
+
|
| 224 |
+
attn_scores = metadata["attentions"][0] # [num_labels, seq_len]
|
| 225 |
+
for label_id, scores in enumerate(attn_scores):
|
| 226 |
+
topk = sorted(zip(batch["tokens"][0], scores.tolist()),
|
| 227 |
+
key=lambda x: -x[1])[:5]
|
| 228 |
+
print(f"Label {label_id} top tokens:", topk)
|
| 229 |
+
|
| 230 |
+
proto_vecs = model.proto_module.prototype_vectors.cpu().numpy() # [num_labels, d]
|
| 231 |
+
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
|
| 232 |
+
|
| 233 |
+
for label_id, u_c in enumerate(proto_vecs):
|
| 234 |
+
dist, idx = nn.kneighbors(u_c.reshape(1, -1))
|
| 235 |
+
print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):")
|
| 236 |
+
print(train_texts[idx[0][0]])
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
|
| 240 |
# Intended Use, Limitations & Ethical Considerations
|
| 241 |
|
| 242 |
## Intended Use
|