|
|
--- |
|
|
tags: |
|
|
- text-classification |
|
|
- medical |
|
|
- prototypical-networks |
|
|
- transformers |
|
|
library_name: transformers |
|
|
language: en |
|
|
license: mit |
|
|
datasets: |
|
|
- your_dataset_name_here |
|
|
model-index: |
|
|
- name: ProtoPatient |
|
|
results: |
|
|
- task: |
|
|
type: multi-label-classification |
|
|
dataset: |
|
|
name: your_dataset_name_here |
|
|
type: text |
|
|
metrics: |
|
|
- name: Accuracy |
|
|
type: accuracy |
|
|
value: 0.XX |
|
|
- name: F1-score |
|
|
type: f1 |
|
|
value: 0.XX |
|
|
--- |
|
|
|
|
|
|
|
|
# ProtoPatient Model for Multi-Label Classification |
|
|
|
|
|
## Paper Reference |
|
|
|
|
|
**van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.** |
|
|
*This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text.* |
|
|
[arXiv:2210.08500](https://arxiv.org/abs/2210.08500) |
|
|
|
|
|
ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability by: |
|
|
|
|
|
- **Highlighting Relevant Tokens:** Shows the most important words for each possible diagnosis. |
|
|
- **Retrieving Prototypical Patients:** Finds training examples with similar textual patterns to provide intuitive justifications for clinicians—essentially answering, “This patient looks like that patient.” |
|
|
|
|
|
--- |
|
|
|
|
|
## Model Overview |
|
|
|
|
|
### Prototype-Based Classification |
|
|
|
|
|
- The model learns **prototypical vectors** (\(u_c\)) for each diagnosis \(c\). |
|
|
- A patient’s admission note is encoded via a PubMedBERT encoder and a linear compression layer into a diagnosis-specific representation (\(v_{p,c}\)). This representation is generated using a label-wise attention mechanism. |
|
|
- Classification scores are computed as the **negative Euclidean distance** between \(v_{p,c}\) and \(u_c\), which directly measures the note’s similarity to the learned prototype. |
|
|
|
|
|
### Label-Wise Attention |
|
|
|
|
|
- For each diagnosis, a separate attention vector identifies relevant tokens in the admission note. |
|
|
- This mechanism provides interpretability by indicating which tokens are most influential in driving each prediction. |
|
|
|
|
|
### Interpretable Output |
|
|
|
|
|
- **Token Highlights:** The top attended words (often correlating with symptoms, risk factors, or diagnostic descriptors). |
|
|
- **Prototypical Patients:** Examples from the training set that are closest to each prototype, representing typical presentations of a diagnosis. |
|
|
|
|
|
--- |
|
|
|
|
|
## Key Features and Benefits |
|
|
|
|
|
- **Improved Performance on Rare Diagnoses:** |
|
|
Prototype-based learning has strong few-shot capabilities, which is especially beneficial for diagnoses with very few samples. |
|
|
|
|
|
- **Faithful Interpretations:** |
|
|
Quantitative evaluations (see Section 5 in the paper) indicate that the attention-based highlights are more faithful to the model’s decision process compared to post-hoc methods such as Lime, Occlusion, and gradient-based approaches. |
|
|
|
|
|
- **Clinical Utility:** |
|
|
- Provides label-wise explanations to help clinicians assess whether the predictions align with actual risk factors. |
|
|
- Points to prototypical patients, allowing for comparison of new cases with typical (or atypical) presentations. |
|
|
|
|
|
--- |
|
|
|
|
|
## Performance Metrics |
|
|
|
|
|
Evaluated on **MIMIC-III**: |
|
|
- **Admission Notes:** 48,745 |
|
|
- **Diagnosis Labels:** 1,266 |
|
|
|
|
|
Performance (approximate): |
|
|
- **Macro ROC AUC:** ~87–88% |
|
|
- **Micro ROC AUC:** ~97% |
|
|
- **Macro PR AUC:** ~18–21% |
|
|
|
|
|
The model shows particularly strong gains for rare diagnoses (less than 50 samples) when compared with baselines like PubMedBERT alone or hierarchical attention RNNs (e.g., HAN, HA-GRU). |
|
|
|
|
|
Additionally, the model achieves high transferability on **i2b2** data (1,118 admission notes) across different clinical environments. |
|
|
|
|
|
*Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.* |
|
|
|
|
|
--- |
|
|
|
|
|
## Repository Structure |
|
|
|
|
|
```plaintext |
|
|
ProtoPatient/ |
|
|
├── proto_model/ |
|
|
│ ├── proto.py |
|
|
│ ├── utils.py |
|
|
│ ├── metrics.py |
|
|
│ └── __init__.py |
|
|
├── config.json |
|
|
├── setup.py |
|
|
├── model.safetensors |
|
|
├── tokenizer.json |
|
|
├── tokenizer_config.json |
|
|
├── vocab.txt |
|
|
├── README.md |
|
|
└── .gitattributes |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## How to Use the Model |
|
|
|
|
|
### 1. Install Dependencies |
|
|
|
|
|
```bash |
|
|
git clone https://huggingface.co/row56/ProtoPatient |
|
|
cd ProtoPatient |
|
|
pip install -e . transformers torch safetensors |
|
|
export TOKENIZERS_PARALLELISM=false |
|
|
``` |
|
|
|
|
|
### 2. Load the Model via Hugging Face |
|
|
|
|
|
```python |
|
|
import os |
|
|
import warnings |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
from transformers import logging as hf_logging |
|
|
hf_logging.set_verbosity_error() |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from proto_model.configuration_proto import ProtoConfig |
|
|
from proto_model.modeling_proto import ProtoForMultiLabelClassification |
|
|
|
|
|
cfg = ProtoConfig.from_pretrained("row56/ProtoPatient") |
|
|
cfg.pretrained_model_name_or_path = "bert-base-uncased" |
|
|
cfg.use_cuda = torch.cuda.is_available() |
|
|
|
|
|
device = torch.device("cuda" if cfg.use_cuda else "cpu") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path) |
|
|
model = ProtoForMultiLabelClassification.from_pretrained( |
|
|
"row56/ProtoPatient", |
|
|
config=cfg, |
|
|
|
|
|
) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
def get_proto_logits(texts): |
|
|
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") |
|
|
batch = { |
|
|
"input_ids": enc["input_ids"], |
|
|
"attention_masks": enc["attention_mask"], |
|
|
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])), |
|
|
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]] |
|
|
} |
|
|
with torch.no_grad(): |
|
|
logits, _ = model.proto_module(batch) |
|
|
return logits |
|
|
|
|
|
texts = [ |
|
|
"Patient shows elevated heart rate and low oxygen saturation.", |
|
|
"No significant findings; patient is healthy." |
|
|
] |
|
|
logits = get_proto_logits(texts) |
|
|
print("Logits shape:", logits.shape) |
|
|
print("Logits:\n", logits) |
|
|
``` |
|
|
|
|
|
## 3. Training Data & Licenses |
|
|
|
|
|
This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement. |
|
|
|
|
|
To obtain MIMIC-III: |
|
|
|
|
|
Visit https://physionet.org/content/mimiciii/1.4/ |
|
|
Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training. |
|
|
Sign the MIMIC-III Data Use Agreement (DUA). |
|
|
Download the raw notes and run the preprocessing scripts from the paper’s repository. |
|
|
Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license. |
|
|
|
|
|
## 4. Load Precomputed Training Data for Prototype Retrieval |
|
|
|
|
|
After you have MIMIC-III and have applied the published preprocessing, you should produce: |
|
|
|
|
|
data/train_embeds.npy — NumPy array of shape (N, d) with per-example, per-class embeddings. |
|
|
data/train_texts.json — JSON array of length N of the raw admission-note strings. |
|
|
Place those in data/ and then: |
|
|
|
|
|
```python |
|
|
import numpy as np |
|
|
import json |
|
|
|
|
|
train_embeds = np.load("data/train_embeds.npy") |
|
|
with open("data/train_texts.json", "r") as f: |
|
|
train_texts = json.load(f) |
|
|
|
|
|
print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}") |
|
|
``` |
|
|
|
|
|
## 5. Interpreting Outputs & Retrieving Prototypes |
|
|
|
|
|
```python |
|
|
from sklearn.neighbors import NearestNeighbors |
|
|
|
|
|
text = "Patient has chest pain and shortness of breath." |
|
|
enc = tokenizer([text], padding=True, truncation=True, return_tensors="pt") |
|
|
batch = { |
|
|
"input_ids": enc["input_ids"], |
|
|
"attention_masks": enc["attention_mask"], |
|
|
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])), |
|
|
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]] |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits, metadata = model.proto_module(batch) |
|
|
|
|
|
attn_scores = metadata["attentions"][0] |
|
|
for label_id, scores in enumerate(attn_scores): |
|
|
topk = sorted(zip(batch["tokens"][0], scores.tolist()), |
|
|
key=lambda x: -x[1])[:5] |
|
|
print(f"Label {label_id} top tokens:", topk) |
|
|
|
|
|
proto_vecs = model.proto_module.prototype_vectors.cpu().numpy() |
|
|
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds) |
|
|
|
|
|
for label_id, u_c in enumerate(proto_vecs): |
|
|
dist, idx = nn.kneighbors(u_c.reshape(1, -1)) |
|
|
print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):") |
|
|
print(train_texts[idx[0][0]]) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
# Intended Use, Limitations & Ethical Considerations |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
- **Research & Education:** |
|
|
ProtoPatient is designed primarily for academic research and educational purposes in clinical NLP. |
|
|
|
|
|
- **Interpretability Demonstration:** |
|
|
The model demonstrates how prototype-based methods can provide interpretable multi-label classification on clinical admission notes. |
|
|
|
|
|
--- |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- **Generalization:** |
|
|
The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations. |
|
|
|
|
|
- **Prototype Scope:** |
|
|
The current version uses a single prototype per diagnosis, though some diagnoses might have multiple typical presentations—this is an area for future improvement. |
|
|
|
|
|
- **Inter-diagnosis Relationships:** |
|
|
The model does not explicitly model relationships (e.g., conflicts or comorbidities) between different diagnoses. |
|
|
|
|
|
--- |
|
|
|
|
|
## Ethical & Regulatory Considerations |
|
|
|
|
|
- **Not for Direct Clinical Use:** |
|
|
This model is not intended for direct clinical decision-making. Always consult healthcare professionals. |
|
|
|
|
|
- **Bias and Fairness:** |
|
|
Users should be aware of potential biases in the training data; rare conditions might still be misclassified. |
|
|
|
|
|
- **Patient Privacy:** |
|
|
When applying the model to real clinical data, patient privacy must be strictly maintained. |
|
|
|
|
|
--- |
|
|
|
|
|
# Example Interpretability Output |
|
|
|
|
|
Based on the approach described in the paper (see Section 5 and Table 5): |
|
|
|
|
|
- **Highlighted Tokens:** |
|
|
Tokens such as “worst headache of her life,” “vomiting,” “fever,” and “infiltrate” strongly indicate specific diagnoses. |
|
|
|
|
|
- **Prototypical Sample:** |
|
|
A snippet from a training patient with similar text segments provides a rationale for the prediction. |
|
|
|
|
|
*This interpretability output aids clinicians in understanding the model's reasoning – for example: "The system suggests intracerebral hemorrhage because the patient's note closely resembles typical cases with that diagnosis."* |
|
|
|
|
|
--- |
|
|
|
|
|
# Recommended Citation |
|
|
|
|
|
If you use ProtoPatient in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{vanaken2022this, |
|
|
title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text}, |
|
|
author={van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander}, |
|
|
year={2022}, |
|
|
eprint={2210.08500}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CL} |
|
|
} |
|
|
|
|
|
|
|
|
|