Upload sproto model
Browse files- LICENSE +105 -0
- README.md +256 -0
- config.json +33 -0
- configuration_sproto.py +61 -0
- model.safetensors +3 -0
- modeling_sproto.py +77 -0
- overview.png +0 -0
LICENSE
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity.
|
| 18 |
+
|
| 19 |
+
"You" shall mean an individual or Legal Entity exercising permissions
|
| 20 |
+
granted by this License.
|
| 21 |
+
|
| 22 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 23 |
+
including but not limited to software source code, documentation
|
| 24 |
+
source, and configuration files.
|
| 25 |
+
|
| 26 |
+
"Object" form shall mean any form resulting from mechanical
|
| 27 |
+
transformation or translation of a Source form, including but
|
| 28 |
+
not limited to compiled object code, generated documentation,
|
| 29 |
+
and conversions to other media types.
|
| 30 |
+
|
| 31 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 32 |
+
Object form, made available under the License, as indicated by a
|
| 33 |
+
copyright notice that is included in or attached to the work
|
| 34 |
+
(an example is provided in the Appendix below).
|
| 35 |
+
|
| 36 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 37 |
+
form, that is based on or derived from the Work and for which the
|
| 38 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 39 |
+
represent, as a whole, an original work of authorship.
|
| 40 |
+
|
| 41 |
+
"Contribution" shall mean any work of authorship, including
|
| 42 |
+
the original version of the Work and any modifications or additions
|
| 43 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 44 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 45 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 46 |
+
the copyright owner.
|
| 47 |
+
|
| 48 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 49 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 50 |
+
subsequently incorporated within the Work.
|
| 51 |
+
|
| 52 |
+
2. Grant of Copyright License.
|
| 53 |
+
|
| 54 |
+
Subject to the terms and conditions of this License, each Contributor
|
| 55 |
+
hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
|
| 56 |
+
royalty-free, irrevocable copyright license to reproduce, prepare
|
| 57 |
+
Derivative Works of, publicly display, publicly perform, sublicense,
|
| 58 |
+
and distribute the Work and such Derivative Works in Source or Object
|
| 59 |
+
form.
|
| 60 |
+
|
| 61 |
+
3. Grant of Patent License.
|
| 62 |
+
|
| 63 |
+
Subject to the terms and conditions of this License, each Contributor
|
| 64 |
+
hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
|
| 65 |
+
royalty-free, irrevocable patent license to make, have made, use,
|
| 66 |
+
offer to sell, sell, import, and otherwise transfer the Work.
|
| 67 |
+
|
| 68 |
+
4. Redistribution.
|
| 69 |
+
|
| 70 |
+
You may reproduce and distribute copies of the Work or Derivative
|
| 71 |
+
Works thereof in any medium, with or without modifications, and in
|
| 72 |
+
Source or Object form, provided that You meet the following conditions:
|
| 73 |
+
|
| 74 |
+
You must give any other recipients of the Work a copy of this License
|
| 75 |
+
and you must cause any modified files to carry prominent notices
|
| 76 |
+
stating that You changed the files.
|
| 77 |
+
|
| 78 |
+
5. Submission of Contributions.
|
| 79 |
+
|
| 80 |
+
Unless You explicitly state otherwise, any Contribution intentionally
|
| 81 |
+
submitted for inclusion in the Work shall be under the terms of this
|
| 82 |
+
License.
|
| 83 |
+
|
| 84 |
+
6. Trademarks.
|
| 85 |
+
|
| 86 |
+
This License does not grant permission to use the trade names,
|
| 87 |
+
trademarks, service marks, or product names of the Licensor.
|
| 88 |
+
|
| 89 |
+
7. Disclaimer of Warranty.
|
| 90 |
+
|
| 91 |
+
The Work is provided on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
| 92 |
+
CONDITIONS OF ANY KIND.
|
| 93 |
+
|
| 94 |
+
8. Limitation of Liability.
|
| 95 |
+
|
| 96 |
+
In no event and under no legal theory shall any Contributor be liable
|
| 97 |
+
for damages arising from the use of the Work.
|
| 98 |
+
|
| 99 |
+
9. Accepting Warranty or Additional Liability.
|
| 100 |
+
|
| 101 |
+
While redistributing the Work, You may choose to offer support or
|
| 102 |
+
warranty obligations, but You may not impose such obligations on
|
| 103 |
+
Contributors.
|
| 104 |
+
|
| 105 |
+
END OF TERMS AND CONDITIONS
|
README.md
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
library_name: transformers
|
| 5 |
+
|
| 6 |
+
pipeline_tag: text-classification
|
| 7 |
+
task_categories:
|
| 8 |
+
- text-classification
|
| 9 |
+
|
| 10 |
+
model_type: sproto
|
| 11 |
+
base_model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
|
| 12 |
+
|
| 13 |
+
datasets:
|
| 14 |
+
- mimic-iv
|
| 15 |
+
|
| 16 |
+
metrics:
|
| 17 |
+
- auroc
|
| 18 |
+
- pr-auc
|
| 19 |
+
|
| 20 |
+
tags:
|
| 21 |
+
- text-classification
|
| 22 |
+
- multi-label-classification
|
| 23 |
+
- long-tail-learning
|
| 24 |
+
- medical
|
| 25 |
+
- clinical-nlp
|
| 26 |
+
- interpretability
|
| 27 |
+
- prototypical-networks
|
| 28 |
+
- ehr
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
# S-Proto: Sparse Prototypical Networks for Long-Tail Clinical Diagnosis Prediction
|
| 32 |
+
|
| 33 |
+

