row56 commited on
Commit
265e2ec
·
verified ·
1 Parent(s): 242206b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -23
README.md CHANGED
@@ -119,47 +119,124 @@ ProtoPatient/
119
  ### 1. Install Dependencies
120
 
121
  ```bash
122
- pip install transformers torch
 
 
 
123
  ```
124
 
125
  ### 2. Load the Model via Hugging Face
126
 
127
  ```python
128
- from transformers import AutoTokenizer, AutoModel
129
-
130
- repo_id = "row56/ProtoPatient"
131
- tokenizer = AutoTokenizer.from_pretrained(repo_id)
132
- model = AutoModel.from_pretrained(repo_id)
 
 
 
 
 
 
 
 
 
 
 
133
  model.eval()
134
-
135
- sample_text = "This patient presents with severe headaches and nausea..."
136
- inputs = tokenizer(sample_text, return_tensors="pt")
137
- outputs = model(**inputs)
138
- print("Output shape:", outputs.last_hidden_state.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  ```
140
 
141
- ## 3. Interpreting Outputs
142
 
143
- For a full prototypical classification workflow, use the custom modules in `proto_model/` (e.g., `ProtoForMultiLabelClassification`) to inspect:
144
- - Which tokens receive high attention for each diagnosis.
145
- - Which prototypical patients are retrieved as similar examples.
146
 
147
- Using the standard `AutoModel` returns raw embeddings; the custom code is required for full label-wise attention and prototype retrieval.
148
 
149
- ---
 
 
 
 
 
 
150
 
151
- ## 4. (Optional) Hugging Face Pipelines
152
 
153
- Integrate the model into a pipeline for feature extraction:
 
 
154
 
155
  ```python
156
- from transformers import pipeline
 
157
 
158
- extractor = pipeline("feature-extraction", model=repo_id, tokenizer=repo_id)
159
- embeddings = extractor("Severe headaches and vomiting...")
160
- print(len(embeddings), len(embeddings[0])) # Token-level feature vectors
 
 
161
  ```
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  # Intended Use, Limitations & Ethical Considerations
164
 
165
  ## Intended Use
 
119
  ### 1. Install Dependencies
120
 
121
  ```bash
122
+ git clone https://huggingface.co/row56/ProtoPatient
123
+ cd ProtoPatient
124
+ pip install -e . transformers torch safetensors
125
+ export TOKENIZERS_PARALLELISM=false
126
  ```
127
 
128
  ### 2. Load the Model via Hugging Face
129
 
130
  ```python
131
+ import torch
132
+ from transformers import AutoTokenizer
133
+ from proto_model.configuration_proto import ProtoConfig
134
+ from proto_model.modeling_proto import ProtoForMultiLabelClassification
135
+
136
+ # Load & configure
137
+ cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
138
+ cfg.pretrained_model_name_or_path = "bert-base-uncased"
139
+ cfg.use_cuda = False
140
+
141
+ tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
142
+ model = ProtoForMultiLabelClassification.from_pretrained(
143
+ "row56/ProtoPatient",
144
+ config=cfg,
145
+ ignore_mismatched_sizes=True
146
+ )
147
  model.eval()
148
+ model.cpu()
149
+
150
+ # Helper
151
+ def get_proto_logits(texts):
152
+ enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
153
+ batch = {
154
+ "input_ids": enc["input_ids"],
155
+ "attention_masks": enc["attention_mask"],
156
+ "token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
157
+ "tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
158
+ }
159
+ with torch.no_grad():
160
+ logits, _ = model.proto_module(batch)
161
+ return logits
162
+
163
+ # Run
164
+ texts = [
165
+ "Patient shows elevated heart rate and low oxygen saturation.",
166
+ "No significant findings; patient is healthy."
167
+ ]
168
+ logits = get_proto_logits(texts)
169
+ print("Logits shape:", logits.shape)
170
+ print("Logits:\n", logits)
171
+
172
+ probs = torch.sigmoid(logits)
173
+ print("Probabilities:\n", probs)
174
  ```
175
 
176
+ ## 3. Training Data & Licenses
177
 
178
+ This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement.
 
 
179
 
180
+ To obtain MIMIC-III:
181
 
182
+ Visit https://physionet.org/content/mimiciii/1.4/
183
+ Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training.
184
+ Sign the MIMIC-III Data Use Agreement (DUA).
185
+ Download the raw notes and run the preprocessing scripts from the paper’s repository.
186
+ Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license.
187
+
188
+ ## 4. Load Precomputed Training Data for Prototype Retrieval
189
 
190
+ After you have MIMIC-III and have applied the published preprocessing, you should produce:
191
 
192
+ data/train_embeds.npy NumPy array of shape (N, d) with per-example, per-class embeddings.
193
+ data/train_texts.json — JSON array of length N of the raw admission-note strings.
194
+ Place those in data/ and then:
195
 
196
  ```python
197
+ import numpy as np
198
+ import json
199
 
200
+ train_embeds = np.load("data/train_embeds.npy") # shape (N, d)
201
+ with open("data/train_texts.json", "r") as f:
202
+ train_texts = json.load(f) # list[str]
203
+
204
+ print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
205
  ```
206
 
207
+ ## 5. Interpreting Outputs & Retrieving Prototypes
208
+
209
+ ```python
210
+ from sklearn.neighbors import NearestNeighbors
211
+
212
+ text = "Patient has chest pain and shortness of breath."
213
+ enc = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
214
+ batch = {
215
+ "input_ids": enc["input_ids"],
216
+ "attention_masks": enc["attention_mask"],
217
+ "token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
218
+ "tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
219
+ }
220
+
221
+ with torch.no_grad():
222
+ logits, metadata = model.proto_module(batch)
223
+
224
+ attn_scores = metadata["attentions"][0] # [num_labels, seq_len]
225
+ for label_id, scores in enumerate(attn_scores):
226
+ topk = sorted(zip(batch["tokens"][0], scores.tolist()),
227
+ key=lambda x: -x[1])[:5]
228
+ print(f"Label {label_id} top tokens:", topk)
229
+
230
+ proto_vecs = model.proto_module.prototype_vectors.cpu().numpy() # [num_labels, d]
231
+ nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
232
+
233
+ for label_id, u_c in enumerate(proto_vecs):
234
+ dist, idx = nn.kneighbors(u_c.reshape(1, -1))
235
+ print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):")
236
+ print(train_texts[idx[0][0]])
237
+ ```
238
+
239
+
240
  # Intended Use, Limitations & Ethical Considerations
241
 
242
  ## Intended Use