ProtoPatient / README.md
row56's picture
Update README.md
7467653 verified
---
tags:
- text-classification
- medical
- prototypical-networks
- transformers
library_name: transformers
language: en
license: mit
datasets:
- your_dataset_name_here
model-index:
- name: ProtoPatient
results:
- task:
type: multi-label-classification
dataset:
name: your_dataset_name_here
type: text
metrics:
- name: Accuracy
type: accuracy
value: 0.XX # Update with real value
- name: F1-score
type: f1
value: 0.XX # Update with real value
---
# ProtoPatient Model for Multi-Label Classification
## Paper Reference
**van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.**
*This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text.*
[arXiv:2210.08500](https://arxiv.org/abs/2210.08500)
ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability by:
- **Highlighting Relevant Tokens:** Shows the most important words for each possible diagnosis.
- **Retrieving Prototypical Patients:** Finds training examples with similar textual patterns to provide intuitive justifications for clinicians—essentially answering, “This patient looks like that patient.”
---
## Model Overview
### Prototype-Based Classification
- The model learns **prototypical vectors** (\(u_c\)) for each diagnosis \(c\).
- A patient’s admission note is encoded via a PubMedBERT encoder and a linear compression layer into a diagnosis-specific representation (\(v_{p,c}\)). This representation is generated using a label-wise attention mechanism.
- Classification scores are computed as the **negative Euclidean distance** between \(v_{p,c}\) and \(u_c\), which directly measures the note’s similarity to the learned prototype.
### Label-Wise Attention
- For each diagnosis, a separate attention vector identifies relevant tokens in the admission note.
- This mechanism provides interpretability by indicating which tokens are most influential in driving each prediction.
### Interpretable Output
- **Token Highlights:** The top attended words (often correlating with symptoms, risk factors, or diagnostic descriptors).
- **Prototypical Patients:** Examples from the training set that are closest to each prototype, representing typical presentations of a diagnosis.
---
## Key Features and Benefits
- **Improved Performance on Rare Diagnoses:**
Prototype-based learning has strong few-shot capabilities, which is especially beneficial for diagnoses with very few samples.
- **Faithful Interpretations:**
Quantitative evaluations (see Section 5 in the paper) indicate that the attention-based highlights are more faithful to the model’s decision process compared to post-hoc methods such as Lime, Occlusion, and gradient-based approaches.
- **Clinical Utility:**
- Provides label-wise explanations to help clinicians assess whether the predictions align with actual risk factors.
- Points to prototypical patients, allowing for comparison of new cases with typical (or atypical) presentations.
---
## Performance Metrics
Evaluated on **MIMIC-III**:
- **Admission Notes:** 48,745
- **Diagnosis Labels:** 1,266
Performance (approximate):
- **Macro ROC AUC:** ~87–88%
- **Micro ROC AUC:** ~97%
- **Macro PR AUC:** ~18–21%
The model shows particularly strong gains for rare diagnoses (less than 50 samples) when compared with baselines like PubMedBERT alone or hierarchical attention RNNs (e.g., HAN, HA-GRU).
Additionally, the model achieves high transferability on **i2b2** data (1,118 admission notes) across different clinical environments.
*Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.*
---
## Repository Structure
```plaintext
ProtoPatient/
├── proto_model/
│ ├── proto.py
│ ├── utils.py
│ ├── metrics.py
│ └── __init__.py
├── config.json
├── setup.py
├── model.safetensors
├── tokenizer.json
├── tokenizer_config.json
├── vocab.txt
├── README.md
└── .gitattributes
```
---
## How to Use the Model
### 1. Install Dependencies
```bash
git clone https://huggingface.co/row56/ProtoPatient
cd ProtoPatient
pip install -e . transformers torch safetensors
export TOKENIZERS_PARALLELISM=false
```
### 2. Load the Model via Hugging Face
```python
import os
import warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
import torch
from transformers import AutoTokenizer
from proto_model.configuration_proto import ProtoConfig
from proto_model.modeling_proto import ProtoForMultiLabelClassification
cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
cfg.pretrained_model_name_or_path = "bert-base-uncased"
cfg.use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if cfg.use_cuda else "cpu")
tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
model = ProtoForMultiLabelClassification.from_pretrained(
"row56/ProtoPatient",
config=cfg,
)
model.to(device)
model.eval()
def get_proto_logits(texts):
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
batch = {
"input_ids": enc["input_ids"],
"attention_masks": enc["attention_mask"],
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
}
with torch.no_grad():
logits, _ = model.proto_module(batch)
return logits
texts = [
"Patient shows elevated heart rate and low oxygen saturation.",
"No significant findings; patient is healthy."
]
logits = get_proto_logits(texts)
print("Logits shape:", logits.shape)
print("Logits:\n", logits)
```
## 3. Training Data & Licenses
This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement.
To obtain MIMIC-III:
Visit https://physionet.org/content/mimiciii/1.4/
Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training.
Sign the MIMIC-III Data Use Agreement (DUA).
Download the raw notes and run the preprocessing scripts from the paper’s repository.
Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license.
## 4. Load Precomputed Training Data for Prototype Retrieval
After you have MIMIC-III and have applied the published preprocessing, you should produce:
data/train_embeds.npy — NumPy array of shape (N, d) with per-example, per-class embeddings.
data/train_texts.json — JSON array of length N of the raw admission-note strings.
Place those in data/ and then:
```python
import numpy as np
import json
train_embeds = np.load("data/train_embeds.npy")
with open("data/train_texts.json", "r") as f:
train_texts = json.load(f)
print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
```
## 5. Interpreting Outputs & Retrieving Prototypes
```python
from sklearn.neighbors import NearestNeighbors
text = "Patient has chest pain and shortness of breath."
enc = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
batch = {
"input_ids": enc["input_ids"],
"attention_masks": enc["attention_mask"],
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
}
with torch.no_grad():
logits, metadata = model.proto_module(batch)
attn_scores = metadata["attentions"][0]
for label_id, scores in enumerate(attn_scores):
topk = sorted(zip(batch["tokens"][0], scores.tolist()),
key=lambda x: -x[1])[:5]
print(f"Label {label_id} top tokens:", topk)
proto_vecs = model.proto_module.prototype_vectors.cpu().numpy()
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
for label_id, u_c in enumerate(proto_vecs):
dist, idx = nn.kneighbors(u_c.reshape(1, -1))
print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):")
print(train_texts[idx[0][0]])
```
---
# Intended Use, Limitations & Ethical Considerations
## Intended Use
- **Research & Education:**
ProtoPatient is designed primarily for academic research and educational purposes in clinical NLP.
- **Interpretability Demonstration:**
The model demonstrates how prototype-based methods can provide interpretable multi-label classification on clinical admission notes.
---
## Limitations
- **Generalization:**
The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations.
- **Prototype Scope:**
The current version uses a single prototype per diagnosis, though some diagnoses might have multiple typical presentations—this is an area for future improvement.
- **Inter-diagnosis Relationships:**
The model does not explicitly model relationships (e.g., conflicts or comorbidities) between different diagnoses.
---
## Ethical & Regulatory Considerations
- **Not for Direct Clinical Use:**
This model is not intended for direct clinical decision-making. Always consult healthcare professionals.
- **Bias and Fairness:**
Users should be aware of potential biases in the training data; rare conditions might still be misclassified.
- **Patient Privacy:**
When applying the model to real clinical data, patient privacy must be strictly maintained.
---
# Example Interpretability Output
Based on the approach described in the paper (see Section 5 and Table 5):
- **Highlighted Tokens:**
Tokens such as “worst headache of her life,” “vomiting,” “fever,” and “infiltrate” strongly indicate specific diagnoses.
- **Prototypical Sample:**
A snippet from a training patient with similar text segments provides a rationale for the prediction.
*This interpretability output aids clinicians in understanding the model's reasoning – for example: "The system suggests intracerebral hemorrhage because the patient's note closely resembles typical cases with that diagnosis."*
---
# Recommended Citation
If you use ProtoPatient in your research, please cite:
```bibtex
@misc{vanaken2022this,
title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text},
author={van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander},
year={2022},
eprint={2210.08500},
archivePrefix={arXiv},
primaryClass={cs.CL}
}