|
|
""" |
|
|
OHCA Training Pipeline Example v3.0 - Improved Methodology |
|
|
|
|
|
This example shows how to train an OHCA classifier using the improved v3.0 methodology |
|
|
that addresses data scientist feedback about bias, data leakage, and evaluation issues. |
|
|
""" |
|
|
|
|
|
import pandas as pd |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.append('../src') |
|
|
|
|
|
|
|
|
from ohca_training_pipeline import ( |
|
|
|
|
|
complete_improved_training_pipeline, |
|
|
complete_annotation_and_train_v3, |
|
|
create_patient_level_splits, |
|
|
find_optimal_threshold, |
|
|
evaluate_on_test_set, |
|
|
save_model_with_metadata, |
|
|
|
|
|
|
|
|
create_training_sample, |
|
|
prepare_training_data, |
|
|
train_ohca_model, |
|
|
evaluate_model, |
|
|
complete_training_pipeline, |
|
|
complete_annotation_and_train |
|
|
) |
|
|
|
|
|
def improved_training_example(): |
|
|
"""Complete example using v3.0 methodology (RECOMMENDED)""" |
|
|
|
|
|
print("OHCA Training Pipeline v3.0 - Improved Methodology Example") |
|
|
print("="*65) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n1. Data Preparation with Patient-Level Information") |
|
|
print("-" * 55) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_path = "enhanced_discharge_notes_v3.csv" |
|
|
|
|
|
if not os.path.exists(data_path): |
|
|
print("Creating enhanced sample data with patient IDs...") |
|
|
|
|
|
|
|
|
sample_data = [] |
|
|
|
|
|
|
|
|
for patient_id in range(1, 501): |
|
|
num_admissions = np.random.choice([1, 2, 3], p=[0.7, 0.2, 0.1]) |
|
|
|
|
|
for admission in range(num_admissions): |
|
|
hadm_id = f'HADM_{patient_id:04d}_{admission+1:02d}' |
|
|
subject_id = f'SUBJ_{patient_id:04d}' |
|
|
|
|
|
|
|
|
scenarios = [ |
|
|
"Chief complaint: Cardiac arrest at home. Patient found down by family members, immediate CPR initiated, EMS transport with ROSC achieved.", |
|
|
"Chief complaint: Chest pain. Patient presents with acute onset substernal chest pain, troponins negative, no arrest occurred during stay.", |
|
|
"Chief complaint: Shortness of breath. Patient with chronic heart failure exacerbation, treated with diuretics, stable course.", |
|
|
"Chief complaint: Found down at work. Witnessed cardiac arrest, coworker CPR, AED shock delivered, transported by EMS.", |
|
|
"Chief complaint: Syncope. Patient had brief loss of consciousness, no cardiac arrest, extensive workup negative.", |
|
|
"Chief complaint: Transfer for cardiac catheterization. Patient had OHCA at restaurant, bystander CPR, achieved ROSC.", |
|
|
"Chief complaint: Diabetes management. Routine admission for hyperglycemia, no acute cardiac events during hospitalization.", |
|
|
"Chief complaint: Pneumonia. Community-acquired pneumonia, treated with antibiotics, good clinical response achieved.", |
|
|
"Chief complaint: Cardiac arrest in parking garage. Security guard CPR, EMS defibrillation, neurologically intact.", |
|
|
"Chief complaint: Routine elective surgery. Planned procedure completed successfully, no complications during stay." |
|
|
] |
|
|
|
|
|
text = np.random.choice(scenarios) |
|
|
|
|
|
sample_data.append({ |
|
|
'hadm_id': hadm_id, |
|
|
'subject_id': subject_id, |
|
|
'clean_text': text |
|
|
}) |
|
|
|
|
|
df = pd.DataFrame(sample_data) |
|
|
df.to_csv(data_path, index=False) |
|
|
print(f"Enhanced sample data saved to: {data_path}") |
|
|
print(f"Created {len(df)} admissions from {df['subject_id'].nunique()} unique patients") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n2. Patient-Level Splits and Annotation Sample Creation") |
|
|
print("-" * 60) |
|
|
|
|
|
df = pd.read_csv(data_path) |
|
|
print(f"Loaded {len(df):,} discharge notes from {df['subject_id'].nunique():,} patients") |
|
|
|
|
|
|
|
|
annotation_result = complete_improved_training_pipeline( |
|
|
data_path=data_path, |
|
|
annotation_dir="./v3_training_annotation", |
|
|
train_sample_size=800, |
|
|
val_sample_size=200 |
|
|
) |
|
|
|
|
|
print(f"\nImproved annotation interface created!") |
|
|
print(f"Key improvements over legacy method:") |
|
|
print(f" Patient-level splits prevent data leakage") |
|
|
print(f" Larger training sample (800 vs 264 cases)") |
|
|
print(f" Separate validation sample (200 cases)") |
|
|
print(f" Independent test set reserved for final evaluation") |
|
|
|
|
|
print(f"\nFiles created:") |
|
|
print(f" Training: {annotation_result['train_annotation_file']}") |
|
|
print(f" Validation: {annotation_result['val_annotation_file']}") |
|
|
print(f" Guidelines: {annotation_result['guidelines_file']}") |
|
|
print(f" Test set: {annotation_result['test_file']} (DO NOT ANNOTATE)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("MANUAL ANNOTATION REQUIRED - v3.0 METHODOLOGY") |
|
|
print("="*70) |
|
|
print("IMPORTANT CHANGES IN v3.0:") |
|
|
print("You now have TWO separate files to annotate:") |
|
|
print("1. Training file (800 cases) - Used for model training") |
|
|
print("2. Validation file (200 cases) - Used for threshold optimization") |
|
|
print() |
|
|
print("Before continuing, you need to:") |
|
|
print("1. Read guidelines: ./v3_training_annotation/annotation_guidelines_v3.md") |
|
|
print("2. Annotate TRAINING file: train_annotation.xlsx") |
|
|
print("3. Annotate VALIDATION file: validation_annotation.xlsx") |
|
|
print("4. For each case, label: 1=OHCA, 0=Non-OHCA") |
|
|
print("5. Fill confidence scores and notes") |
|
|
print("6. Save both Excel files") |
|
|
print("7. Run continue_v3_training_after_annotation()") |
|
|
print() |
|
|
print("Key benefits of separate annotation:") |
|
|
print(" Prevents threshold tuning bias") |
|
|
print(" Allows proper model evaluation") |
|
|
print(" Provides unbiased performance estimates") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
print(f"\nCreating mock annotations for demonstration...") |
|
|
return create_mock_annotations_v3(annotation_result) |
|
|
|
|
|
def create_mock_annotations_v3(annotation_result): |
|
|
"""Create mock annotations for both training and validation files""" |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
train_df = pd.read_excel(annotation_result['train_annotation_file']) |
|
|
train_df['ohca_label'] = train_df['clean_text'].apply(mock_label_function) |
|
|
train_df['confidence'] = np.random.choice([3, 4, 5], size=len(train_df), p=[0.3, 0.5, 0.2]) |
|
|
train_df['annotator'] = 'demo_v3' |
|
|
train_df['annotation_date'] = '2025-01-01' |
|
|
train_df['notes'] = 'Mock annotation for v3.0 demo' |
|
|
|
|
|
train_completed = "./v3_training_annotation/train_annotation_completed.xlsx" |
|
|
train_df.to_excel(train_completed, index=False) |
|
|
|
|
|
|
|
|
val_df = pd.read_excel(annotation_result['val_annotation_file']) |
|
|
val_df['ohca_label'] = val_df['clean_text'].apply(mock_label_function) |
|
|
val_df['confidence'] = np.random.choice([3, 4, 5], size=len(val_df), p=[0.3, 0.5, 0.2]) |
|
|
val_df['annotator'] = 'demo_v3' |
|
|
val_df['annotation_date'] = '2025-01-01' |
|
|
val_df['notes'] = 'Mock annotation for v3.0 demo' |
|
|
|
|
|
val_completed = "./v3_training_annotation/validation_annotation_completed.xlsx" |
|
|
val_df.to_excel(val_completed, index=False) |
|
|
|
|
|
print(f"Mock annotations created:") |
|
|
print(f" Training: {train_completed} ({len(train_df)} cases)") |
|
|
print(f" Validation: {val_completed} ({len(val_df)} cases)") |
|
|
print(f" Training OHCA prevalence: {train_df['ohca_label'].mean():.1%}") |
|
|
print(f" Validation OHCA prevalence: {val_df['ohca_label'].mean():.1%}") |
|
|
|
|
|
|
|
|
return continue_v3_training_after_annotation( |
|
|
train_completed, val_completed, annotation_result['test_file'] |
|
|
) |
|
|
|
|
|
def mock_label_function(text): |
|
|
"""Simple rule-based mock labeling (in practice, done manually)""" |
|
|
text_lower = str(text).lower() |
|
|
|
|
|
|
|
|
ohca_terms = ['cardiac arrest', 'found down', 'cpr', 'rosc', 'aed shock', 'defibrillation'] |
|
|
location_terms = ['home', 'work', 'restaurant', 'parking', 'gym', 'public'] |
|
|
|
|
|
has_arrest = any(term in text_lower for term in ohca_terms) |
|
|
has_location = any(term in text_lower for term in location_terms) |
|
|
|
|
|
|
|
|
exclude_terms = ['transfer', 'routine', 'elective', 'diabetes', 'pneumonia'] |
|
|
is_excluded = any(term in text_lower for term in exclude_terms) |
|
|
|
|
|
if has_arrest and has_location and not is_excluded: |
|
|
return 1 |
|
|
else: |
|
|
return 0 |
|
|
|
|
|
def continue_v3_training_after_annotation(train_file, val_file, test_file): |
|
|
"""Continue v3.0 training after manual annotation is complete""" |
|
|
|
|
|
print(f"\nCONTINUING v3.0 TRAINING AFTER ANNOTATION") |
|
|
print("="*55) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n3. Enhanced Data Preparation") |
|
|
print("-" * 35) |
|
|
|
|
|
|
|
|
from ohca_training_pipeline import prepare_training_data |
|
|
|
|
|
|
|
|
train_dataset, val_dataset, train_df, val_df, tokenizer = prepare_training_data( |
|
|
train_file, val_file |
|
|
) |
|
|
|
|
|
print(f"Enhanced data preparation complete:") |
|
|
print(f" Training samples: {len(train_dataset)} (after balancing)") |
|
|
print(f" Validation samples: {len(val_dataset)}") |
|
|
print(f" Separate files prevent data leakage") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n4. Model Training") |
|
|
print("-" * 20) |
|
|
|
|
|
model, trained_tokenizer = train_ohca_model( |
|
|
train_dataset=train_dataset, |
|
|
val_dataset=val_dataset, |
|
|
train_df=train_df, |
|
|
tokenizer=tokenizer, |
|
|
num_epochs=3, |
|
|
save_path="./trained_ohca_model_v3" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n5. Optimal Threshold Finding (v3.0 Innovation)") |
|
|
print("-" * 55) |
|
|
|
|
|
optimal_threshold, val_metrics = find_optimal_threshold( |
|
|
model=model, |
|
|
tokenizer=trained_tokenizer, |
|
|
val_df=val_df |
|
|
) |
|
|
|
|
|
print(f"Optimal threshold found: {optimal_threshold:.3f}") |
|
|
print(f"This addresses the data scientist's concern about threshold optimization!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n6. Unbiased Test Set Evaluation") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
test_df = pd.read_csv(test_file) |
|
|
|
|
|
print(f"Independent test set: {len(test_df)} cases") |
|
|
print(f"Note: In practice, you would manually annotate a subset of test cases") |
|
|
print(f"For demonstration, we'll simulate this step") |
|
|
|
|
|
|
|
|
|
|
|
test_df['label'] = test_df['clean_text'].apply(mock_label_function) |
|
|
|
|
|
test_metrics = evaluate_on_test_set( |
|
|
model=model, |
|
|
tokenizer=trained_tokenizer, |
|
|
test_df=test_df, |
|
|
optimal_threshold=optimal_threshold |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n7. Enhanced Model Saving with Metadata") |
|
|
print("-" * 45) |
|
|
|
|
|
save_model_with_metadata( |
|
|
model=model, |
|
|
tokenizer=trained_tokenizer, |
|
|
optimal_threshold=optimal_threshold, |
|
|
val_metrics=val_metrics, |
|
|
test_metrics=test_metrics, |
|
|
model_save_path="./trained_ohca_model_v3" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n" + "="*70) |
|
|
print("v3.0 TRAINING COMPLETE - METHODOLOGY IMPROVEMENTS IMPLEMENTED") |
|
|
print("="*70) |
|
|
|
|
|
print(f"Model and metadata saved to: ./trained_ohca_model_v3/") |
|
|
|
|
|
print(f"\nPerformance Summary (Unbiased Evaluation):") |
|
|
print(f" Validation F1-Score: {val_metrics['f1_score']:.3f}") |
|
|
print(f" Validation Sensitivity: {val_metrics['sensitivity']:.1%}") |
|
|
print(f" Validation Specificity: {val_metrics['specificity']:.1%}") |
|
|
print(f" Test Accuracy: {test_metrics['test_accuracy']:.1%}") |
|
|
print(f" Test F1-Score: {test_metrics['test_f1_score']:.3f}") |
|
|
|
|
|
print(f"\nv3.0 Improvements Implemented:") |
|
|
print(f" Patient-level splits prevent data leakage") |
|
|
print(f" Proper train/validation/test methodology") |
|
|
print(f" Optimal threshold: {optimal_threshold:.3f} (saved with model)") |
|
|
print(f" Larger training set: {len(train_dataset)} samples") |
|
|
print(f" Unbiased evaluation on independent test set") |
|
|
print(f" Enhanced metadata and model versioning") |
|
|
|
|
|
print(f"\nNext Steps:") |
|
|
print(f" 1. Model automatically uses optimal threshold during inference") |
|
|
print(f" 2. Enhanced clinical decision support available") |
|
|
print(f" 3. Use quick_inference_with_optimal_threshold() for new data") |
|
|
print(f" 4. Monitor performance and retrain as needed") |
|
|
|
|
|
return { |
|
|
'model_path': "./trained_ohca_model_v3/", |
|
|
'optimal_threshold': optimal_threshold, |
|
|
'val_metrics': val_metrics, |
|
|
'test_metrics': test_metrics, |
|
|
'training_methodology': 'v3.0', |
|
|
'improvements_implemented': [ |
|
|
'Patient-level data splits', |
|
|
'Separate train/validation annotation', |
|
|
'Optimal threshold optimization', |
|
|
'Independent test set evaluation', |
|
|
'Enhanced model metadata' |
|
|
] |
|
|
} |
|
|
|
|
|
def legacy_training_example(): |
|
|
"""Legacy training example for comparison/backward compatibility""" |
|
|
|
|
|
print("Legacy Training Pipeline Example (for comparison)") |
|
|
print("="*55) |
|
|
|
|
|
print("WARNING: This demonstrates the OLD methodology with known issues:") |
|
|
print(" Small sample size (330 total, 264 training)") |
|
|
print(" No patient-level splits (data leakage possible)") |
|
|
print(" Threshold optimization on same validation set used for evaluation") |
|
|
print(" No independent test set") |
|
|
print() |
|
|
print("This is maintained for backward compatibility only.") |
|
|
print("RECOMMENDATION: Use improved_training_example() instead!") |
|
|
|
|
|
data_path = "legacy_discharge_notes.csv" |
|
|
|
|
|
|
|
|
if not os.path.exists(data_path): |
|
|
legacy_data = { |
|
|
'hadm_id': [f'LEG_{i:06d}' for i in range(1000)], |
|
|
'clean_text': [ |
|
|
"Chief complaint: Cardiac arrest at home.", |
|
|
"Chief complaint: Chest pain, no arrest.", |
|
|
"Chief complaint: Found down, cardiac arrest.", |
|
|
"Chief complaint: Shortness of breath.", |
|
|
"Chief complaint: Syncope, no arrest.", |
|
|
] * 200 |
|
|
} |
|
|
pd.DataFrame(legacy_data).to_csv(data_path, index=False) |
|
|
|
|
|
|
|
|
result = complete_training_pipeline( |
|
|
data_path=data_path, |
|
|
annotation_dir="./legacy_annotation", |
|
|
model_save_path="./legacy_trained_model" |
|
|
) |
|
|
|
|
|
print(f"Legacy annotation file created: {result['annotation_file']}") |
|
|
print(f"Annotation sample size: 330 cases (small compared to v3.0's 1000)") |
|
|
|
|
|
print(f"\nLegacy method limitations:") |
|
|
print(f" Single annotation file instead of separate train/val") |
|
|
print(f" No optimal threshold finding") |
|
|
print(f" No patient-level data protection") |
|
|
print(f" Biased evaluation methodology") |
|
|
|
|
|
return result |
|
|
|
|
|
def methodology_comparison(): |
|
|
"""Compare v3.0 vs legacy methodologies side by side""" |
|
|
|
|
|
print("v3.0 vs Legacy Methodology Comparison") |
|
|
print("="*45) |
|
|
|
|
|
comparison_table = """ |
|
|
Aspect | Legacy Method | v3.0 Improved Method |
|
|
----------------------- | ---------------------- | ---------------------- |
|
|
Sample Size | 330 total (264 train) | 1000+ total (800 train) |
|
|
Data Splits | Random note-level | Patient-level splits |
|
|
Annotation Files | 1 file (biased) | 2 files (unbiased) |
|
|
Threshold Selection | Static 0.5 or manual | Optimal from validation |
|
|
Evaluation | Same set for tuning | Independent test set |
|
|
Data Leakage Risk | High (same patients) | Prevented (patient-level) |
|
|
Performance Reliability| Inflated estimates | Unbiased estimates |
|
|
Clinical Integration | Basic confidence | Enhanced priorities |
|
|
Model Metadata | Limited | Comprehensive |
|
|
Methodology Validation | None | Peer-reviewed approach |
|
|
""" |
|
|
|
|
|
print(comparison_table) |
|
|
|
|
|
print(f"\nKey Data Scientist Concerns Addressed in v3.0:") |
|
|
print(f"1. BIAS: Patient-level splits prevent data leakage") |
|
|
print(f"2. SAMPLE SIZE: 800 training cases vs 264 in legacy") |
|
|
print(f"3. EVALUATION: Independent test set prevents threshold tuning bias") |
|
|
print(f"4. THRESHOLD CONSISTENCY: Optimal threshold saved and reused") |
|
|
print(f"5. METHODOLOGY: Follows ML best practices") |
|
|
|
|
|
print(f"\nRecommendation:") |
|
|
print(f" Use v3.0 methodology for all new model training") |
|
|
print(f" Consider retraining legacy models with v3.0 approach") |
|
|
print(f" Legacy functions maintained for backward compatibility only") |
|
|
|
|
|
def training_best_practices_v3(): |
|
|
"""Updated best practices for v3.0 methodology""" |
|
|
|
|
|
print("OHCA Training Best Practices - v3.0 Methodology") |
|
|
print("="*55) |
|
|
|
|
|
print(f"\nData Preparation (Enhanced):") |
|
|
print(f" Ensure you have patient IDs (subject_id column)") |
|
|
print(f" Minimum 500+ unique patients for robust splits") |
|
|
print(f" Clean and standardize discharge note text") |
|
|
print(f" Include diverse hospital systems if possible") |
|
|
|
|
|
print(f"\nAnnotation Strategy (v3.0):") |
|
|
print(f" Annotate BOTH training and validation files separately") |
|
|
print(f" Training sample: 800+ cases for better performance") |
|
|
print(f" Validation sample: 200+ cases for reliable threshold optimization") |
|
|
print(f" Reserve test set for final unbiased evaluation") |
|
|
print(f" Use consistent OHCA definition across all annotators") |
|
|
|
|
|
print(f"\nModel Training (Improved):") |
|
|
print(f" Patient-level splits prevent data leakage") |
|
|
print(f" Class balancing handles imbalanced datasets") |
|
|
print(f" Monitor training loss to prevent overfitting") |
|
|
print(f" Use validation set only for threshold optimization") |
|
|
|
|
|
print(f"\nModel Evaluation (Unbiased):") |
|
|
print(f" Find optimal threshold on validation set") |
|
|
print(f" Report final performance on independent test set") |
|
|
print(f" Never use test set for model selection or tuning") |
|
|
print(f" Focus on clinical metrics (sensitivity, specificity)") |
|
|
|
|
|
print(f"\nDeployment (Enhanced):") |
|
|
print(f" Model automatically uses optimal threshold") |
|
|
print(f" Enhanced clinical decision support built-in") |
|
|
print(f" Comprehensive model metadata for tracking") |
|
|
print(f" Plan for continuous model monitoring") |
|
|
|
|
|
print(f"\nQuality Assurance:") |
|
|
print(f" Validate performance on external datasets") |
|
|
print(f" Monitor for distribution drift in new data") |
|
|
print(f" Regular retraining with new annotated cases") |
|
|
print(f" Document all methodology improvements") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("OHCA Training Examples v3.0 - Improved Methodology") |
|
|
print("="*55) |
|
|
|
|
|
print(f"\nAvailable examples:") |
|
|
print(f"1. v3.0 Training with Improved Methodology (RECOMMENDED)") |
|
|
print(f"2. Legacy Training (backward compatibility)") |
|
|
print(f"3. Methodology Comparison (v3.0 vs Legacy)") |
|
|
print(f"4. v3.0 Best Practices Guide") |
|
|
|
|
|
choice = input(f"\nEnter choice (1-4): ").strip() |
|
|
|
|
|
if choice == "1": |
|
|
improved_training_example() |
|
|
elif choice == "2": |
|
|
legacy_training_example() |
|
|
elif choice == "3": |
|
|
methodology_comparison() |
|
|
elif choice == "4": |
|
|
training_best_practices_v3() |
|
|
else: |
|
|
print(f"Running v3.0 training example by default...") |
|
|
improved_training_example() |
|
|
|