monajm36 commited on
Commit
edda1fa
Β·
unverified Β·
1 Parent(s): 21d4c7c

Create training_example.py

Browse files
Files changed (1) hide show
  1. examples/training_example.py +289 -0
examples/training_example.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OHCA Training Pipeline Example
3
+
4
+ This example shows how to train an OHCA classifier from scratch.
5
+ """
6
+
7
+ import pandas as pd
8
+ import sys
9
+ import os
10
+
11
+ # Add src to path
12
+ sys.path.append('../src')
13
+
14
+ from ohca_training_pipeline import (
15
+ create_training_sample,
16
+ prepare_training_data,
17
+ train_ohca_model,
18
+ evaluate_model,
19
+ complete_training_pipeline,
20
+ complete_annotation_and_train
21
+ )
22
+
23
+ def example_training_pipeline():
24
+ """Complete example of training an OHCA classifier"""
25
+
26
+ print("πŸš€ OHCA Training Pipeline Example")
27
+ print("="*50)
28
+
29
+ # ==========================================================================
30
+ # STEP 1: Prepare your data
31
+ # ==========================================================================
32
+
33
+ # Your discharge notes should be in CSV format with columns:
34
+ # - hadm_id: Unique identifier for each hospital admission
35
+ # - clean_text: Cleaned discharge note text
36
+
37
+ data_path = "path/to/your/discharge_notes.csv"
38
+
39
+ # For demonstration, create sample data
40
+ if not os.path.exists(data_path):
41
+ print("Creating sample data for demonstration...")
42
+
43
+ sample_data = {
44
+ 'hadm_id': [f'HADM_{i:06d}' for i in range(2000)],
45
+ 'clean_text': [
46
+ "Chief complaint: Cardiac arrest at home. Patient found down by family members, CPR initiated immediately. EMS called, patient transported to ED.",
47
+ "Chief complaint: Chest pain. Patient presents with acute onset chest pain, no loss of consciousness, no arrest occurred.",
48
+ "Chief complaint: Shortness of breath. Patient has chronic heart failure exacerbation, stable vital signs throughout admission.",
49
+ "Chief complaint: Patient found down, cardiac arrest in parking lot, bystander CPR given, ROSC achieved by EMS in field.",
50
+ "Chief complaint: Syncope. Patient had brief loss of consciousness but no cardiac arrest, workup negative for cardiac causes.",
51
+ "Chief complaint: Transfer from outside hospital. Patient had witnessed cardiac arrest at work, CPR by coworkers, transferred for cardiac catheterization.",
52
+ ] * 334 # Repeat to get 2000+ samples
53
+ }
54
+
55
+ df = pd.DataFrame(sample_data)
56
+ df.to_csv(data_path, index=False)
57
+ print(f"Sample data saved to: {data_path}")
58
+
59
+ # ==========================================================================
60
+ # STEP 2: Create annotation sample
61
+ # ==========================================================================
62
+
63
+ print("\nπŸ“ STEP 2: Creating Annotation Sample")
64
+ print("-" * 40)
65
+
66
+ df = pd.read_csv(data_path)
67
+ print(f"Loaded {len(df):,} discharge notes")
68
+
69
+ # Create balanced sample for annotation
70
+ annotation_result = create_training_sample(
71
+ df,
72
+ output_dir="./training_annotation_interface"
73
+ )
74
+
75
+ print(f"\nβœ… Annotation interface created!")
76
+ print(f"πŸ“ Files created:")
77
+ print(f" - ./training_annotation_interface/ohca_annotation.xlsx")
78
+ print(f" - ./training_annotation_interface/annotation_guidelines.md")
79
+
80
+ # ==========================================================================
81
+ # MANUAL ANNOTATION PHASE
82
+ # ==========================================================================
83
+
84
+ print("\n" + "="*60)
85
+ print("⏸️ MANUAL ANNOTATION REQUIRED")
86
+ print("="*60)
87
+ print("Before continuing, you need to:")
88
+ print("1. Open: ./training_annotation_interface/ohca_annotation.xlsx")
89
+ print("2. Read: ./training_annotation_interface/annotation_guidelines.md")
90
+ print("3. Manually label each case:")
91
+ print(" - 1 = OHCA (out-of-hospital cardiac arrest)")
92
+ print(" - 0 = Non-OHCA (everything else)")
93
+ print("4. Fill in confidence scores (1-5)")
94
+ print("5. Save the Excel file")
95
+ print("6. Run continue_training_after_annotation()")
96
+ print("="*60)
97
+
98
+ # For demonstration, create mock annotations
99
+ print("\nπŸ”§ Creating mock annotations for demonstration...")
100
+
101
+ annotation_df = pd.read_excel("./training_annotation_interface/ohca_annotation.xlsx")
102
+
103
+ # Simple rule-based mock labeling (in practice, this is done manually)
104
+ def mock_label(text):
105
+ text_lower = str(text).lower()
106
+ if 'cardiac arrest' in text_lower and any(word in text_lower for word in ['home', 'work', 'found down', 'parking lot']):
107
+ return 1 # OHCA
108
+ else:
109
+ return 0 # Non-OHCA
110
+
111
+ annotation_df['ohca_label'] = annotation_df['clean_text'].apply(mock_label)
112
+ annotation_df['confidence'] = 4 # Mock confidence
113
+ annotation_df['annotator'] = 'demo'
114
+ annotation_df['annotation_date'] = '2025-01-01'
115
+ annotation_df['notes'] = 'Mock annotation for demo'
116
+
117
+ # Save completed annotations
118
+ completed_file = "./training_annotation_interface/ohca_annotation_completed.xlsx"
119
+ annotation_df.to_excel(completed_file, index=False)
120
+
121
+ print(f"βœ… Mock annotations created: {completed_file}")
122
+
123
+ # Continue with training
124
+ return continue_training_after_annotation(completed_file)
125
+
126
+ def continue_training_after_annotation(annotation_file):
127
+ """Continue training after manual annotation is complete"""
128
+
129
+ print("\nπŸ”„ CONTINUING TRAINING AFTER ANNOTATION")
130
+ print("="*50)
131
+
132
+ # ==========================================================================
133
+ # STEP 3: Prepare training data
134
+ # ==========================================================================
135
+
136
+ print("\nπŸ“Š STEP 3: Preparing Training Data")
137
+ print("-" * 40)
138
+
139
+ # Load completed annotations
140
+ labeled_df = pd.read_excel(annotation_file)
141
+
142
+ # Prepare training datasets
143
+ train_dataset, val_dataset, train_df, tokenizer = prepare_training_data(labeled_df)
144
+
145
+ # ==========================================================================
146
+ # STEP 4: Train the model
147
+ # ==========================================================================
148
+
149
+ print("\nπŸ‹οΈ STEP 4: Training Model")
150
+ print("-" * 40)
151
+
152
+ model, trained_tokenizer = train_ohca_model(
153
+ train_dataset=train_dataset,
154
+ val_dataset=val_dataset,
155
+ train_df=train_df,
156
+ tokenizer=tokenizer,
157
+ num_epochs=3,
158
+ save_path="./trained_ohca_model"
159
+ )
160
+
161
+ # ==========================================================================
162
+ # STEP 5: Evaluate the model
163
+ # ==========================================================================
164
+
165
+ print("\nπŸ“ˆ STEP 5: Evaluating Model")
166
+ print("-" * 40)
167
+
168
+ evaluation_results = evaluate_model(
169
+ model=model,
170
+ val_dataset=val_dataset,
171
+ save_results=True,
172
+ results_path="./trained_ohca_model/evaluation_results.txt"
173
+ )
174
+
175
+ # ==========================================================================
176
+ # STEP 6: Training complete summary
177
+ # ==========================================================================
178
+
179
+ print("\n" + "="*60)
180
+ print("πŸŽ‰ TRAINING COMPLETE!")
181
+ print("="*60)
182
+
183
+ print(f"πŸ“ Model saved to: ./trained_ohca_model/")
184
+ print(f"πŸ“Š Evaluation results: ./trained_ohca_model/evaluation_results.txt")
185
+
186
+ print(f"\nπŸ“ˆ Performance Summary:")
187
+ print(f" AUC-ROC: {evaluation_results['auc']:.3f}")
188
+ print(f" F1-Score: {evaluation_results['optimal_metrics']['f1']:.3f}")
189
+ print(f" Sensitivity: {evaluation_results['optimal_metrics']['recall']:.1%}")
190
+ print(f" Specificity: {evaluation_results['optimal_metrics']['specificity']:.1%}")
191
+
192
+ print(f"\n🎯 Next Steps:")
193
+ print(f" 1. Review evaluation results")
194
+ print(f" 2. Test model on new data using inference module")
195
+ print(f" 3. Deploy model for clinical use")
196
+ print(f" 4. Consider retraining with more data if needed")
197
+
198
+ return {
199
+ 'model_path': "./trained_ohca_model/",
200
+ 'evaluation_results': evaluation_results,
201
+ 'training_data_size': len(train_dataset),
202
+ 'validation_data_size': len(val_dataset)
203
+ }
204
+
205
+ def quick_training_example():
206
+ """Simplified training example using the complete pipeline function"""
207
+
208
+ print("⚑ Quick Training Pipeline Example")
209
+ print("="*40)
210
+
211
+ # Use the complete pipeline function
212
+ data_path = "path/to/your/discharge_notes.csv"
213
+
214
+ # Step 1: Create annotation interface
215
+ result = complete_training_pipeline(
216
+ data_path=data_path,
217
+ annotation_dir="./quick_annotation_interface",
218
+ model_save_path="./quick_trained_model"
219
+ )
220
+
221
+ print(f"Annotation files created:")
222
+ print(f" πŸ“„ {result['annotation_file']}")
223
+ print(f" πŸ“‹ {result['guidelines_file']}")
224
+
225
+ # After manual annotation, continue with:
226
+ # final_result = complete_annotation_and_train(
227
+ # annotation_file=result['annotation_file'],
228
+ # model_save_path="./quick_trained_model",
229
+ # num_epochs=3
230
+ # )
231
+
232
+ return result
233
+
234
+ def training_tips_and_best_practices():
235
+ """Tips for successful OHCA model training"""
236
+
237
+ print("πŸ’‘ OHCA Training Tips & Best Practices")
238
+ print("="*45)
239
+
240
+ print("\nπŸ“‹ Data Preparation:")
241
+ print(" β€’ Ensure discharge notes are well-cleaned")
242
+ print(" β€’ Include diverse hospital systems if possible")
243
+ print(" β€’ Minimum 200-300 cases for reliable training")
244
+ print(" β€’ Aim for 10-30% OHCA prevalence in sample")
245
+
246
+ print("\n🏷️ Annotation Guidelines:")
247
+ print(" β€’ Be consistent with OHCA definition")
248
+ print(" β€’ Focus on PRIMARY reason for admission")
249
+ print(" β€’ Use confidence scores to flag uncertain cases")
250
+ print(" β€’ Consider inter-annotator agreement for quality")
251
+
252
+ print("\nπŸ”§ Model Training:")
253
+ print(" β€’ Start with 3 epochs, increase if underfitting")
254
+ print(" β€’ Monitor for overfitting in small datasets")
255
+ print(" β€’ Consider class balancing for imbalanced data")
256
+ print(" β€’ Use validation set to tune hyperparameters")
257
+
258
+ print("\nπŸ“Š Model Evaluation:")
259
+ print(" β€’ Prioritize sensitivity (catching OHCA cases)")
260
+ print(" β€’ Balance sensitivity vs specificity for use case")
261
+ print(" β€’ AUC > 0.8 indicates good performance")
262
+ print(" β€’ F1-score > 0.7 suggests balanced performance")
263
+
264
+ print("\n🎯 Model Deployment:")
265
+ print(" β€’ Test on held-out dataset before deployment")
266
+ print(" β€’ Consider probability thresholds for clinical use")
267
+ print(" β€’ Plan for model monitoring and retraining")
268
+ print(" β€’ Document model limitations and scope")
269
+
270
+ if __name__ == "__main__":
271
+ print("OHCA Training Examples")
272
+ print("="*25)
273
+
274
+ print("\nChoose an example:")
275
+ print("1. Complete training pipeline")
276
+ print("2. Quick training example")
277
+ print("3. Training tips and best practices")
278
+
279
+ choice = input("\nEnter choice (1-3): ").strip()
280
+
281
+ if choice == "1":
282
+ example_training_pipeline()
283
+ elif choice == "2":
284
+ quick_training_example()
285
+ elif choice == "3":
286
+ training_tips_and_best_practices()
287
+ else:
288
+ print("Running complete training pipeline by default...")
289
+ example_training_pipeline()