Update README.md
Browse files
README.md
CHANGED
|
@@ -29,129 +29,74 @@ model-index:
|
|
| 29 |
|
| 30 |
# ProtoPatient Model for Multi-Label Classification
|
| 31 |
|
| 32 |
-
## Paper Reference
|
| 33 |
|
| 34 |
-
van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability:
|
| 39 |
-
- It highlights the most relevant tokens for each possible diagnosis.
|
| 40 |
-
- It retrieves prototypical patients from the training set who exhibit similar textual patterns, providing intuitive justifications to clinicians: “This patient looks like that patient.”
|
| 41 |
## Model Overview
|
|
|
|
| 42 |
### Prototype-Based Classification
|
| 43 |
-
|
| 44 |
-
- The model learns prototypical vectors (
|
| 45 |
-
|
| 46 |
-
-
|
| 47 |
-
|
| 48 |
-
- Classification scores are computed via the negative Euclidean distance between and , yielding a direct measure of “this note’s similarity to the learned prototype.”
|
| 49 |
### Label-Wise Attention
|
|
|
|
| 50 |
- For each diagnosis, a separate attention vector identifies relevant tokens in the admission note.
|
| 51 |
-
- This
|
|
|
|
| 52 |
### Interpretable Output
|
| 53 |
-
|
| 54 |
-
- The top attended words
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
## Key Features and Benefits
|
| 58 |
-
|
| 59 |
-
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
## Performance Metrics
|
| 66 |
-
Evaluated on MIMIC-III (48,745 admission notes, 1,266 diagnosis labels):
|
| 67 |
|
| 68 |
-
|
| 69 |
-
-
|
| 70 |
-
-
|
| 71 |
|
| 72 |
-
Performance
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
## Repository Structure
|
| 79 |
-
ProtoPatient/
|
| 80 |
-
├── proto_model/
|
| 81 |
-
│ ├── proto.py
|
| 82 |
-
│ ├── utils.py
|
| 83 |
-
│ ├── metrics.py
|
| 84 |
-
│ ├── __init__.py
|
| 85 |
-
├── config.json
|
| 86 |
-
├── model.safetensors
|
| 87 |
-
├── tokenizer.json
|
| 88 |
-
├── tokenizer_config.json
|
| 89 |
-
├── vocab.txt
|
| 90 |
-
├── README.md
|
| 91 |
-
└── .gitattributes
|
| 92 |
-
|
| 93 |
-
## How to Use the Model
|
| 94 |
-
### Install Dependencies
|
| 95 |
-
pip install transformers torch
|
| 96 |
-
Optionally, install safetensors if you want to load the .safetensors file.
|
| 97 |
-
|
| 98 |
-
### Load the Model via Hugging Face
|
| 99 |
-
from transformers import AutoTokenizer, AutoModel
|
| 100 |
-
|
| 101 |
-
repo_id = "row56/ProtoPatient"
|
| 102 |
-
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
| 103 |
-
model = AutoModel.from_pretrained(repo_id)
|
| 104 |
-
model.eval()
|
| 105 |
-
|
| 106 |
-
sample_text = "This patient presents with severe headaches and nausea..."
|
| 107 |
-
inputs = tokenizer(sample_text, return_tensors="pt")
|
| 108 |
-
outputs = model(**inputs)
|
| 109 |
-
print("Output shape:", outputs.last_hidden_state.shape)
|
| 110 |
-
|
| 111 |
-
### Interpreting Outputs
|
| 112 |
-
For a prototypical classification approach, you would generally use the custom modules in proto_model/ (e.g., ProtoForMultiLabelClassification) and check which tokens are highly attended per label, as well as which “prototype patients” are most similar.
|
| 113 |
-
If you’re using the standard AutoModel, you can still get the raw embeddings, but you will need the custom code for label-wise attention and prototype retrieval.
|
| 114 |
-
|
| 115 |
-
### (Optional) Hugging Face Pipelines
|
| 116 |
-
You can integrate the model into a pipeline (e.g., feature extraction) to simplify usage:
|
| 117 |
-
from transformers import pipeline
|
| 118 |
-
extractor = pipeline("feature-extraction", model=repo_id, tokenizer=repo_id)
|
| 119 |
-
embeddings = extractor("Severe headaches and vomiting...")
|
| 120 |
-
print(len(embeddings), len(embeddings[0])) # Token-level features
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
## Intended Use, Limitations & Ethical Considerations
|
| 124 |
-
### Intended Use:
|
| 125 |
-
- ProtoPatient is primarily for research and education in clinical NLP.
|
| 126 |
-
- It demonstrates how to leverage prototype-based interpretability for multi-label classification on admission notes.
|
| 127 |
-
### Limitations:
|
| 128 |
-
- The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations.
|
| 129 |
-
- It considers only one prototype per diagnosis in the currently released version; some diagnoses may have multiple typical presentations, which is an area for future research.
|
| 130 |
-
- It does not explicitly model inter-diagnosis relationships (e.g., conflicts or comorbidities).
|
| 131 |
-
### Ethical & Regulatory:
|
| 132 |
-
- This model is not intended for direct clinical use. Always consult healthcare professionals for medical decisions.
|
| 133 |
-
- Users must be aware of potential biases in the training data. Rare conditions could still be misclassified despite improvements.
|
| 134 |
-
- Patient privacy must be strictly maintained if applying to real hospital data.
|
| 135 |
-
|
| 136 |
-
## Example Interpretability Output
|
| 137 |
-
Based on the approach in the paper (Section 5 and Table 5 there):
|
| 138 |
-
|
| 139 |
-
- Highlighted tokens: Terms that strongly indicate a certain diagnosis (e.g., “worst headache of her life,” “vomiting,” “fever,” “infiltrate,” etc.).
|
| 140 |
-
- Prototypical sample: A snippet from a training patient with very similar text segments (e.g., describing similar symptoms, risk factors, or diagnoses).
|
| 141 |
-
This provides clinicians with rationales: “The system thinks your patient has intracerebral hemorrhage because they exhibit text segments similar to a previous patient who had that diagnosis.”
|
| 142 |
-
|
| 143 |
-
## Recommended Citation
|
| 144 |
-
If you use ProtoPatient in your research, please cite:
|
| 145 |
-
|
| 146 |
-
@misc{vanaken2022this,
|
| 147 |
-
title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text},
|
| 148 |
-
author={
|
| 149 |
-
van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and
|
| 150 |
-
Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander
|
| 151 |
-
},
|
| 152 |
-
year={2022},
|
| 153 |
-
eprint={2210.08500},
|
| 154 |
-
archivePrefix={arXiv},
|
| 155 |
-
primaryClass={cs.CL}
|
| 156 |
-
}
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# ProtoPatient Model for Multi-Label Classification
|
| 31 |
|
| 32 |
+
## Paper Reference
|
| 33 |
|
| 34 |
+
**van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.**
|
| 35 |
+
*This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text.*
|
| 36 |
+
[arXiv:2210.08500](https://arxiv.org/abs/2210.08500)
|
| 37 |
+
|
| 38 |
+
ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability by:
|
| 39 |
+
|
| 40 |
+
- **Highlighting Relevant Tokens:** Shows the most important words for each possible diagnosis.
|
| 41 |
+
- **Retrieving Prototypical Patients:** Finds training examples with similar textual patterns to provide intuitive justifications for clinicians—essentially answering, “This patient looks like that patient.”
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
## Model Overview
|
| 46 |
+
|
| 47 |
### Prototype-Based Classification
|
| 48 |
+
|
| 49 |
+
- The model learns **prototypical vectors** (\(u_c\)) for each diagnosis \(c\).
|
| 50 |
+
- A patient’s admission note is encoded via a PubMedBERT encoder and a linear compression layer into a diagnosis-specific representation (\(v_{p,c}\)). This representation is generated using a label-wise attention mechanism.
|
| 51 |
+
- Classification scores are computed as the **negative Euclidean distance** between \(v_{p,c}\) and \(u_c\), which directly measures the note’s similarity to the learned prototype.
|
| 52 |
+
|
|
|
|
| 53 |
### Label-Wise Attention
|
| 54 |
+
|
| 55 |
- For each diagnosis, a separate attention vector identifies relevant tokens in the admission note.
|
| 56 |
+
- This mechanism provides interpretability by indicating which tokens are most influential in driving each prediction.
|
| 57 |
+
|
| 58 |
### Interpretable Output
|
| 59 |
+
|
| 60 |
+
- **Token Highlights:** The top attended words (often correlating with symptoms, risk factors, or diagnostic descriptors).
|
| 61 |
+
- **Prototypical Patients:** Examples from the training set that are closest to each prototype, representing typical presentations of a diagnosis.
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
## Key Features and Benefits
|
| 66 |
+
|
| 67 |
+
- **Improved Performance on Rare Diagnoses:**
|
| 68 |
+
Prototype-based learning has strong few-shot capabilities, which is especially beneficial for diagnoses with very few samples.
|
| 69 |
+
|
| 70 |
+
- **Faithful Interpretations:**
|
| 71 |
+
Quantitative evaluations (see Section 5 in the paper) indicate that the attention-based highlights are more faithful to the model’s decision process compared to post-hoc methods such as Lime, Occlusion, and gradient-based approaches.
|
| 72 |
+
|
| 73 |
+
- **Clinical Utility:**
|
| 74 |
+
- Provides label-wise explanations to help clinicians assess whether the predictions align with actual risk factors.
|
| 75 |
+
- Points to prototypical patients, allowing for comparison of new cases with typical (or atypical) presentations.
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
## Performance Metrics
|
|
|
|
| 80 |
|
| 81 |
+
Evaluated on **MIMIC-III**:
|
| 82 |
+
- **Admission Notes:** 48,745
|
| 83 |
+
- **Diagnosis Labels:** 1,266
|
| 84 |
|
| 85 |
+
Performance (approximate):
|
| 86 |
+
- **Macro ROC AUC:** ~87–88%
|
| 87 |
+
- **Micro ROC AUC:** ~97%
|
| 88 |
+
- **Macro PR AUC:** ~18–21%
|
| 89 |
|
| 90 |
+
The model shows particularly strong gains for rare diagnoses (less than 50 samples) when compared with baselines like PubMedBERT alone or hierarchical attention RNNs (e.g., HAN, HA-GRU).
|
| 91 |
|
| 92 |
+
Additionally, the model achieves high transferability on **i2b2** data (1,118 admission notes) across different clinical environments.
|
| 93 |
+
|
| 94 |
+
*Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.*
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
|
| 98 |
## Repository Structure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
ProtoPatient/ ├── proto_model/ │ ├── proto.py │ ├── utils.py │ ├── metrics.py │ └── init.py ├── config.json ├── model.safetensors ├── tokenizer.json ├── tokenizer_config.json ├── vocab.txt ├── README.md └── .gitattributes
|