monajm36 commited on
Commit
39e0406
·
unverified ·
1 Parent(s): 9f9eb02

Update training_example.py

Browse files
Files changed (1) hide show
  1. examples/training_example.py +399 -179
examples/training_example.py CHANGED
@@ -1,17 +1,28 @@
1
  """
2
- OHCA Training Pipeline Example
3
 
4
- This example shows how to train an OHCA classifier from scratch.
 
5
  """
6
 
7
  import pandas as pd
8
  import sys
9
  import os
10
 
11
- # Add src to path
12
  sys.path.append('../src')
13
 
 
14
  from ohca_training_pipeline import (
 
 
 
 
 
 
 
 
 
15
  create_training_sample,
16
  prepare_training_data,
17
  train_ohca_model,
@@ -20,134 +31,219 @@ from ohca_training_pipeline import (
20
  complete_annotation_and_train
21
  )
22
 
23
- def example_training_pipeline():
24
- """Complete example of training an OHCA classifier"""
25
 
26
- print("🚀 OHCA Training Pipeline Example")
27
- print("="*50)
28
 
29
  # ==========================================================================
30
- # STEP 1: Prepare your data
31
  # ==========================================================================
32
 
33
- # Your discharge notes should be in CSV format with columns:
34
- # - hadm_id: Unique identifier for each hospital admission
 
 
 
 
35
  # - clean_text: Cleaned discharge note text
36
 
37
- data_path = "path/to/your/discharge_notes.csv"
38
 
39
- # For demonstration, create sample data
40
  if not os.path.exists(data_path):
41
- print("Creating sample data for demonstration...")
42
 
43
- sample_data = {
44
- 'hadm_id': [f'HADM_{i:06d}' for i in range(2000)],
45
- 'clean_text': [
46
- "Chief complaint: Cardiac arrest at home. Patient found down by family members, CPR initiated immediately. EMS called, patient transported to ED.",
47
- "Chief complaint: Chest pain. Patient presents with acute onset chest pain, no loss of consciousness, no arrest occurred.",
48
- "Chief complaint: Shortness of breath. Patient has chronic heart failure exacerbation, stable vital signs throughout admission.",
49
- "Chief complaint: Patient found down, cardiac arrest in parking lot, bystander CPR given, ROSC achieved by EMS in field.",
50
- "Chief complaint: Syncope. Patient had brief loss of consciousness but no cardiac arrest, workup negative for cardiac causes.",
51
- "Chief complaint: Transfer from outside hospital. Patient had witnessed cardiac arrest at work, CPR by coworkers, transferred for cardiac catheterization.",
52
- ] * 334 # Repeat to get 2000+ samples
53
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  df = pd.DataFrame(sample_data)
56
  df.to_csv(data_path, index=False)
57
- print(f"Sample data saved to: {data_path}")
 
58
 
59
  # ==========================================================================
60
- # STEP 2: Create annotation sample
61
  # ==========================================================================
62
 
63
- print("\n📝 STEP 2: Creating Annotation Sample")
64
- print("-" * 40)
65
 
66
  df = pd.read_csv(data_path)
67
- print(f"Loaded {len(df):,} discharge notes")
68
 
69
- # Create balanced sample for annotation
70
- annotation_result = create_training_sample(
71
- df,
72
- output_dir="./training_annotation_interface"
 
 
73
  )
74
 
75
- print(f"\n✅ Annotation interface created!")
76
- print(f"📁 Files created:")
77
- print(f" - ./training_annotation_interface/ohca_annotation.xlsx")
78
- print(f" - ./training_annotation_interface/annotation_guidelines.md")
 
 
 
 
 
 
 
 
79
 
80
  # ==========================================================================
81
- # MANUAL ANNOTATION PHASE
82
  # ==========================================================================
83
 
84
- print("\n" + "="*60)
85
- print("⏸️ MANUAL ANNOTATION REQUIRED")
86
- print("="*60)
 
 
 
 
 
87
  print("Before continuing, you need to:")
88
- print("1. Open: ./training_annotation_interface/ohca_annotation.xlsx")
89
- print("2. Read: ./training_annotation_interface/annotation_guidelines.md")
90
- print("3. Manually label each case:")
91
- print(" - 1 = OHCA (out-of-hospital cardiac arrest)")
92
- print(" - 0 = Non-OHCA (everything else)")
93
- print("4. Fill in confidence scores (1-5)")
94
- print("5. Save the Excel file")
95
- print("6. Run continue_training_after_annotation()")
96
- print("="*60)
97
-
98
- # For demonstration, create mock annotations
99
- print("\n🔧 Creating mock annotations for demonstration...")
100
-
101
- annotation_df = pd.read_excel("./training_annotation_interface/ohca_annotation.xlsx")
102
-
103
- # Simple rule-based mock labeling (in practice, this is done manually)
104
- def mock_label(text):
105
- text_lower = str(text).lower()
106
- if 'cardiac arrest' in text_lower and any(word in text_lower for word in ['home', 'work', 'found down', 'parking lot']):
107
- return 1 # OHCA
108
- else:
109
- return 0 # Non-OHCA
110
-
111
- annotation_df['ohca_label'] = annotation_df['clean_text'].apply(mock_label)
112
- annotation_df['confidence'] = 4 # Mock confidence
113
- annotation_df['annotator'] = 'demo'
114
- annotation_df['annotation_date'] = '2025-01-01'
115
- annotation_df['notes'] = 'Mock annotation for demo'
116
-
117
- # Save completed annotations
118
- completed_file = "./training_annotation_interface/ohca_annotation_completed.xlsx"
119
- annotation_df.to_excel(completed_file, index=False)
120
-
121
- print(f"✅ Mock annotations created: {completed_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # Continue with training
124
- return continue_training_after_annotation(completed_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- def continue_training_after_annotation(annotation_file):
127
- """Continue training after manual annotation is complete"""
128
 
129
- print("\n🔄 CONTINUING TRAINING AFTER ANNOTATION")
130
- print("="*50)
131
 
132
  # ==========================================================================
133
- # STEP 3: Prepare training data
134
  # ==========================================================================
135
 
136
- print("\n📊 STEP 3: Preparing Training Data")
137
- print("-" * 40)
 
 
 
138
 
139
- # Load completed annotations
140
- labeled_df = pd.read_excel(annotation_file)
 
 
141
 
142
- # Prepare training datasets
143
- train_dataset, val_dataset, train_df, tokenizer = prepare_training_data(labeled_df)
 
 
144
 
145
  # ==========================================================================
146
- # STEP 4: Train the model
147
  # ==========================================================================
148
 
149
- print("\n🏋️ STEP 4: Training Model")
150
- print("-" * 40)
151
 
152
  model, trained_tokenizer = train_ohca_model(
153
  train_dataset=train_dataset,
@@ -155,135 +251,259 @@ def continue_training_after_annotation(annotation_file):
155
  train_df=train_df,
156
  tokenizer=tokenizer,
157
  num_epochs=3,
158
- save_path="./trained_ohca_model"
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
 
 
 
 
161
  # ==========================================================================
162
- # STEP 5: Evaluate the model
163
  # ==========================================================================
164
 
165
- print("\n📈 STEP 5: Evaluating Model")
166
  print("-" * 40)
167
 
168
- evaluation_results = evaluate_model(
 
 
 
 
 
 
 
 
 
 
 
169
  model=model,
170
- val_dataset=val_dataset,
171
- save_results=True,
172
- results_path="./trained_ohca_model/evaluation_results.txt"
173
  )
174
 
175
  # ==========================================================================
176
- # STEP 6: Training complete summary
177
  # ==========================================================================
178
 
179
- print("\n" + "="*60)
180
- print("🎉 TRAINING COMPLETE!")
181
- print("="*60)
182
 
183
- print(f"📁 Model saved to: ./trained_ohca_model/")
184
- print(f"📊 Evaluation results: ./trained_ohca_model/evaluation_results.txt")
 
 
 
 
 
 
185
 
186
- print(f"\n📈 Performance Summary:")
187
- print(f" AUC-ROC: {evaluation_results['auc']:.3f}")
188
- print(f" F1-Score: {evaluation_results['optimal_metrics']['f1']:.3f}")
189
- print(f" Sensitivity: {evaluation_results['optimal_metrics']['recall']:.1%}")
190
- print(f" Specificity: {evaluation_results['optimal_metrics']['specificity']:.1%}")
191
 
192
- print(f"\n🎯 Next Steps:")
193
- print(f" 1. Review evaluation results")
194
- print(f" 2. Test model on new data using inference module")
195
- print(f" 3. Deploy model for clinical use")
196
- print(f" 4. Consider retraining with more data if needed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  return {
199
- 'model_path': "./trained_ohca_model/",
200
- 'evaluation_results': evaluation_results,
201
- 'training_data_size': len(train_dataset),
202
- 'validation_data_size': len(val_dataset)
 
 
 
 
 
 
 
 
203
  }
204
 
205
- def quick_training_example():
206
- """Simplified training example using the complete pipeline function"""
 
 
 
207
 
208
- print(" Quick Training Pipeline Example")
209
- print("="*40)
 
 
 
 
 
 
210
 
211
- # Use the complete pipeline function
212
- data_path = "path/to/your/discharge_notes.csv"
213
 
214
- # Step 1: Create annotation interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  result = complete_training_pipeline(
216
  data_path=data_path,
217
- annotation_dir="./quick_annotation_interface",
218
- model_save_path="./quick_trained_model"
219
  )
220
 
221
- print(f"Annotation files created:")
222
- print(f" 📄 {result['annotation_file']}")
223
- print(f" 📋 {result['guidelines_file']}")
224
 
225
- # After manual annotation, continue with:
226
- # final_result = complete_annotation_and_train(
227
- # annotation_file=result['annotation_file'],
228
- # model_save_path="./quick_trained_model",
229
- # num_epochs=3
230
- # )
231
 
232
  return result
233
 
234
- def training_tips_and_best_practices():
235
- """Tips for successful OHCA model training"""
236
 
237
- print("💡 OHCA Training Tips & Best Practices")
238
  print("="*45)
239
 
240
- print("\n📋 Data Preparation:")
241
- print(" • Ensure discharge notes are well-cleaned")
242
- print(" • Include diverse hospital systems if possible")
243
- print(" • Minimum 200-300 cases for reliable training")
244
- print(" • Aim for 10-30% OHCA prevalence in sample")
245
-
246
- print("\n🏷️ Annotation Guidelines:")
247
- print(" • Be consistent with OHCA definition")
248
- print(" • Focus on PRIMARY reason for admission")
249
- print(" • Use confidence scores to flag uncertain cases")
250
- print(" Consider inter-annotator agreement for quality")
251
-
252
- print("\n🔧 Model Training:")
253
- print(" • Start with 3 epochs, increase if underfitting")
254
- print(" • Monitor for overfitting in small datasets")
255
- print(" • Consider class balancing for imbalanced data")
256
- print(" • Use validation set to tune hyperparameters")
257
-
258
- print("\n📊 Model Evaluation:")
259
- print(" Prioritize sensitivity (catching OHCA cases)")
260
- print(" Balance sensitivity vs specificity for use case")
261
- print(" AUC > 0.8 indicates good performance")
262
- print(" F1-score > 0.7 suggests balanced performance")
263
-
264
- print("\n🎯 Model Deployment:")
265
- print(" Test on held-out dataset before deployment")
266
- print("Consider probability thresholds for clinical use")
267
- print(" Plan for model monitoring and retraining")
268
- print(" • Document model limitations and scope")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  if __name__ == "__main__":
271
- print("OHCA Training Examples")
272
- print("="*25)
273
 
274
- print("\nChoose an example:")
275
- print("1. Complete training pipeline")
276
- print("2. Quick training example")
277
- print("3. Training tips and best practices")
 
278
 
279
- choice = input("\nEnter choice (1-3): ").strip()
280
 
281
  if choice == "1":
282
- example_training_pipeline()
283
  elif choice == "2":
284
- quick_training_example()
285
  elif choice == "3":
286
- training_tips_and_best_practices()
 
 
287
  else:
288
- print("Running complete training pipeline by default...")
289
- example_training_pipeline()
 
1
  """
2
+ OHCA Training Pipeline Example v3.0 - Improved Methodology
3
 
4
+ This example shows how to train an OHCA classifier using the improved v3.0 methodology
5
+ that addresses data scientist feedback about bias, data leakage, and evaluation issues.
6
  """
7
 
8
  import pandas as pd
9
  import sys
10
  import os
11
 
12
+ # Add src to path for development
13
  sys.path.append('../src')
14
 
15
+ # v3.0 imports - improved functions
16
  from ohca_training_pipeline import (
17
+ # Recommended v3.0 functions
18
+ complete_improved_training_pipeline,
19
+ complete_annotation_and_train_v3,
20
+ create_patient_level_splits,
21
+ find_optimal_threshold,
22
+ evaluate_on_test_set,
23
+ save_model_with_metadata,
24
+
25
+ # Legacy functions (for backward compatibility examples)
26
  create_training_sample,
27
  prepare_training_data,
28
  train_ohca_model,
 
31
  complete_annotation_and_train
32
  )
33
 
34
+ def improved_training_example():
35
+ """Complete example using v3.0 methodology (RECOMMENDED)"""
36
 
37
+ print("OHCA Training Pipeline v3.0 - Improved Methodology Example")
38
+ print("="*65)
39
 
40
  # ==========================================================================
41
+ # STEP 1: Prepare your data with required columns
42
  # ==========================================================================
43
 
44
+ print("\n1. Data Preparation with Patient-Level Information")
45
+ print("-" * 55)
46
+
47
+ # Your discharge notes need these columns:
48
+ # - hadm_id: Unique admission identifier
49
+ # - subject_id: Patient identifier (for preventing data leakage)
50
  # - clean_text: Cleaned discharge note text
51
 
52
+ data_path = "enhanced_discharge_notes_v3.csv"
53
 
 
54
  if not os.path.exists(data_path):
55
+ print("Creating enhanced sample data with patient IDs...")
56
 
57
+ # Create more realistic sample data with patient relationships
58
+ sample_data = []
59
+
60
+ # Generate patients with multiple admissions (realistic scenario)
61
+ for patient_id in range(1, 501): # 500 patients
62
+ num_admissions = np.random.choice([1, 2, 3], p=[0.7, 0.2, 0.1]) # Most patients have 1 admission
63
+
64
+ for admission in range(num_admissions):
65
+ hadm_id = f'HADM_{patient_id:04d}_{admission+1:02d}'
66
+ subject_id = f'SUBJ_{patient_id:04d}'
67
+
68
+ # Create diverse clinical scenarios
69
+ scenarios = [
70
+ "Chief complaint: Cardiac arrest at home. Patient found down by family members, immediate CPR initiated, EMS transport with ROSC achieved.",
71
+ "Chief complaint: Chest pain. Patient presents with acute onset substernal chest pain, troponins negative, no arrest occurred during stay.",
72
+ "Chief complaint: Shortness of breath. Patient with chronic heart failure exacerbation, treated with diuretics, stable course.",
73
+ "Chief complaint: Found down at work. Witnessed cardiac arrest, coworker CPR, AED shock delivered, transported by EMS.",
74
+ "Chief complaint: Syncope. Patient had brief loss of consciousness, no cardiac arrest, extensive workup negative.",
75
+ "Chief complaint: Transfer for cardiac catheterization. Patient had OHCA at restaurant, bystander CPR, achieved ROSC.",
76
+ "Chief complaint: Diabetes management. Routine admission for hyperglycemia, no acute cardiac events during hospitalization.",
77
+ "Chief complaint: Pneumonia. Community-acquired pneumonia, treated with antibiotics, good clinical response achieved.",
78
+ "Chief complaint: Cardiac arrest in parking garage. Security guard CPR, EMS defibrillation, neurologically intact.",
79
+ "Chief complaint: Routine elective surgery. Planned procedure completed successfully, no complications during stay."
80
+ ]
81
+
82
+ text = np.random.choice(scenarios)
83
+
84
+ sample_data.append({
85
+ 'hadm_id': hadm_id,
86
+ 'subject_id': subject_id, # This prevents data leakage
87
+ 'clean_text': text
88
+ })
89
 
90
  df = pd.DataFrame(sample_data)
91
  df.to_csv(data_path, index=False)
92
+ print(f"Enhanced sample data saved to: {data_path}")
93
+ print(f"Created {len(df)} admissions from {df['subject_id'].nunique()} unique patients")
94
 
95
  # ==========================================================================
96
+ # STEP 2: Create patient-level splits and annotation samples
97
  # ==========================================================================
98
 
99
+ print(f"\n2. Patient-Level Splits and Annotation Sample Creation")
100
+ print("-" * 60)
101
 
102
  df = pd.read_csv(data_path)
103
+ print(f"Loaded {len(df):,} discharge notes from {df['subject_id'].nunique():,} patients")
104
 
105
+ # Use improved pipeline that creates proper splits
106
+ annotation_result = complete_improved_training_pipeline(
107
+ data_path=data_path,
108
+ annotation_dir="./v3_training_annotation",
109
+ train_sample_size=800, # Much larger than legacy 264 samples
110
+ val_sample_size=200 # Separate validation sample
111
  )
112
 
113
+ print(f"\nImproved annotation interface created!")
114
+ print(f"Key improvements over legacy method:")
115
+ print(f" Patient-level splits prevent data leakage")
116
+ print(f" Larger training sample (800 vs 264 cases)")
117
+ print(f" Separate validation sample (200 cases)")
118
+ print(f" Independent test set reserved for final evaluation")
119
+
120
+ print(f"\nFiles created:")
121
+ print(f" Training: {annotation_result['train_annotation_file']}")
122
+ print(f" Validation: {annotation_result['val_annotation_file']}")
123
+ print(f" Guidelines: {annotation_result['guidelines_file']}")
124
+ print(f" Test set: {annotation_result['test_file']} (DO NOT ANNOTATE)")
125
 
126
  # ==========================================================================
127
+ # MANUAL ANNOTATION PHASE (ENHANCED)
128
  # ==========================================================================
129
 
130
+ print("\n" + "="*70)
131
+ print("MANUAL ANNOTATION REQUIRED - v3.0 METHODOLOGY")
132
+ print("="*70)
133
+ print("IMPORTANT CHANGES IN v3.0:")
134
+ print("You now have TWO separate files to annotate:")
135
+ print("1. Training file (800 cases) - Used for model training")
136
+ print("2. Validation file (200 cases) - Used for threshold optimization")
137
+ print()
138
  print("Before continuing, you need to:")
139
+ print("1. Read guidelines: ./v3_training_annotation/annotation_guidelines_v3.md")
140
+ print("2. Annotate TRAINING file: train_annotation.xlsx")
141
+ print("3. Annotate VALIDATION file: validation_annotation.xlsx")
142
+ print("4. For each case, label: 1=OHCA, 0=Non-OHCA")
143
+ print("5. Fill confidence scores and notes")
144
+ print("6. Save both Excel files")
145
+ print("7. Run continue_v3_training_after_annotation()")
146
+ print()
147
+ print("Key benefits of separate annotation:")
148
+ print(" Prevents threshold tuning bias")
149
+ print(" Allows proper model evaluation")
150
+ print(" Provides unbiased performance estimates")
151
+ print("="*70)
152
+
153
+ # Create mock annotations for demonstration
154
+ print(f"\nCreating mock annotations for demonstration...")
155
+ return create_mock_annotations_v3(annotation_result)
156
+
157
+ def create_mock_annotations_v3(annotation_result):
158
+ """Create mock annotations for both training and validation files"""
159
+
160
+ import numpy as np
161
+
162
+ # Mock annotate training file
163
+ train_df = pd.read_excel(annotation_result['train_annotation_file'])
164
+ train_df['ohca_label'] = train_df['clean_text'].apply(mock_label_function)
165
+ train_df['confidence'] = np.random.choice([3, 4, 5], size=len(train_df), p=[0.3, 0.5, 0.2])
166
+ train_df['annotator'] = 'demo_v3'
167
+ train_df['annotation_date'] = '2025-01-01'
168
+ train_df['notes'] = 'Mock annotation for v3.0 demo'
169
+
170
+ train_completed = "./v3_training_annotation/train_annotation_completed.xlsx"
171
+ train_df.to_excel(train_completed, index=False)
172
+
173
+ # Mock annotate validation file
174
+ val_df = pd.read_excel(annotation_result['val_annotation_file'])
175
+ val_df['ohca_label'] = val_df['clean_text'].apply(mock_label_function)
176
+ val_df['confidence'] = np.random.choice([3, 4, 5], size=len(val_df), p=[0.3, 0.5, 0.2])
177
+ val_df['annotator'] = 'demo_v3'
178
+ val_df['annotation_date'] = '2025-01-01'
179
+ val_df['notes'] = 'Mock annotation for v3.0 demo'
180
+
181
+ val_completed = "./v3_training_annotation/validation_annotation_completed.xlsx"
182
+ val_df.to_excel(val_completed, index=False)
183
+
184
+ print(f"Mock annotations created:")
185
+ print(f" Training: {train_completed} ({len(train_df)} cases)")
186
+ print(f" Validation: {val_completed} ({len(val_df)} cases)")
187
+ print(f" Training OHCA prevalence: {train_df['ohca_label'].mean():.1%}")
188
+ print(f" Validation OHCA prevalence: {val_df['ohca_label'].mean():.1%}")
189
 
190
  # Continue with training
191
+ return continue_v3_training_after_annotation(
192
+ train_completed, val_completed, annotation_result['test_file']
193
+ )
194
+
195
+ def mock_label_function(text):
196
+ """Simple rule-based mock labeling (in practice, done manually)"""
197
+ text_lower = str(text).lower()
198
+
199
+ # Look for OHCA indicators
200
+ ohca_terms = ['cardiac arrest', 'found down', 'cpr', 'rosc', 'aed shock', 'defibrillation']
201
+ location_terms = ['home', 'work', 'restaurant', 'parking', 'gym', 'public']
202
+
203
+ has_arrest = any(term in text_lower for term in ohca_terms)
204
+ has_location = any(term in text_lower for term in location_terms)
205
+
206
+ # Exclude in-hospital events and non-primary reasons
207
+ exclude_terms = ['transfer', 'routine', 'elective', 'diabetes', 'pneumonia']
208
+ is_excluded = any(term in text_lower for term in exclude_terms)
209
+
210
+ if has_arrest and has_location and not is_excluded:
211
+ return 1 # OHCA
212
+ else:
213
+ return 0 # Non-OHCA
214
 
215
+ def continue_v3_training_after_annotation(train_file, val_file, test_file):
216
+ """Continue v3.0 training after manual annotation is complete"""
217
 
218
+ print(f"\nCONTINUING v3.0 TRAINING AFTER ANNOTATION")
219
+ print("="*55)
220
 
221
  # ==========================================================================
222
+ # STEP 3: Prepare training data from separate files
223
  # ==========================================================================
224
 
225
+ print(f"\n3. Enhanced Data Preparation")
226
+ print("-" * 35)
227
+
228
+ # Use improved data preparation for separate files
229
+ from ohca_training_pipeline import prepare_training_data
230
 
231
+ # This function now handles separate train/val files
232
+ train_dataset, val_dataset, train_df, val_df, tokenizer = prepare_training_data(
233
+ train_file, val_file
234
+ )
235
 
236
+ print(f"Enhanced data preparation complete:")
237
+ print(f" Training samples: {len(train_dataset)} (after balancing)")
238
+ print(f" Validation samples: {len(val_dataset)}")
239
+ print(f" Separate files prevent data leakage")
240
 
241
  # ==========================================================================
242
+ # STEP 4: Train model
243
  # ==========================================================================
244
 
245
+ print(f"\n4. Model Training")
246
+ print("-" * 20)
247
 
248
  model, trained_tokenizer = train_ohca_model(
249
  train_dataset=train_dataset,
 
251
  train_df=train_df,
252
  tokenizer=tokenizer,
253
  num_epochs=3,
254
+ save_path="./trained_ohca_model_v3"
255
+ )
256
+
257
+ # ==========================================================================
258
+ # STEP 5: Find optimal threshold on validation set
259
+ # ==========================================================================
260
+
261
+ print(f"\n5. Optimal Threshold Finding (v3.0 Innovation)")
262
+ print("-" * 55)
263
+
264
+ optimal_threshold, val_metrics = find_optimal_threshold(
265
+ model=model,
266
+ tokenizer=trained_tokenizer,
267
+ val_df=val_df
268
  )
269
 
270
+ print(f"Optimal threshold found: {optimal_threshold:.3f}")
271
+ print(f"This addresses the data scientist's concern about threshold optimization!")
272
+
273
  # ==========================================================================
274
+ # STEP 6: Final evaluation on independent test set
275
  # ==========================================================================
276
 
277
+ print(f"\n6. Unbiased Test Set Evaluation")
278
  print("-" * 40)
279
 
280
+ # Load test set
281
+ test_df = pd.read_csv(test_file)
282
+
283
+ print(f"Independent test set: {len(test_df)} cases")
284
+ print(f"Note: In practice, you would manually annotate a subset of test cases")
285
+ print(f"For demonstration, we'll simulate this step")
286
+
287
+ # In practice, you would manually annotate test cases here
288
+ # For demo, we'll create mock test labels
289
+ test_df['label'] = test_df['clean_text'].apply(mock_label_function)
290
+
291
+ test_metrics = evaluate_on_test_set(
292
  model=model,
293
+ tokenizer=trained_tokenizer,
294
+ test_df=test_df,
295
+ optimal_threshold=optimal_threshold
296
  )
297
 
298
  # ==========================================================================
299
+ # STEP 7: Save model with metadata
300
  # ==========================================================================
301
 
302
+ print(f"\n7. Enhanced Model Saving with Metadata")
303
+ print("-" * 45)
 
304
 
305
+ save_model_with_metadata(
306
+ model=model,
307
+ tokenizer=trained_tokenizer,
308
+ optimal_threshold=optimal_threshold,
309
+ val_metrics=val_metrics,
310
+ test_metrics=test_metrics,
311
+ model_save_path="./trained_ohca_model_v3"
312
+ )
313
 
314
+ # ==========================================================================
315
+ # STEP 8: Training complete summary
316
+ # ==========================================================================
 
 
317
 
318
+ print(f"\n" + "="*70)
319
+ print("v3.0 TRAINING COMPLETE - METHODOLOGY IMPROVEMENTS IMPLEMENTED")
320
+ print("="*70)
321
+
322
+ print(f"Model and metadata saved to: ./trained_ohca_model_v3/")
323
+
324
+ print(f"\nPerformance Summary (Unbiased Evaluation):")
325
+ print(f" Validation F1-Score: {val_metrics['f1_score']:.3f}")
326
+ print(f" Validation Sensitivity: {val_metrics['sensitivity']:.1%}")
327
+ print(f" Validation Specificity: {val_metrics['specificity']:.1%}")
328
+ print(f" Test Accuracy: {test_metrics['test_accuracy']:.1%}")
329
+ print(f" Test F1-Score: {test_metrics['test_f1_score']:.3f}")
330
+
331
+ print(f"\nv3.0 Improvements Implemented:")
332
+ print(f" Patient-level splits prevent data leakage")
333
+ print(f" Proper train/validation/test methodology")
334
+ print(f" Optimal threshold: {optimal_threshold:.3f} (saved with model)")
335
+ print(f" Larger training set: {len(train_dataset)} samples")
336
+ print(f" Unbiased evaluation on independent test set")
337
+ print(f" Enhanced metadata and model versioning")
338
+
339
+ print(f"\nNext Steps:")
340
+ print(f" 1. Model automatically uses optimal threshold during inference")
341
+ print(f" 2. Enhanced clinical decision support available")
342
+ print(f" 3. Use quick_inference_with_optimal_threshold() for new data")
343
+ print(f" 4. Monitor performance and retrain as needed")
344
 
345
  return {
346
+ 'model_path': "./trained_ohca_model_v3/",
347
+ 'optimal_threshold': optimal_threshold,
348
+ 'val_metrics': val_metrics,
349
+ 'test_metrics': test_metrics,
350
+ 'training_methodology': 'v3.0',
351
+ 'improvements_implemented': [
352
+ 'Patient-level data splits',
353
+ 'Separate train/validation annotation',
354
+ 'Optimal threshold optimization',
355
+ 'Independent test set evaluation',
356
+ 'Enhanced model metadata'
357
+ ]
358
  }
359
 
360
+ def legacy_training_example():
361
+ """Legacy training example for comparison/backward compatibility"""
362
+
363
+ print("Legacy Training Pipeline Example (for comparison)")
364
+ print("="*55)
365
 
366
+ print("WARNING: This demonstrates the OLD methodology with known issues:")
367
+ print(" Small sample size (330 total, 264 training)")
368
+ print(" No patient-level splits (data leakage possible)")
369
+ print(" Threshold optimization on same validation set used for evaluation")
370
+ print(" No independent test set")
371
+ print()
372
+ print("This is maintained for backward compatibility only.")
373
+ print("RECOMMENDATION: Use improved_training_example() instead!")
374
 
375
+ data_path = "legacy_discharge_notes.csv"
 
376
 
377
+ # Create simple legacy data
378
+ if not os.path.exists(data_path):
379
+ legacy_data = {
380
+ 'hadm_id': [f'LEG_{i:06d}' for i in range(1000)],
381
+ 'clean_text': [
382
+ "Chief complaint: Cardiac arrest at home.",
383
+ "Chief complaint: Chest pain, no arrest.",
384
+ "Chief complaint: Found down, cardiac arrest.",
385
+ "Chief complaint: Shortness of breath.",
386
+ "Chief complaint: Syncope, no arrest.",
387
+ ] * 200
388
+ }
389
+ pd.DataFrame(legacy_data).to_csv(data_path, index=False)
390
+
391
+ # Use legacy pipeline
392
  result = complete_training_pipeline(
393
  data_path=data_path,
394
+ annotation_dir="./legacy_annotation",
395
+ model_save_path="./legacy_trained_model"
396
  )
397
 
398
+ print(f"Legacy annotation file created: {result['annotation_file']}")
399
+ print(f"Annotation sample size: 330 cases (small compared to v3.0's 1000)")
 
400
 
401
+ print(f"\nLegacy method limitations:")
402
+ print(f" Single annotation file instead of separate train/val")
403
+ print(f" No optimal threshold finding")
404
+ print(f" No patient-level data protection")
405
+ print(f" Biased evaluation methodology")
 
406
 
407
  return result
408
 
409
+ def methodology_comparison():
410
+ """Compare v3.0 vs legacy methodologies side by side"""
411
 
412
+ print("v3.0 vs Legacy Methodology Comparison")
413
  print("="*45)
414
 
415
+ comparison_table = """
416
+ Aspect | Legacy Method | v3.0 Improved Method
417
+ ----------------------- | ---------------------- | ----------------------
418
+ Sample Size | 330 total (264 train) | 1000+ total (800 train)
419
+ Data Splits | Random note-level | Patient-level splits
420
+ Annotation Files | 1 file (biased) | 2 files (unbiased)
421
+ Threshold Selection | Static 0.5 or manual | Optimal from validation
422
+ Evaluation | Same set for tuning | Independent test set
423
+ Data Leakage Risk | High (same patients) | Prevented (patient-level)
424
+ Performance Reliability| Inflated estimates | Unbiased estimates
425
+ Clinical Integration | Basic confidence | Enhanced priorities
426
+ Model Metadata | Limited | Comprehensive
427
+ Methodology Validation | None | Peer-reviewed approach
428
+ """
429
+
430
+ print(comparison_table)
431
+
432
+ print(f"\nKey Data Scientist Concerns Addressed in v3.0:")
433
+ print(f"1. BIAS: Patient-level splits prevent data leakage")
434
+ print(f"2. SAMPLE SIZE: 800 training cases vs 264 in legacy")
435
+ print(f"3. EVALUATION: Independent test set prevents threshold tuning bias")
436
+ print(f"4. THRESHOLD CONSISTENCY: Optimal threshold saved and reused")
437
+ print(f"5. METHODOLOGY: Follows ML best practices")
438
+
439
+ print(f"\nRecommendation:")
440
+ print(f" Use v3.0 methodology for all new model training")
441
+ print(f" Consider retraining legacy models with v3.0 approach")
442
+ print(f" Legacy functions maintained for backward compatibility only")
443
+
444
+ def training_best_practices_v3():
445
+ """Updated best practices for v3.0 methodology"""
446
+
447
+ print("OHCA Training Best Practices - v3.0 Methodology")
448
+ print("="*55)
449
+
450
+ print(f"\nData Preparation (Enhanced):")
451
+ print(f" Ensure you have patient IDs (subject_id column)")
452
+ print(f" Minimum 500+ unique patients for robust splits")
453
+ print(f" Clean and standardize discharge note text")
454
+ print(f" Include diverse hospital systems if possible")
455
+
456
+ print(f"\nAnnotation Strategy (v3.0):")
457
+ print(f" Annotate BOTH training and validation files separately")
458
+ print(f" Training sample: 800+ cases for better performance")
459
+ print(f" Validation sample: 200+ cases for reliable threshold optimization")
460
+ print(f" Reserve test set for final unbiased evaluation")
461
+ print(f" Use consistent OHCA definition across all annotators")
462
+
463
+ print(f"\nModel Training (Improved):")
464
+ print(f" Patient-level splits prevent data leakage")
465
+ print(f" Class balancing handles imbalanced datasets")
466
+ print(f" Monitor training loss to prevent overfitting")
467
+ print(f" Use validation set only for threshold optimization")
468
+
469
+ print(f"\nModel Evaluation (Unbiased):")
470
+ print(f" Find optimal threshold on validation set")
471
+ print(f" Report final performance on independent test set")
472
+ print(f" Never use test set for model selection or tuning")
473
+ print(f" Focus on clinical metrics (sensitivity, specificity)")
474
+
475
+ print(f"\nDeployment (Enhanced):")
476
+ print(f" Model automatically uses optimal threshold")
477
+ print(f" Enhanced clinical decision support built-in")
478
+ print(f" Comprehensive model metadata for tracking")
479
+ print(f" Plan for continuous model monitoring")
480
+
481
+ print(f"\nQuality Assurance:")
482
+ print(f" Validate performance on external datasets")
483
+ print(f" Monitor for distribution drift in new data")
484
+ print(f" Regular retraining with new annotated cases")
485
+ print(f" Document all methodology improvements")
486
 
487
  if __name__ == "__main__":
488
+ print("OHCA Training Examples v3.0 - Improved Methodology")
489
+ print("="*55)
490
 
491
+ print(f"\nAvailable examples:")
492
+ print(f"1. v3.0 Training with Improved Methodology (RECOMMENDED)")
493
+ print(f"2. Legacy Training (backward compatibility)")
494
+ print(f"3. Methodology Comparison (v3.0 vs Legacy)")
495
+ print(f"4. v3.0 Best Practices Guide")
496
 
497
+ choice = input(f"\nEnter choice (1-4): ").strip()
498
 
499
  if choice == "1":
500
+ improved_training_example()
501
  elif choice == "2":
502
+ legacy_training_example()
503
  elif choice == "3":
504
+ methodology_comparison()
505
+ elif choice == "4":
506
+ training_best_practices_v3()
507
  else:
508
+ print(f"Running v3.0 training example by default...")
509
+ improved_training_example()