monajm36 commited on
Commit
bf252cc
Β·
unverified Β·
1 Parent(s): 5af9dc3

Create apply_to_external_dataset.py

Browse files
Files changed (1) hide show
  1. examples/apply_to_external_dataset.py +337 -0
examples/apply_to_external_dataset.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Applying OHCA Classifier to CLIF Datasets
3
+
4
+ This example demonstrates how to apply a MIMIC-trained OHCA model to CLIF datasets
5
+ from other institutions. CLIF (Common Longitudinal ICU data Format) standardizes
6
+ healthcare data, making cross-institutional model deployment much easier.
7
+
8
+ Example use case: Apply MIMIC-IV trained model β†’ University of Chicago CLIF dataset
9
+ """
10
+
11
+ import pandas as pd
12
+ import numpy as np
13
+ import sys
14
+ import os
15
+ from pathlib import Path
16
+
17
+ # Import OHCA inference functions
18
+ sys.path.append('../src')
19
+ from ohca_inference import (
20
+ load_ohca_model,
21
+ run_inference,
22
+ analyze_predictions,
23
+ get_high_confidence_cases
24
+ )
25
+
26
+ def apply_ohca_model_to_clif_dataset():
27
+ """
28
+ Apply MIMIC-trained OHCA model to CLIF datasets from other institutions
29
+
30
+ CLIF (Common Longitudinal ICU data Format) standardizes healthcare data across
31
+ institutions, making it easier to apply models trained on one dataset to another.
32
+
33
+ This example shows how to:
34
+ 1. Load a MIMIC-trained OHCA model
35
+ 2. Load CLIF dataset from another institution
36
+ 3. Apply model using standardized CLIF format
37
+ 4. Analyze results for clinical deployment
38
+ """
39
+
40
+ print("πŸ₯ Applying MIMIC-trained OHCA Model to CLIF Dataset")
41
+ print("="*55)
42
+
43
+ # ==========================================================================
44
+ # STEP 1: Load your trained OHCA model
45
+ # ==========================================================================
46
+
47
+ print("\nπŸ“‚ Step 1: Loading trained OHCA model...")
48
+
49
+ # Path to your trained model (adjust to your actual path)
50
+ model_path = "./trained_ohca_model" # or wherever you saved your model
51
+
52
+ if not os.path.exists(model_path):
53
+ print(f"❌ Model not found at: {model_path}")
54
+ print("Please ensure you have a trained model or update the path.")
55
+ return
56
+
57
+ # Load the model
58
+ model, tokenizer = load_ohca_model(model_path)
59
+ print("βœ… Model loaded successfully")
60
+
61
+ # ==========================================================================
62
+ # STEP 2: Load CLIF dataset from external institution
63
+ # ==========================================================================
64
+
65
+ print("\nπŸ“Š Step 2: Loading CLIF dataset...")
66
+
67
+ # CLIF datasets follow standardized format across institutions
68
+ # Common CLIF datasets: UChicago, Stanford, etc.
69
+ clif_data_path = "path/to/clif/dataset.csv"
70
+
71
+ # For demonstration, create sample CLIF-formatted data
72
+ if not os.path.exists(clif_data_path):
73
+ print("Creating sample CLIF dataset for demonstration...")
74
+ clif_data_path = create_sample_clif_data()
75
+
76
+ # Load the CLIF dataset
77
+ clif_df = pd.read_csv(clif_data_path)
78
+ print(f"Loaded {len(clif_df):,} cases from CLIF dataset")
79
+ print(f"Available columns: {list(clif_df.columns)}")
80
+
81
+ # ==========================================================================
82
+ # STEP 3: Prepare CLIF data for OHCA inference
83
+ # ==========================================================================
84
+
85
+ print("\nπŸ”§ Step 3: Preparing CLIF data for inference...")
86
+
87
+ # CLIF format standardizes column names across institutions
88
+ # Common CLIF discharge note fields and identifiers:
89
+
90
+ clif_column_mapping = {
91
+ # CLIF standard patient identifiers:
92
+ 'patient_id': 'hadm_id', # Standard CLIF patient ID
93
+ 'hospitalization_id': 'hadm_id', # CLIF hospitalization ID
94
+ 'encounter_id': 'hadm_id', # Alternative CLIF encounter ID
95
+
96
+ # CLIF standard clinical text fields:
97
+ 'discharge_summary': 'clean_text', # CLIF discharge summary
98
+ 'clinical_notes': 'clean_text', # CLIF clinical notes
99
+ 'progress_notes': 'clean_text', # CLIF progress notes
100
+ 'discharge_notes': 'clean_text', # CLIF discharge notes
101
+ }
102
+
103
+ # Apply CLIF column mapping
104
+ print("πŸ”„ Mapping CLIF columns to OHCA model format...")
105
+
106
+ # Check which CLIF columns are available
107
+ available_mappings = {k: v for k, v in clif_column_mapping.items()
108
+ if k in clif_df.columns}
109
+
110
+ if available_mappings:
111
+ # Apply the mapping
112
+ clif_df = clif_df.rename(columns=available_mappings)
113
+ print(f"βœ… Mapped CLIF columns: {list(available_mappings.keys())}")
114
+ else:
115
+ print("⚠️ Standard CLIF columns not found. Manual mapping required.")
116
+ print(f"Available columns: {list(clif_df.columns)}")
117
+ print("Please update clif_column_mapping to match your CLIF dataset")
118
+ return
119
+
120
+ # Ensure required columns exist
121
+ if 'hadm_id' not in clif_df.columns or 'clean_text' not in clif_df.columns:
122
+ print("❌ Required columns 'hadm_id' and 'clean_text' not found after mapping")
123
+ print("Please update the clif_column_mapping above")
124
+ return
125
+
126
+ # Clean the CLIF data
127
+ clif_df = clif_df.dropna(subset=['hadm_id', 'clean_text'])
128
+ clif_df['clean_text'] = clif_df['clean_text'].astype(str)
129
+
130
+ print(f"βœ… CLIF data prepared: {len(clif_df):,} cases ready for inference")
131
+
132
+ # ==========================================================================
133
+ # STEP 4: Run OHCA inference on CLIF data
134
+ # ==========================================================================
135
+
136
+ print("\nπŸ” Step 4: Running OHCA inference on CLIF dataset...")
137
+
138
+ # Run inference on CLIF data
139
+ results = run_inference(
140
+ model=model,
141
+ tokenizer=tokenizer,
142
+ inference_df=clif_df,
143
+ batch_size=16,
144
+ output_path="clif_dataset_ohca_predictions.csv"
145
+ )
146
+
147
+ # ==========================================================================
148
+ # STEP 5: Analyze results
149
+ # ==========================================================================
150
+
151
+ print("\nπŸ“ˆ Step 5: Analyzing results...")
152
+
153
+ # Basic statistics
154
+ total_cases = len(results)
155
+ predicted_ohca_05 = (results['ohca_probability'] >= 0.5).sum()
156
+ predicted_ohca_08 = (results['ohca_probability'] >= 0.8).sum()
157
+ predicted_ohca_09 = (results['ohca_probability'] >= 0.9).sum()
158
+
159
+ print(f"\nπŸ“Š OHCA Predictions on CLIF Dataset:")
160
+ print(f" Total CLIF cases analyzed: {total_cases:,}")
161
+ print(f" Predicted OHCA (β‰₯0.5): {predicted_ohca_05:,} ({predicted_ohca_05/total_cases:.1%})")
162
+ print(f" High confidence (β‰₯0.8): {predicted_ohca_08:,} ({predicted_ohca_08/total_cases:.1%})")
163
+ print(f" Very high confidence (β‰₯0.9): {predicted_ohca_09:,} ({predicted_ohca_09/total_cases:.1%})")
164
+
165
+ # CLIF standardization benefits
166
+ print(f"\n🎯 CLIF Standardization Benefits:")
167
+ print(f" βœ… Consistent data format across institutions")
168
+ print(f" βœ… Minimal preprocessing required")
169
+ print(f" βœ… Improved model generalizability")
170
+ print(f" βœ… Easier cross-institutional validation")
171
+
172
+ # Detailed analysis
173
+ analysis = analyze_predictions(results)
174
+
175
+ # Get high-confidence cases for manual review
176
+ high_confidence_cases = get_high_confidence_cases(results, threshold=0.8)
177
+
178
+ if len(high_confidence_cases) > 0:
179
+ print(f"\n🎯 High Confidence OHCA Cases (for manual review):")
180
+ print(f" Found {len(high_confidence_cases)} cases with probability β‰₯ 0.8")
181
+
182
+ # Save high confidence cases separately
183
+ high_confidence_cases.to_csv(
184
+ "clif_dataset_high_confidence_ohca.csv",
185
+ index=False
186
+ )
187
+ print(f" πŸ’Ύ Saved to: clif_dataset_high_confidence_ohca.csv")
188
+
189
+ # ==========================================================================
190
+ # STEP 6: Clinical interpretation and next steps
191
+ # ==========================================================================
192
+
193
+ print(f"\nπŸ₯ Clinical Interpretation:")
194
+ print(f" β€’ MIMIC-trained model successfully applied to CLIF dataset")
195
+ print(f" β€’ CLIF standardization facilitated cross-institutional deployment")
196
+ print(f" β€’ Recommend manual review of high-confidence predictions")
197
+ print(f" β€’ Consider validation against known ground truth if available")
198
+
199
+ print(f"\nπŸ“‹ Recommended Next Steps:")
200
+ print(f" 1. Review high-confidence predictions with clinical experts")
201
+ print(f" 2. Calculate performance metrics if ground truth available")
202
+ print(f" 3. Compare OHCA prevalence with MIMIC-IV baseline")
203
+ print(f" 4. Document any institutional differences observed")
204
+ print(f" 5. Consider CLIF-specific model fine-tuning if needed")
205
+
206
+ # ==========================================================================
207
+ # STEP 7: Save comprehensive results
208
+ # ==========================================================================
209
+
210
+ print(f"\nπŸ’Ύ Saving results...")
211
+
212
+ # Create comprehensive results summary
213
+ summary = {
214
+ 'dataset_info': {
215
+ 'total_cases': total_cases,
216
+ 'data_source': 'CLIF Dataset',
217
+ 'data_format': 'Common Longitudinal ICU data Format (CLIF)',
218
+ 'model_used': model_path
219
+ },
220
+ 'predictions': {
221
+ 'ohca_predicted_05': int(predicted_ohca_05),
222
+ 'ohca_predicted_08': int(predicted_ohca_08),
223
+ 'ohca_predicted_09': int(predicted_ohca_09),
224
+ 'prevalence_05': float(predicted_ohca_05/total_cases),
225
+ 'prevalence_08': float(predicted_ohca_08/total_cases),
226
+ 'prevalence_09': float(predicted_ohca_09/total_cases)
227
+ },
228
+ 'files_created': [
229
+ 'clif_dataset_ohca_predictions.csv',
230
+ 'clif_dataset_high_confidence_ohca.csv'
231
+ ]
232
+ }
233
+
234
+ # Save summary
235
+ import json
236
+ with open('clif_dataset_analysis_summary.json', 'w') as f:
237
+ json.dump(summary, f, indent=2)
238
+
239
+ print(f"βœ… CLIF dataset analysis complete! Files created:")
240
+ print(f" πŸ“„ clif_dataset_ohca_predictions.csv")
241
+ print(f" 🎯 clif_dataset_high_confidence_ohca.csv")
242
+ print(f" πŸ“‹ clif_dataset_analysis_summary.json")
243
+
244
+ return results
245
+
246
+ def create_sample_clif_data():
247
+ """Create sample CLIF-formatted dataset for demonstration"""
248
+
249
+ # CLIF standard format with typical column names
250
+ sample_clif_data = {
251
+ 'patient_id': [f'CLIF_{i:06d}' for i in range(500)], # CLIF patient identifier
252
+ 'hospitalization_id': [f'HOSP_{i:06d}' for i in range(500)], # CLIF hospitalization ID
253
+ 'discharge_summary': [ # CLIF discharge summary field
254
+ "Patient presented with cardiac arrest at home. Family initiated CPR, EMS transported.",
255
+ "Chief complaint: Chest pain. Patient stable throughout admission, no arrest.",
256
+ "Patient found down at workplace. Coworkers performed CPR until EMS arrival.",
257
+ "Admission for pneumonia. Patient responded well to antibiotics, stable course.",
258
+ "Transfer from outside hospital for post-arrest care. Originally arrested at restaurant.",
259
+ "Chief complaint: Shortness of breath. CHF exacerbation managed with diuretics.",
260
+ "Witnessed collapse at gym. Immediate bystander CPR, AED used, ROSC achieved.",
261
+ "Routine admission for diabetes management. No acute events during stay.",
262
+ "Patient arrested during family dinner. CPR by family, transported by EMS.",
263
+ "Scheduled procedure. Patient stable pre and post procedure, no complications.",
264
+ ] * 50, # Repeat to get 500 samples
265
+ 'clif_version': ['2.1.0'] * 500, # CLIF version metadata
266
+ 'institution': ['Sample_Hospital'] * 500 # Source institution
267
+ }
268
+
269
+ sample_df = pd.DataFrame(sample_clif_data)
270
+ sample_path = "sample_clif_dataset.csv"
271
+ sample_df.to_csv(sample_path, index=False)
272
+
273
+ print(f"πŸ“ Created sample CLIF dataset: {sample_path}")
274
+ print(f" Format: CLIF (Common Longitudinal ICU data Format)")
275
+ print(f" Columns: {list(sample_clif_data.keys())}")
276
+ return sample_path
277
+
278
+ def clif_validation_workflow():
279
+ """
280
+ Specific workflow for CLIF cross-institutional validation studies
281
+
282
+ Use this when you have CLIF datasets with ground truth labels from
283
+ multiple institutions and want to measure model generalizability.
284
+ """
285
+
286
+ print("πŸ”¬ CLIF Cross-Institutional Validation Workflow")
287
+ print("="*45)
288
+
289
+ print("\nThis workflow is for when you have:")
290
+ print("β€’ CLIF datasets from multiple institutions")
291
+ print("β€’ Known OHCA labels for validation")
292
+ print("β€’ Want to measure cross-institutional performance")
293
+ print("β€’ Need to assess CLIF standardization benefits")
294
+
295
+ print("\nSteps:")
296
+ print("1. Apply MIMIC-trained model to CLIF datasets (use apply_ohca_model_to_clif_dataset())")
297
+ print("2. Compare predictions with ground truth labels")
298
+ print("3. Calculate performance metrics across institutions")
299
+ print("4. Analyze CLIF standardization benefits")
300
+ print("5. Document institutional variations and model robustness")
301
+
302
+ print("\nExample code for CLIF validation metrics:")
303
+ print("""
304
+ # After running inference on multiple CLIF datasets
305
+ from sklearn.metrics import roc_auc_score, classification_report
306
+
307
+ # Load CLIF ground truth
308
+ clif_ground_truth = pd.read_csv('clif_ground_truth.csv')
309
+
310
+ # Calculate cross-institutional metrics
311
+ clif_auc = roc_auc_score(clif_ground_truth['true_label'], results['ohca_probability'])
312
+ print(f"CLIF validation AUC: {clif_auc:.3f}")
313
+
314
+ # Compare MIMIC vs CLIF performance
315
+ print("Cross-institutional performance:")
316
+ print(f"MIMIC training AUC: {mimic_auc:.3f}")
317
+ print(f"CLIF validation AUC: {clif_auc:.3f}")
318
+ print(f"CLIF standardization benefit: Minimal performance drop")
319
+ """)
320
+
321
+ if __name__ == "__main__":
322
+ print("CLIF Dataset Application Examples")
323
+ print("="*35)
324
+
325
+ print("\nChoose an example:")
326
+ print("1. Apply MIMIC-trained model to CLIF dataset")
327
+ print("2. CLIF cross-institutional validation workflow")
328
+
329
+ choice = input("\nEnter choice (1-2): ").strip()
330
+
331
+ if choice == "1":
332
+ apply_ohca_model_to_clif_dataset()
333
+ elif choice == "2":
334
+ clif_validation_workflow()
335
+ else:
336
+ print("Running CLIF dataset application by default...")
337
+ apply_ohca_model_to_clif_dataset()