|
| 34 |
+
|
| 35 |
+
This repository provides **S-Proto**, a sparse and interpretable prototypical network for extreme multi-label diagnosis prediction from clinical text. The model is designed to address the long-tail distribution of clinical diagnoses while preserving faithful, prototype-based explanations.
|
| 36 |
+
|
| 37 |
+
## Interactive Demo
|
| 38 |
+
|
| 39 |
+
You can explore the model's predictions and interpretability features through our interactive web demo:
|
| 40 |
+
**[https://s-proto.demo.datexis.com/](https://s-proto.demo.datexis.com/)**
|
| 41 |
+
|
| 42 |
+
S-Proto was introduced in the paper:
|
| 43 |
+
|
| 44 |
+
**[Boosting Long-Tail Data Classification with Sparse Prototypical Networks](https://ecmlpkdd-storage.s3.eu-central-1.amazonaws.com/preprints/2024/lncs14947/lncs14947435.pdf)**
|
| 45 |
+
Alexei Figueroa*, Jens-Michalis Papaioannou*, et al.
|
| 46 |
+
DATEXIS, Berliner Hochschule für Technik, Feinstein Institutes, TU Munich, Leibniz University Hannover
|
| 47 |
+
(* equal contribution)
|
| 48 |
+
|
| 49 |
+
## Overview
|
| 50 |
+
|
| 51 |
+
Clinical outcome prediction from Electronic Health Records is characterized by extreme label imbalance. A small number of diagnoses account for most patients, while the majority of diagnoses appear rarely. Standard transformer classifiers tend to perform well on frequent diagnoses but degrade sharply in the long tail.
|
| 52 |
+
|
| 53 |
+
S-Proto addresses this problem by extending prototypical networks with:
|
| 54 |
+
|
| 55 |
+
- Multiple prototypes per diagnosis
|
| 56 |
+
- Sparse winner-takes-all activation
|
| 57 |
+
- Prototype-level interpretability
|
| 58 |
+
- Efficient training despite increased representational capacity
|
| 59 |
+
|
| 60 |
+
The model achieves state-of-the-art performance on MIMIC-IV diagnosis prediction, with particularly strong gains in PR-AUC for rare diagnoses, and transfers successfully to unseen clinical datasets.
|
| 61 |
+
|
| 62 |
+
## Model Architecture
|
| 63 |
+
|
| 64 |
+
S-Proto builds on **PubMedBERT** as the text encoder and introduces a sparse prototypical layer on top.
|
| 65 |
+
|
| 66 |
+
For each diagnosis label, the model learns multiple sub-networks, each consisting of:
|
| 67 |
+
|
| 68 |
+
- A label-specific attention vector
|
| 69 |
+
- A prototype vector representing a prototypical patient
|
| 70 |
+
|
| 71 |
+
Given an input clinical note:
|
| 72 |
+
|
| 73 |
+
1. The note is encoded using PubMedBERT
|
| 74 |
+
2. Token embeddings are projected into a latent space
|
| 75 |
+
3. Each diagnosis activates multiple candidate sub-networks
|
| 76 |
+
4. A winner-takes-all mechanism selects the single most relevant sub-network per diagnosis
|
| 77 |
+
5. Only the winning prototype contributes to the prediction and receives gradient updates
|
| 78 |
+
|
| 79 |
+
This allows S-Proto to model heterogeneous disease phenotypes while remaining sparse and efficient.
|
| 80 |
+
|
| 81 |
+
## Intended Use
|
| 82 |
+
|
| 83 |
+
This model is intended for:
|
| 84 |
+
|
| 85 |
+
- Clinical diagnosis prediction from admission notes
|
| 86 |
+
- Research on long-tail learning in healthcare NLP
|
| 87 |
+
- Interpretable clinical decision support systems
|
| 88 |
+
- Analysis of disease phenotypes via learned prototypes
|
| 89 |
+
|
| 90 |
+
This model is **not intended for direct clinical deployment** without external validation, auditing, and regulatory approval.
|
| 91 |
+
|
| 92 |
+
## Inference Example
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
from transformers import AutoTokenizer, AutoModel
|
| 96 |
+
import torch
|
| 97 |
+
|
| 98 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 99 |
+
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
| 100 |
+
)
|
| 101 |
+
model = AutoModel.from_pretrained(
|
| 102 |
+
"datexis/sproto",
|
| 103 |
+
trust_remote_code=True
|
| 104 |
+
)
|
| 105 |
+
model.eval()
|
| 106 |
+
|
| 107 |
+
text_input = [
|
| 108 |
+
"CHIEF COMPLAINT: Right Carotid Artery Stenosis. "
|
| 109 |
+
"PRESENT ILLNESS: Ms. ___ is a ___ year old woman with hyperlipidemia, "
|
| 110 |
+
"cirrhosis with esophageal varices, alcoholism, COPD, left eye blindness, "
|
| 111 |
+
"and right carotid stenosis status post right carotid endarterectomy."
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
inputs = tokenizer(
|
| 115 |
+
text_input,
|
| 116 |
+
padding=True,
|
| 117 |
+
truncation=True,
|
| 118 |
+
max_length=512,
|
| 119 |
+
return_tensors="pt"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs["input_ids"]]
|
| 123 |
+
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
output = model(
|
| 126 |
+
input_ids=inputs["input_ids"],
|
| 127 |
+
attention_mask=inputs["attention_mask"],
|
| 128 |
+
token_type_ids=inputs.get("token_type_ids"),
|
| 129 |
+
tokens=tokens
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
logits = output["logits"]
|
| 133 |
+
max_indices = output["max_indices"]
|
| 134 |
+
metadata = output["metadata"]
|
| 135 |
+
|
| 136 |
+
print("Inference successful")
|
| 137 |
+
print("Logits shape:", logits.shape)
|
| 138 |
+
print("Max indices:", max_indices)
|
| 139 |
+
print("Metadata:", metadata)
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
## Outputs
|
| 143 |
+
|
| 144 |
+
The model returns a dictionary with the following entries:
|
| 145 |
+
|
| 146 |
+
- **logits**
|
| 147 |
+
Prediction scores per diagnosis label.
|
| 148 |
+
|
| 149 |
+
- **max_indices**
|
| 150 |
+
Index of the winning prototype sub-network per diagnosis, corresponding to the selected prototype.
|
| 151 |
+
|
| 152 |
+
- **metadata**
|
| 153 |
+
Additional information useful for analysis and interpretability.
|
| 154 |
+
|
| 155 |
+
## Explainability
|
| 156 |
+
|
| 157 |
+
S-Proto provides built-in faithful explanations through its prototypical structure:
|
| 158 |
+
|
| 159 |
+
- Attention vectors highlight clinically relevant tokens
|
| 160 |
+
- Prototype distances reflect similarity to prototypical patients
|
| 161 |
+
- Multiple prototypes per diagnosis capture disease subtypes and cohorts
|
| 162 |
+
- Faithfulness metrics remain comparable to ProtoPatient despite higher capacity
|
| 163 |
+
|
| 164 |
+
Qualitative evaluation with medical professionals confirms that learned prototypes often correspond to clinically meaningful phenotypes.
|
| 165 |
+
|
| 166 |
+
## Training
|
| 167 |
+
|
| 168 |
+
First, clone the repository:
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
git clone https://github.com/DATEXIS/sproto.git
|
| 172 |
+
cd sproto
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
Set up the environment using Poetry:
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
poetry install
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
Activate the virtual environment:
|
| 182 |
+
|
| 183 |
+
```bash
|
| 184 |
+
poetry env activate
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
Once the environment is active, you can start training by running the train command with the desired arguments.
|
| 188 |
+
|
| 189 |
+
Example:
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
train \
|
| 193 |
+
--batch_size 3 \
|
| 194 |
+
--pretrained_model microsoft/biomednlp-pubmedbert-base-uncased-abstract-fulltext \
|
| 195 |
+
--pretrained_model_path path_to_pretrained_model.ckpt \
|
| 196 |
+
--model_type MULTI_PROTO \
|
| 197 |
+
--train_file training_data.csv \
|
| 198 |
+
--val_file validation_data.csv \
|
| 199 |
+
--test_file test_data.csv \
|
| 200 |
+
--save_dir ../experiments/ \
|
| 201 |
+
--gpus 1 \
|
| 202 |
+
--check_val_every_n_epoch 2 \
|
| 203 |
+
--num_warmup_steps 0 \
|
| 204 |
+
--num_training_steps 50 \
|
| 205 |
+
--max_length 512 \
|
| 206 |
+
--lr_features 0.000005 \
|
| 207 |
+
--lr_prototypes 0.001 \
|
| 208 |
+
--lr_others 0.001 \
|
| 209 |
+
--num_val_samples None \
|
| 210 |
+
--use_attention True \
|
| 211 |
+
--reduce_hidden_size 256 \
|
| 212 |
+
--all_labels_path all_labels.pcl \
|
| 213 |
+
--seed 42 \
|
| 214 |
+
--label_column labels \
|
| 215 |
+
--metric_opt auroc_macro \
|
| 216 |
+
--train_files [] \
|
| 217 |
+
--val_files [] \
|
| 218 |
+
--only_test True \
|
| 219 |
+
--model_name 5p \
|
| 220 |
+
--store_metadata False \
|
| 221 |
+
--num_prototypes_per_class 5
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
## Citation
|
| 225 |
+
|
| 226 |
+
```bibtex
|
| 227 |
+
@inproceedings{figueroa2024sproto,
|
| 228 |
+
title={Boosting Long-Tail Data Classification with Sparse Prototypical Networks},
|
| 229 |
+
author={Figueroa, Alexei and Papaioannou, Jens-Michalis and Fallon, Conor and Bekiaridou, Alexandra and Bressem, Keno and Zanos, Stavros and Gers, Felix and Nejdl, Wolfgang and Löser, Alexander},
|
| 230 |
+
booktitle={Proceedings of the Conference on Empirical Methods in Natural Language Processing},
|
| 231 |
+
year={2024}
|
| 232 |
+
}
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
## License
|
| 236 |
+
|
| 237 |
+
This model and its associated code are released under the Apache License 2.0.
|
| 238 |
+
|
| 239 |
+
The model was trained on the MIMIC-IV dataset, which is subject to restricted access. No training data is included or redistributed with this repository.
|
| 240 |
+
The data were accessed under a data use agreement. No patient-identifiable information is shared.
|
| 241 |
+
|
| 242 |
+
Use of this model must comply with all applicable data governance and ethical guidelines.
|
| 243 |
+
|
| 244 |
+
### Limitations
|
| 245 |
+
|
| 246 |
+
- Extremely rare diagnoses remain challenging
|
| 247 |
+
- Clinical dataset biases may be reflected in predictions
|
| 248 |
+
- Winner-takes-all selection is fixed and not learned dynamically
|
| 249 |
+
- Not validated for real-world clinical deployment
|
| 250 |
+
|
| 251 |
+
### Ethical Considerations
|
| 252 |
+
|
| 253 |
+
- The model processes sensitive clinical text
|
| 254 |
+
- Predictions should always be reviewed by qualified professionals
|
| 255 |
+
- Outputs should not be used as sole evidence for clinical decisions
|
| 256 |
+
- Care must be taken to avoid reinforcing existing healthcare biases
|
config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_vector_path": null,
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoConfig": "configuration_sproto.SprotoConfig",
|
| 5 |
+
"AutoModel": "modeling_sproto.SprotoModel"
|
| 6 |
+
},
|
| 7 |
+
"batch_size": 21,
|
| 8 |
+
"dot_product": false,
|
| 9 |
+
"eval_buckets": null,
|
| 10 |
+
"final_layer": false,
|
| 11 |
+
"label_order_path": "/pvc/shared/continual/data/icd_10_all_labels_admission_mimiciv_dia.pcl",
|
| 12 |
+
"loss": "BCE",
|
| 13 |
+
"lr_features": 5e-06,
|
| 14 |
+
"lr_others": 0.001,
|
| 15 |
+
"lr_prototypes": 0.001,
|
| 16 |
+
"model_type": "sproto",
|
| 17 |
+
"normalize": null,
|
| 18 |
+
"num_classes": 1643,
|
| 19 |
+
"num_prototypes_per_class": 5,
|
| 20 |
+
"num_training_steps": 5000,
|
| 21 |
+
"num_warmup_steps": 5000,
|
| 22 |
+
"pretrained_model": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
|
| 23 |
+
"prototype_vector_path": null,
|
| 24 |
+
"reduce_hidden_size": 256,
|
| 25 |
+
"save_dir": "/pvc/shared/continual/experiments/mimiciv/icd10_clinical-continual-5p-test",
|
| 26 |
+
"seed": 28,
|
| 27 |
+
"transformers_version": "4.25.1",
|
| 28 |
+
"use_attention": true,
|
| 29 |
+
"use_cuda": true,
|
| 30 |
+
"use_global_attention": false,
|
| 31 |
+
"use_prototype_loss": false,
|
| 32 |
+
"use_sigmoid": false
|
| 33 |
+
}
|
configuration_sproto.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class SprotoConfig(PretrainedConfig):
|
| 4 |
+
model_type = "sproto"
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
pretrained_model=None,
|
| 9 |
+
num_classes=None,
|
| 10 |
+
label_order_path=None,
|
| 11 |
+
use_sigmoid=False,
|
| 12 |
+
use_cuda=True,
|
| 13 |
+
lr_prototypes=5e-2,
|
| 14 |
+
lr_features=2e-6,
|
| 15 |
+
lr_others=2e-2,
|
| 16 |
+
num_training_steps=5000,
|
| 17 |
+
num_warmup_steps=1000,
|
| 18 |
+
loss="BCE",
|
| 19 |
+
save_dir="output",
|
| 20 |
+
use_attention=True,
|
| 21 |
+
use_global_attention=False,
|
| 22 |
+
dot_product=False,
|
| 23 |
+
normalize=None,
|
| 24 |
+
final_layer=False,
|
| 25 |
+
reduce_hidden_size=None,
|
| 26 |
+
use_prototype_loss=False,
|
| 27 |
+
prototype_vector_path=None,
|
| 28 |
+
attention_vector_path=None,
|
| 29 |
+
eval_buckets=None,
|
| 30 |
+
seed=7,
|
| 31 |
+
num_prototypes_per_class=1,
|
| 32 |
+
batch_size=10,
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
super().__init__(**kwargs)
|
| 36 |
+
|
| 37 |
+
self.pretrained_model = pretrained_model
|
| 38 |
+
self.num_classes = num_classes
|
| 39 |
+
self.label_order_path = label_order_path
|
| 40 |
+
self.use_sigmoid = use_sigmoid
|
| 41 |
+
self.use_cuda = use_cuda
|
| 42 |
+
self.lr_prototypes = lr_prototypes
|
| 43 |
+
self.lr_features = lr_features
|
| 44 |
+
self.lr_others = lr_others
|
| 45 |
+
self.num_training_steps = num_training_steps
|
| 46 |
+
self.num_warmup_steps = num_warmup_steps
|
| 47 |
+
self.loss = loss
|
| 48 |
+
self.save_dir = save_dir
|
| 49 |
+
self.use_attention = use_attention
|
| 50 |
+
self.use_global_attention = use_global_attention
|
| 51 |
+
self.dot_product = dot_product
|
| 52 |
+
self.normalize = normalize
|
| 53 |
+
self.final_layer = final_layer
|
| 54 |
+
self.reduce_hidden_size = reduce_hidden_size
|
| 55 |
+
self.use_prototype_loss = use_prototype_loss
|
| 56 |
+
self.prototype_vector_path = prototype_vector_path
|
| 57 |
+
self.attention_vector_path = attention_vector_path
|
| 58 |
+
self.eval_buckets = eval_buckets
|
| 59 |
+
self.seed = seed
|
| 60 |
+
self.num_prototypes_per_class = num_prototypes_per_class
|
| 61 |
+
self.batch_size = batch_size
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ef1e86215368bfcbb723cb3a28d2c343927f154decc09527cce2093326a07fd2
|
| 3 |
+
size 455575332
|
modeling_sproto.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel
|
| 2 |
+
from sproto.model.multi_proto import MultiProtoModule
|
| 3 |
+
from .configuration_sproto import SprotoConfig
|
| 4 |
+
|
| 5 |
+
class SprotoModel(PreTrainedModel):
|
| 6 |
+
config_class = SprotoConfig
|
| 7 |
+
base_model_prefix = "sproto"
|
| 8 |
+
|
| 9 |
+
def __init__(self, config: SprotoConfig):
|
| 10 |
+
super().__init__(config)
|
| 11 |
+
|
| 12 |
+
self.module = MultiProtoModule(
|
| 13 |
+
pretrained_model=config.pretrained_model,
|
| 14 |
+
num_classes=config.num_classes,
|
| 15 |
+
label_order_path=config.label_order_path,
|
| 16 |
+
use_sigmoid=config.use_sigmoid,
|
| 17 |
+
use_cuda=config.use_cuda,
|
| 18 |
+
lr_prototypes=config.lr_prototypes,
|
| 19 |
+
lr_features=config.lr_features,
|
| 20 |
+
lr_others=config.lr_others,
|
| 21 |
+
num_training_steps=config.num_training_steps,
|
| 22 |
+
num_warmup_steps=config.num_warmup_steps,
|
| 23 |
+
loss=config.loss,
|
| 24 |
+
save_dir=config.save_dir,
|
| 25 |
+
use_attention=config.use_attention,
|
| 26 |
+
use_global_attention=config.use_global_attention,
|
| 27 |
+
dot_product=config.dot_product,
|
| 28 |
+
normalize=config.normalize,
|
| 29 |
+
final_layer=config.final_layer,
|
| 30 |
+
reduce_hidden_size=config.reduce_hidden_size,
|
| 31 |
+
use_prototype_loss=config.use_prototype_loss,
|
| 32 |
+
prototype_vector_path=config.prototype_vector_path,
|
| 33 |
+
attention_vector_path=config.attention_vector_path,
|
| 34 |
+
eval_buckets=config.eval_buckets,
|
| 35 |
+
seed=config.seed,
|
| 36 |
+
num_prototypes_per_class=config.num_prototypes_per_class,
|
| 37 |
+
batch_size=config.batch_size,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Initialize weights and apply final processing
|
| 41 |
+
self.post_init()
|
| 42 |
+
|
| 43 |
+
def _init_weights(self, module):
|
| 44 |
+
"""Initialize the weights"""
|
| 45 |
+
if isinstance(module, (MultiProtoModule)):
|
| 46 |
+
# MultiProtoModule handles its own initialization or is loaded from checkpoint
|
| 47 |
+
return
|
| 48 |
+
# Add other initializations if standard layers are used directly in SprotoModel
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
def forward(
|
| 52 |
+
self,
|
| 53 |
+
input_ids=None,
|
| 54 |
+
attention_mask=None,
|
| 55 |
+
token_type_ids=None,
|
| 56 |
+
targets=None,
|
| 57 |
+
tokens=None,
|
| 58 |
+
sample_ids=None,
|
| 59 |
+
**kwargs,
|
| 60 |
+
):
|
| 61 |
+
|
| 62 |
+
batch = {
|
| 63 |
+
"input_ids": input_ids,
|
| 64 |
+
"attention_masks": attention_mask,
|
| 65 |
+
"token_type_ids": token_type_ids,
|
| 66 |
+
"targets": targets,
|
| 67 |
+
"tokens": tokens,
|
| 68 |
+
"sample_ids": sample_ids,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
logits, max_indices, metadata = self.module(batch)
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"logits": logits,
|
| 75 |
+
"max_indices": max_indices,
|
| 76 |
+
"metadata": metadata,
|
| 77 |
+
}
|
overview.png
ADDED
|