File size: 6,071 Bytes
493b03a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
#!/usr/bin/env python3
"""
Train OHCA Classifier from Pre-labeled Data
This script trains a v3.0 OHCA classifier using your manually labeled data.
Your data should have columns: hadm_id, clean_text, ohca_label (and optionally subject_id, confidence)
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
import pandas as pd
from sklearn.model_selection import train_test_split
from ohca_training_pipeline import prepare_training_data, train_ohca_model, find_optimal_threshold, save_model_with_metadata
def validate_labeled_data(df):
"""Validate that the labeled data has required columns and format"""
required_cols = ['hadm_id', 'clean_text', 'ohca_label']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Check ohca_label values
unique_labels = df['ohca_label'].unique()
if not set(unique_labels).issubset({0, 1}):
raise ValueError(f"ohca_label must be 0 or 1, found: {unique_labels}")
print(f"Data validation passed:")
print(f" Total cases: {len(df)}")
print(f" OHCA cases (label=1): {(df['ohca_label']==1).sum()}")
print(f" Non-OHCA cases (label=0): {(df['ohca_label']==0).sum()}")
print(f" OHCA prevalence: {(df['ohca_label']==1).mean():.1%}")
def train_from_labeled_data(data_path, model_save_path="./trained_ohca_model",
test_size=0.2, num_epochs=3):
"""
Train OHCA model from pre-labeled data
Args:
data_path: Path to CSV with labeled data
model_save_path: Where to save the trained model
test_size: Fraction to use for validation (default 0.2 = 20%)
num_epochs: Number of training epochs
"""
print("OHCA Classifier Training from Pre-labeled Data")
print("="*50)
# Load and validate data
print(f"Loading labeled data from: {data_path}")
df = pd.read_csv(data_path)
# Add missing columns if needed
if 'subject_id' not in df.columns:
print("Adding subject_id column (using hadm_id as patient ID)")
df['subject_id'] = df['hadm_id']
if 'confidence' not in df.columns:
print("Adding default confidence scores")
df['confidence'] = 4 # Default confidence
validate_labeled_data(df)
# Split into train/validation
print(f"\nSplitting data (train: {1-test_size:.0%}, validation: {test_size:.0%})")
train_df, val_df = train_test_split(
df, test_size=test_size,
stratify=df['ohca_label'],
random_state=42
)
print(f"Training data: {len(train_df)} cases ({(train_df['ohca_label']==1).sum()} OHCA)")
print(f"Validation data: {len(val_df)} cases ({(val_df['ohca_label']==1).sum()} OHCA)")
# Save as temporary Excel files
temp_train = 'temp_train_data.xlsx'
temp_val = 'temp_val_data.xlsx'
train_df.to_excel(temp_train, index=False)
val_df.to_excel(temp_val, index=False)
try:
# Prepare training datasets
print("\nPreparing training datasets...")
train_dataset, val_dataset, train_df_balanced, val_df_clean, tokenizer = prepare_training_data(
temp_train, temp_val
)
# Train the model
print(f"\nTraining model for {num_epochs} epochs...")
model, trained_tokenizer = train_ohca_model(
train_dataset, val_dataset, train_df_balanced, tokenizer,
num_epochs=num_epochs,
save_path=model_save_path
)
# Find optimal threshold
print("\nFinding optimal threshold...")
optimal_threshold, val_metrics = find_optimal_threshold(
model, trained_tokenizer, val_df_clean
)
# Save model with metadata
print("\nSaving model with metadata...")
test_metrics = {'message': 'Trained on user-provided labeled data', 'test_set_size': 0}
save_model_with_metadata(
model, trained_tokenizer, optimal_threshold,
val_metrics, test_metrics, model_save_path
)
print(f"\nTraining completed successfully!")
print(f"Model saved to: {model_save_path}")
print(f"Optimal threshold: {optimal_threshold:.3f}")
print(f"Validation F1-score: {val_metrics['f1_score']:.3f}")
return {
'model_path': model_save_path,
'optimal_threshold': optimal_threshold,
'metrics': val_metrics
}
finally:
# Clean up temporary files
for temp_file in [temp_train, temp_val]:
if os.path.exists(temp_file):
os.remove(temp_file)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Train OHCA classifier from labeled data')
parser.add_argument('data_path', help='Path to CSV file with labeled data')
parser.add_argument('--model_path', default='./trained_ohca_model',
help='Where to save trained model (default: ./trained_ohca_model)')
parser.add_argument('--epochs', type=int, default=3,
help='Number of training epochs (default: 3)')
parser.add_argument('--test_size', type=float, default=0.2,
help='Validation split fraction (default: 0.2)')
args = parser.parse_args()
if not os.path.exists(args.data_path):
print(f"Error: Data file not found: {args.data_path}")
print("\nYour CSV file should have columns:")
print(" hadm_id: Unique admission identifier")
print(" clean_text: Discharge note text")
print(" ohca_label: 1 for OHCA, 0 for non-OHCA")
print(" subject_id: Patient ID (optional - will use hadm_id if missing)")
sys.exit(1)
try:
train_from_labeled_data(args.data_path, args.model_path, args.test_size, args.epochs)
except Exception as e:
print(f"Training failed: {e}")
sys.exit(1)
|