Update README.md
Browse files
README.md
CHANGED
|
@@ -27,43 +27,131 @@ model-index:
|
|
| 27 |
---
|
| 28 |
|
| 29 |
|
| 30 |
-
|
| 31 |
# ProtoPatient Model for Multi-Label Classification
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
| 39 |
ProtoPatient/
|
| 40 |
-
|
| 41 |
│ ├── proto.py
|
| 42 |
│ ├── utils.py
|
| 43 |
│ ├── metrics.py
|
| 44 |
│ ├── __init__.py
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
### **1. Install Dependencies**
|
| 57 |
-
Ensure you have `transformers` and `torch` installed:
|
| 58 |
-
```bash
|
| 59 |
pip install transformers torch
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
model = AutoModel.from_pretrained("row56/ProtoPatient")
|
| 68 |
-
print("✅ Model with weights loaded successfully!")
|
| 69 |
-
```
|
|
|
|
| 27 |
---
|
| 28 |
|
| 29 |
|
|
|
|
| 30 |
# ProtoPatient Model for Multi-Label Classification
|
| 31 |
|
| 32 |
+
## Paper Reference:
|
| 33 |
+
|
| 34 |
+
van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.
|
| 35 |
+
"This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text."
|
| 36 |
+
(arXiv:2210.08500)
|
| 37 |
+
|
| 38 |
+
ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability:
|
| 39 |
+
- It highlights the most relevant tokens for each possible diagnosis.
|
| 40 |
+
- It retrieves prototypical patients from the training set who exhibit similar textual patterns, providing intuitive justifications to clinicians: “This patient looks like that patient.”
|
| 41 |
+
## Model Overview
|
| 42 |
+
### Prototype-Based Classification
|
| 43 |
+
#### Prototypical Vectors
|
| 44 |
+
- The model learns prototypical vectors (uc) for each prototypical diagnosis c.
|
| 45 |
+
#### Diagnosis-specific Representation
|
| 46 |
+
- A patient’s admission note is mapped (via a PubMedBERT encoder plus a linear compression layer) to a diagnosis-specific representation (), generated by a label-wise attention mechanism.
|
| 47 |
+
#### Classification Scores
|
| 48 |
+
- Classification scores are computed via the negative Euclidean distance between and , yielding a direct measure of “this note’s similarity to the learned prototype.”
|
| 49 |
+
### Label-Wise Attention
|
| 50 |
+
- For each diagnosis, a separate attention vector identifies relevant tokens in the admission note.
|
| 51 |
+
- This yields interpretability: the most “attended-to” tokens are presumably the evidence driving each diagnosis prediction.
|
| 52 |
+
### Interpretable Output
|
| 53 |
+
#### Token highlights:
|
| 54 |
+
- The top attended words, which often correlate with symptoms, risk factors, or diagnostic descriptors.
|
| 55 |
+
#### Prototypical Patients:
|
| 56 |
+
- The training examples closest to each prototype, exemplifying typical presentations of a diagnosis.
|
| 57 |
+
## Key Features and Benefits
|
| 58 |
+
#### Improved Performance on Rare Diagnoses:
|
| 59 |
+
- ProtoPatient leverages prototype-based learning, which has shown strong few-shot behavior, especially beneficial for diagnoses with very few samples.
|
| 60 |
+
#### Faithful Interpretations:
|
| 61 |
+
- A quantitative study (see paper, Section 5) shows that ProtoPatient’s attention-based highlights are more faithful to the model’s true decision process compared to post-hoc explainers (like Lime, Occlusion, or Grad-based methods).
|
| 62 |
+
#### Clinical Utility:
|
| 63 |
+
- Offers label-wise explanations to help clinicians quickly assess whether the system’s reasoning aligns with actual risk factors.
|
| 64 |
+
- Points out prototypical patients, allowing doctors to compare and contrast new admissions with typical (or atypical) presentations.
|
| 65 |
+
## Performance Metrics
|
| 66 |
+
Evaluated on MIMIC-III (48,745 admission notes, 1,266 diagnosis labels):
|
| 67 |
+
|
| 68 |
+
- Macro ROC AUC: ~87–88%
|
| 69 |
+
- Micro ROC AUC: ~97%
|
| 70 |
+
- Macro PR AUC: ~18–21%
|
| 71 |
+
|
| 72 |
+
Performance gains are particularly strong for rare diagnoses (fewer than 50 samples) compared to baselines such as PubMedBERT alone or hierarchical attention RNNs (HAN, HA-GRU).
|
| 73 |
|
| 74 |
+
Additionally tested on i2b2 data (1,118 admission notes), achieving high transferability across different clinical environments.
|
| 75 |
|
| 76 |
+
(Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.)
|
| 77 |
+
|
| 78 |
+
## Repository Structure
|
| 79 |
ProtoPatient/
|
| 80 |
+
├── proto_model/
|
| 81 |
│ ├── proto.py
|
| 82 |
│ ├── utils.py
|
| 83 |
│ ├── metrics.py
|
| 84 |
│ ├── __init__.py
|
| 85 |
+
├── config.json
|
| 86 |
+
├── model.safetensors
|
| 87 |
+
├── tokenizer.json
|
| 88 |
+
├── tokenizer_config.json
|
| 89 |
+
├── vocab.txt
|
| 90 |
+
├── README.md
|
| 91 |
+
└── .gitattributes
|
| 92 |
+
|
| 93 |
+
## How to Use the Model
|
| 94 |
+
### Install Dependencies
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
pip install transformers torch
|
| 96 |
+
Optionally, install safetensors if you want to load the .safetensors file.
|
| 97 |
+
|
| 98 |
+
### Load the Model via Hugging Face
|
| 99 |
+
from transformers import AutoTokenizer, AutoModel
|
| 100 |
+
|
| 101 |
+
repo_id = "row56/ProtoPatient"
|
| 102 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
| 103 |
+
model = AutoModel.from_pretrained(repo_id)
|
| 104 |
+
model.eval()
|
| 105 |
+
|
| 106 |
+
sample_text = "This patient presents with severe headaches and nausea..."
|
| 107 |
+
inputs = tokenizer(sample_text, return_tensors="pt")
|
| 108 |
+
outputs = model(**inputs)
|
| 109 |
+
print("Output shape:", outputs.last_hidden_state.shape)
|
| 110 |
+
|
| 111 |
+
### Interpreting Outputs
|
| 112 |
+
For a prototypical classification approach, you would generally use the custom modules in proto_model/ (e.g., ProtoForMultiLabelClassification) and check which tokens are highly attended per label, as well as which “prototype patients” are most similar.
|
| 113 |
+
If you’re using the standard AutoModel, you can still get the raw embeddings, but you will need the custom code for label-wise attention and prototype retrieval.
|
| 114 |
+
|
| 115 |
+
### (Optional) Hugging Face Pipelines
|
| 116 |
+
You can integrate the model into a pipeline (e.g., feature extraction) to simplify usage:
|
| 117 |
+
from transformers import pipeline
|
| 118 |
+
extractor = pipeline("feature-extraction", model=repo_id, tokenizer=repo_id)
|
| 119 |
+
embeddings = extractor("Severe headaches and vomiting...")
|
| 120 |
+
print(len(embeddings), len(embeddings[0])) # Token-level features
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
## Intended Use, Limitations & Ethical Considerations
|
| 124 |
+
### Intended Use:
|
| 125 |
+
- ProtoPatient is primarily for research and education in clinical NLP.
|
| 126 |
+
- It demonstrates how to leverage prototype-based interpretability for multi-label classification on admission notes.
|
| 127 |
+
### Limitations:
|
| 128 |
+
- The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations.
|
| 129 |
+
- It considers only one prototype per diagnosis in the currently released version; some diagnoses may have multiple typical presentations, which is an area for future research.
|
| 130 |
+
- It does not explicitly model inter-diagnosis relationships (e.g., conflicts or comorbidities).
|
| 131 |
+
### Ethical & Regulatory:
|
| 132 |
+
- This model is not intended for direct clinical use. Always consult healthcare professionals for medical decisions.
|
| 133 |
+
- Users must be aware of potential biases in the training data. Rare conditions could still be misclassified despite improvements.
|
| 134 |
+
- Patient privacy must be strictly maintained if applying to real hospital data.
|
| 135 |
+
|
| 136 |
+
## Example Interpretability Output
|
| 137 |
+
Based on the approach in the paper (Section 5 and Table 5 there):
|
| 138 |
+
|
| 139 |
+
- Highlighted tokens: Terms that strongly indicate a certain diagnosis (e.g., “worst headache of her life,” “vomiting,” “fever,” “infiltrate,” etc.).
|
| 140 |
+
- Prototypical sample: A snippet from a training patient with very similar text segments (e.g., describing similar symptoms, risk factors, or diagnoses).
|
| 141 |
+
This provides clinicians with rationales: “The system thinks your patient has intracerebral hemorrhage because they exhibit text segments similar to a previous patient who had that diagnosis.”
|
| 142 |
+
|
| 143 |
+
## Recommended Citation
|
| 144 |
+
If you use ProtoPatient in your research, please cite:
|
| 145 |
|
| 146 |
+
@misc{vanaken2022this,
|
| 147 |
+
title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text},
|
| 148 |
+
author={
|
| 149 |
+
van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and
|
| 150 |
+
Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander
|
| 151 |
+
},
|
| 152 |
+
year={2022},
|
| 153 |
+
eprint={2210.08500},
|
| 154 |
+
archivePrefix={arXiv},
|
| 155 |
+
primaryClass={cs.CL}
|
| 156 |
+
}
|
| 157 |
|
|
|
|
|
|
|
|
|