RamezCh commited on
Commit
75108fa
·
verified ·
1 Parent(s): 9b26dd3

Upload sproto model

Browse files
Files changed (4) hide show
  1. README.md +255 -263
  2. special_tokens_map.json +7 -0
  3. tokenizer.json +0 -0
  4. tokenizer_config.json +16 -0
README.md CHANGED
@@ -1,264 +1,256 @@
1
- ---
2
- language: en
3
- license: apache-2.0
4
- library_name: transformers
5
-
6
- pipeline_tag: text-classification
7
- task_categories:
8
- - text-classification
9
-
10
- model_type: sproto
11
- base_model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
12
-
13
- datasets:
14
- - mimic-iv
15
-
16
- metrics:
17
- - auroc
18
- - pr-auc
19
-
20
- tags:
21
- - text-classification
22
- - multi-label-classification
23
- - long-tail-learning
24
- - medical
25
- - clinical-nlp
26
- - interpretability
27
- - prototypical-networks
28
- - ehr
29
- ---
30
-
31
- # S-Proto: Sparse Prototypical Networks for Long-Tail Clinical Diagnosis Prediction
32
-
33
- **Published at ECML PKDD 2024 (CORE A)**
34
- *Boosting Long-Tail Data Classification with Sparse Prototypical Networks*
35
-
36
- Alexei Figueroa*, Jens-Michalis Papaioannou*, et al.
37
- DATEXIS, Berliner Hochschule für Technik, Feinstein Institutes, TU Munich, Leibniz University Hannover
38
- (* equal contribution)
39
-
40
- ![S-Proto](overview.png)
41
-
42
- This repository provides **S-Proto**, a sparse and interpretable prototypical network for extreme multi-label diagnosis prediction from clinical text. The model is designed to address the long-tail distribution of clinical diagnoses while preserving faithful, prototype-based explanations.
43
-
44
- ## Interactive Demo
45
-
46
- You can explore the model's predictions and interpretability features through our interactive web demo:
47
- **[https://s-proto.demo.datexis.com/](https://s-proto.demo.datexis.com/)**
48
-
49
- S-Proto was introduced in the paper:
50
-
51
- **[Boosting Long-Tail Data Classification with Sparse Prototypical Networks](https://ecmlpkdd-storage.s3.eu-central-1.amazonaws.com/preprints/2024/lncs14947/lncs14947435.pdf)**
52
- European Conference on Machine Learning and Principles and Practice of Knowledge Discovery in Databases (ECML PKDD 2024, CORE A)
53
- Alexei Figueroa*, Jens-Michalis Papaioannou*, et al.
54
- DATEXIS, Berliner Hochschule für Technik, Feinstein Institutes, TU Munich, Leibniz University Hannover
55
- (* equal contribution)
56
-
57
- ## Overview
58
-
59
- Clinical outcome prediction from Electronic Health Records is characterized by extreme label imbalance. A small number of diagnoses account for most patients, while the majority of diagnoses appear rarely. Standard transformer classifiers tend to perform well on frequent diagnoses but degrade sharply in the long tail.
60
-
61
- S-Proto addresses this problem by extending prototypical networks with:
62
-
63
- - Multiple prototypes per diagnosis
64
- - Sparse winner-takes-all activation
65
- - Prototype-level interpretability
66
- - Efficient training despite increased representational capacity
67
-
68
- The model achieves state-of-the-art performance on MIMIC-IV diagnosis prediction, with particularly strong gains in PR-AUC for rare diagnoses, and transfers successfully to unseen clinical datasets.
69
-
70
- ## Model Architecture
71
-
72
- S-Proto builds on **PubMedBERT** as the text encoder and introduces a sparse prototypical layer on top.
73
-
74
- For each diagnosis label, the model learns multiple sub-networks, each consisting of:
75
-
76
- - A label-specific attention vector
77
- - A prototype vector representing a prototypical patient
78
-
79
- Given an input clinical note:
80
-
81
- 1. The note is encoded using PubMedBERT
82
- 2. Token embeddings are projected into a latent space
83
- 3. Each diagnosis activates multiple candidate sub-networks
84
- 4. A winner-takes-all mechanism selects the single most relevant sub-network per diagnosis
85
- 5. Only the winning prototype contributes to the prediction and receives gradient updates
86
-
87
- This allows S-Proto to model heterogeneous disease phenotypes while remaining sparse and efficient.
88
-
89
- ## Intended Use
90
-
91
- This model is intended for:
92
-
93
- - Clinical diagnosis prediction from admission notes
94
- - Research on long-tail learning in healthcare NLP
95
- - Interpretable clinical decision support systems
96
- - Analysis of disease phenotypes via learned prototypes
97
-
98
- This model is **not intended for direct clinical deployment** without external validation, auditing, and regulatory approval.
99
-
100
- ## Inference Example
101
-
102
- ```python
103
- from transformers import AutoTokenizer, AutoModel
104
- import torch
105
-
106
- tokenizer = AutoTokenizer.from_pretrained(
107
- "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
108
- )
109
- model = AutoModel.from_pretrained(
110
- "DATEXIS/sproto",
111
- trust_remote_code=True
112
- )
113
- model.eval()
114
-
115
- text_input = [
116
- "CHIEF COMPLAINT: Right Carotid Artery Stenosis. "
117
- "PRESENT ILLNESS: Ms. ___ is a ___ year old woman with hyperlipidemia, "
118
- "cirrhosis with esophageal varices, alcoholism, COPD, left eye blindness, "
119
- "and right carotid stenosis status post right carotid endarterectomy."
120
- ]
121
-
122
- inputs = tokenizer(
123
- text_input,
124
- padding=True,
125
- truncation=True,
126
- max_length=512,
127
- return_tensors="pt"
128
- )
129
-
130
- tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs["input_ids"]]
131
-
132
- with torch.no_grad():
133
- output = model(
134
- input_ids=inputs["input_ids"],
135
- attention_mask=inputs["attention_mask"],
136
- token_type_ids=inputs.get("token_type_ids"),
137
- tokens=tokens
138
- )
139
-
140
- logits = output["logits"]
141
- max_indices = output["max_indices"]
142
- metadata = output["metadata"]
143
-
144
- print("Inference successful")
145
- print("Logits shape:", logits.shape)
146
- print("Max indices:", max_indices)
147
- print("Metadata:", metadata)
148
- ```
149
-
150
- ## Outputs
151
-
152
- The model returns a dictionary with the following entries:
153
-
154
- - **logits**
155
- Prediction scores per diagnosis label.
156
-
157
- - **max_indices**
158
- Index of the winning prototype sub-network per diagnosis, corresponding to the selected prototype.
159
-
160
- - **metadata**
161
- Additional information useful for analysis and interpretability.
162
-
163
- ## Explainability
164
-
165
- S-Proto provides built-in faithful explanations through its prototypical structure:
166
-
167
- - Attention vectors highlight clinically relevant tokens
168
- - Prototype distances reflect similarity to prototypical patients
169
- - Multiple prototypes per diagnosis capture disease subtypes and cohorts
170
- - Faithfulness metrics remain comparable to ProtoPatient despite higher capacity
171
-
172
- Qualitative evaluation with medical professionals confirms that learned prototypes often correspond to clinically meaningful phenotypes.
173
-
174
- ## Training
175
-
176
- First, clone the repository:
177
-
178
- ```bash
179
- git clone https://github.com/DATEXIS/sproto.git
180
- cd sproto
181
- ```
182
-
183
- Set up the environment using Poetry:
184
-
185
- ```bash
186
- poetry install
187
- ```
188
-
189
- Activate the virtual environment:
190
-
191
- ```bash
192
- poetry env activate
193
- ```
194
-
195
- Once the environment is active, you can start training by running the train command with the desired arguments.
196
-
197
- Example:
198
-
199
- ```bash
200
- train \
201
- --batch_size 3 \
202
- --pretrained_model microsoft/biomednlp-pubmedbert-base-uncased-abstract-fulltext \
203
- --pretrained_model_path path_to_pretrained_model.ckpt \
204
- --model_type MULTI_PROTO \
205
- --train_file training_data.csv \
206
- --val_file validation_data.csv \
207
- --test_file test_data.csv \
208
- --save_dir ../experiments/ \
209
- --gpus 1 \
210
- --check_val_every_n_epoch 2 \
211
- --num_warmup_steps 0 \
212
- --num_training_steps 50 \
213
- --max_length 512 \
214
- --lr_features 0.000005 \
215
- --lr_prototypes 0.001 \
216
- --lr_others 0.001 \
217
- --num_val_samples None \
218
- --use_attention True \
219
- --reduce_hidden_size 256 \
220
- --all_labels_path all_labels.pcl \
221
- --seed 42 \
222
- --label_column labels \
223
- --metric_opt auroc_macro \
224
- --train_files [] \
225
- --val_files [] \
226
- --only_test True \
227
- --model_name 5p \
228
- --store_metadata False \
229
- --num_prototypes_per_class 5
230
- ```
231
-
232
- ## Citation
233
-
234
- ```bibtex
235
- @inproceedings{figueroa2024sproto,
236
- title={Boosting Long-Tail Data Classification with Sparse Prototypical Networks},
237
- author={Figueroa, Alexei and Papaioannou, Jens-Michalis and Fallon, Conor and Bekiaridou, Alexandra and Bressem, Keno and Zanos, Stavros and Gers, Felix and Nejdl, Wolfgang and Löser, Alexander},
238
- booktitle={Proceedings of the European Conference on Machine Learning and Principles and Practice of Knowledge Discovery in Databases (ECML PKDD)},
239
- year={2024}
240
- }
241
- ```
242
-
243
- ## License
244
-
245
- This model and its associated code are released under the Apache License 2.0.
246
-
247
- The model was trained on the MIMIC-IV dataset, which is subject to restricted access. No training data is included or redistributed with this repository.
248
- The data were accessed under a data use agreement. No patient-identifiable information is shared.
249
-
250
- Use of this model must comply with all applicable data governance and ethical guidelines.
251
-
252
- ### Limitations
253
-
254
- - Extremely rare diagnoses remain challenging
255
- - Clinical dataset biases may be reflected in predictions
256
- - Winner-takes-all selection is fixed and not learned dynamically
257
- - Not validated for real-world clinical deployment
258
-
259
- ### Ethical Considerations
260
-
261
- - The model processes sensitive clinical text
262
- - Predictions should always be reviewed by qualified professionals
263
- - Outputs should not be used as sole evidence for clinical decisions
264
  - Care must be taken to avoid reinforcing existing healthcare biases
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+
6
+ pipeline_tag: text-classification
7
+ task_categories:
8
+ - text-classification
9
+
10
+ model_type: sproto
11
+ base_model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
12
+
13
+ datasets:
14
+ - mimic-iv
15
+
16
+ metrics:
17
+ - auroc
18
+ - pr-auc
19
+
20
+ tags:
21
+ - text-classification
22
+ - multi-label-classification
23
+ - long-tail-learning
24
+ - medical
25
+ - clinical-nlp
26
+ - interpretability
27
+ - prototypical-networks
28
+ - ehr
29
+ ---
30
+
31
+ # S-Proto: Sparse Prototypical Networks for Long-Tail Clinical Diagnosis Prediction
32
+
33
+ ![S-Proto](overview.png)
34
+
35
+ This repository provides **S-Proto**, a sparse and interpretable prototypical network for extreme multi-label diagnosis prediction from clinical text. The model is designed to address the long-tail distribution of clinical diagnoses while preserving faithful, prototype-based explanations.
36
+
37
+ ## Interactive Demo
38
+
39
+ You can explore the model's predictions and interpretability features through our interactive web demo:
40
+ **[https://s-proto.demo.datexis.com/](https://s-proto.demo.datexis.com/)**
41
+
42
+ S-Proto was introduced in the paper:
43
+
44
+ **[Boosting Long-Tail Data Classification with Sparse Prototypical Networks](https://ecmlpkdd-storage.s3.eu-central-1.amazonaws.com/preprints/2024/lncs14947/lncs14947435.pdf)**
45
+ Alexei Figueroa*, Jens-Michalis Papaioannou*, et al.
46
+ DATEXIS, Berliner Hochschule für Technik, Feinstein Institutes, TU Munich, Leibniz University Hannover
47
+ (* equal contribution)
48
+
49
+ ## Overview
50
+
51
+ Clinical outcome prediction from Electronic Health Records is characterized by extreme label imbalance. A small number of diagnoses account for most patients, while the majority of diagnoses appear rarely. Standard transformer classifiers tend to perform well on frequent diagnoses but degrade sharply in the long tail.
52
+
53
+ S-Proto addresses this problem by extending prototypical networks with:
54
+
55
+ - Multiple prototypes per diagnosis
56
+ - Sparse winner-takes-all activation
57
+ - Prototype-level interpretability
58
+ - Efficient training despite increased representational capacity
59
+
60
+ The model achieves state-of-the-art performance on MIMIC-IV diagnosis prediction, with particularly strong gains in PR-AUC for rare diagnoses, and transfers successfully to unseen clinical datasets.
61
+
62
+ ## Model Architecture
63
+
64
+ S-Proto builds on **PubMedBERT** as the text encoder and introduces a sparse prototypical layer on top.
65
+
66
+ For each diagnosis label, the model learns multiple sub-networks, each consisting of:
67
+
68
+ - A label-specific attention vector
69
+ - A prototype vector representing a prototypical patient
70
+
71
+ Given an input clinical note:
72
+
73
+ 1. The note is encoded using PubMedBERT
74
+ 2. Token embeddings are projected into a latent space
75
+ 3. Each diagnosis activates multiple candidate sub-networks
76
+ 4. A winner-takes-all mechanism selects the single most relevant sub-network per diagnosis
77
+ 5. Only the winning prototype contributes to the prediction and receives gradient updates
78
+
79
+ This allows S-Proto to model heterogeneous disease phenotypes while remaining sparse and efficient.
80
+
81
+ ## Intended Use
82
+
83
+ This model is intended for:
84
+
85
+ - Clinical diagnosis prediction from admission notes
86
+ - Research on long-tail learning in healthcare NLP
87
+ - Interpretable clinical decision support systems
88
+ - Analysis of disease phenotypes via learned prototypes
89
+
90
+ This model is **not intended for direct clinical deployment** without external validation, auditing, and regulatory approval.
91
+
92
+ ## Inference Example
93
+
94
+ ```python
95
+ from transformers import AutoTokenizer, AutoModel
96
+ import torch
97
+
98
+ tokenizer = AutoTokenizer.from_pretrained(
99
+ "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
100
+ )
101
+ model = AutoModel.from_pretrained(
102
+ "DATEXIS/sproto",
103
+ trust_remote_code=True
104
+ )
105
+ model.eval()
106
+
107
+ text_input = [
108
+ "CHIEF COMPLAINT: Right Carotid Artery Stenosis. "
109
+ "PRESENT ILLNESS: Ms. ___ is a ___ year old woman with hyperlipidemia, "
110
+ "cirrhosis with esophageal varices, alcoholism, COPD, left eye blindness, "
111
+ "and right carotid stenosis status post right carotid endarterectomy."
112
+ ]
113
+
114
+ inputs = tokenizer(
115
+ text_input,
116
+ padding=True,
117
+ truncation=True,
118
+ max_length=512,
119
+ return_tensors="pt"
120
+ )
121
+
122
+ tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs["input_ids"]]
123
+
124
+ with torch.no_grad():
125
+ output = model(
126
+ input_ids=inputs["input_ids"],
127
+ attention_mask=inputs["attention_mask"],
128
+ token_type_ids=inputs.get("token_type_ids"),
129
+ tokens=tokens
130
+ )
131
+
132
+ logits = output["logits"]
133
+ max_indices = output["max_indices"]
134
+ metadata = output["metadata"]
135
+
136
+ print("Inference successful")
137
+ print("Logits shape:", logits.shape)
138
+ print("Max indices:", max_indices)
139
+ print("Metadata:", metadata)
140
+ ```
141
+
142
+ ## Outputs
143
+
144
+ The model returns a dictionary with the following entries:
145
+
146
+ - **logits**
147
+ Prediction scores per diagnosis label.
148
+
149
+ - **max_indices**
150
+ Index of the winning prototype sub-network per diagnosis, corresponding to the selected prototype.
151
+
152
+ - **metadata**
153
+ Additional information useful for analysis and interpretability.
154
+
155
+ ## Explainability
156
+
157
+ S-Proto provides built-in faithful explanations through its prototypical structure:
158
+
159
+ - Attention vectors highlight clinically relevant tokens
160
+ - Prototype distances reflect similarity to prototypical patients
161
+ - Multiple prototypes per diagnosis capture disease subtypes and cohorts
162
+ - Faithfulness metrics remain comparable to ProtoPatient despite higher capacity
163
+
164
+ Qualitative evaluation with medical professionals confirms that learned prototypes often correspond to clinically meaningful phenotypes.
165
+
166
+ ## Training
167
+
168
+ First, clone the repository:
169
+
170
+ ```bash
171
+ git clone https://github.com/DATEXIS/sproto.git
172
+ cd sproto
173
+ ```
174
+
175
+ Set up the environment using Poetry:
176
+
177
+ ```bash
178
+ poetry install
179
+ ```
180
+
181
+ Activate the virtual environment:
182
+
183
+ ```bash
184
+ poetry env activate
185
+ ```
186
+
187
+ Once the environment is active, you can start training by running the train command with the desired arguments.
188
+
189
+ Example:
190
+
191
+ ```bash
192
+ train \
193
+ --batch_size 3 \
194
+ --pretrained_model microsoft/biomednlp-pubmedbert-base-uncased-abstract-fulltext \
195
+ --pretrained_model_path path_to_pretrained_model.ckpt \
196
+ --model_type MULTI_PROTO \
197
+ --train_file training_data.csv \
198
+ --val_file validation_data.csv \
199
+ --test_file test_data.csv \
200
+ --save_dir ../experiments/ \
201
+ --gpus 1 \
202
+ --check_val_every_n_epoch 2 \
203
+ --num_warmup_steps 0 \
204
+ --num_training_steps 50 \
205
+ --max_length 512 \
206
+ --lr_features 0.000005 \
207
+ --lr_prototypes 0.001 \
208
+ --lr_others 0.001 \
209
+ --num_val_samples None \
210
+ --use_attention True \
211
+ --reduce_hidden_size 256 \
212
+ --all_labels_path all_labels.pcl \
213
+ --seed 42 \
214
+ --label_column labels \
215
+ --metric_opt auroc_macro \
216
+ --train_files [] \
217
+ --val_files [] \
218
+ --only_test True \
219
+ --model_name 5p \
220
+ --store_metadata False \
221
+ --num_prototypes_per_class 5
222
+ ```
223
+
224
+ ## Citation
225
+
226
+ ```bibtex
227
+ @inproceedings{figueroa2024sproto,
228
+ title={Boosting Long-Tail Data Classification with Sparse Prototypical Networks},
229
+ author={Figueroa, Alexei and Papaioannou, Jens-Michalis and Fallon, Conor and Bekiaridou, Alexandra and Bressem, Keno and Zanos, Stavros and Gers, Felix and Nejdl, Wolfgang and Löser, Alexander},
230
+ booktitle={Proceedings of the Conference on Empirical Methods in Natural Language Processing},
231
+ year={2024}
232
+ }
233
+ ```
234
+
235
+ ## License
236
+
237
+ This model and its associated code are released under the Apache License 2.0.
238
+
239
+ The model was trained on the MIMIC-IV dataset, which is subject to restricted access. No training data is included or redistributed with this repository.
240
+ The data were accessed under a data use agreement. No patient-identifiable information is shared.
241
+
242
+ Use of this model must comply with all applicable data governance and ethical guidelines.
243
+
244
+ ### Limitations
245
+
246
+ - Extremely rare diagnoses remain challenging
247
+ - Clinical dataset biases may be reflected in predictions
248
+ - Winner-takes-all selection is fixed and not learned dynamically
249
+ - Not validated for real-world clinical deployment
250
+
251
+ ### Ethical Considerations
252
+
253
+ - The model processes sensitive clinical text
254
+ - Predictions should always be reviewed by qualified professionals
255
+ - Outputs should not be used as sole evidence for clinical decisions
 
 
 
 
 
 
 
 
256
  - Care must be taken to avoid reinforcing existing healthcare biases
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "name_or_path": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
8
+ "never_split": null,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "special_tokens_map_file": null,
12
+ "strip_accents": null,
13
+ "tokenize_chinese_chars": true,
14
+ "tokenizer_class": "BertTokenizer",
15
+ "unk_token": "[UNK]"
16
+ }