monajm36
commited on
Commit
·
493b03a
1
Parent(s):
e2ef18e
Add user-friendly scripts for training and prediction workflows
Browse files
examples/scripts/predict_ohca.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Apply OHCA Classifier to New Discharge Notes
|
| 4 |
+
|
| 5 |
+
This script applies a trained OHCA classifier to new discharge notes.
|
| 6 |
+
Input data should have columns: hadm_id, clean_text
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from ohca_inference import quick_inference
|
| 15 |
+
|
| 16 |
+
def validate_discharge_data(df):
|
| 17 |
+
"""Validate that discharge data has required columns"""
|
| 18 |
+
required_cols = ['hadm_id', 'clean_text']
|
| 19 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
| 20 |
+
|
| 21 |
+
if missing_cols:
|
| 22 |
+
raise ValueError(f"Missing required columns: {missing_cols}")
|
| 23 |
+
|
| 24 |
+
# Check for missing values
|
| 25 |
+
missing_ids = df['hadm_id'].isna().sum()
|
| 26 |
+
missing_text = df['clean_text'].isna().sum()
|
| 27 |
+
|
| 28 |
+
if missing_ids > 0:
|
| 29 |
+
print(f"Warning: {missing_ids} rows have missing hadm_id")
|
| 30 |
+
if missing_text > 0:
|
| 31 |
+
print(f"Warning: {missing_text} rows have missing clean_text")
|
| 32 |
+
|
| 33 |
+
print(f"Data validation:")
|
| 34 |
+
print(f" Total discharge notes: {len(df)}")
|
| 35 |
+
print(f" Valid records: {len(df.dropna(subset=required_cols))}")
|
| 36 |
+
|
| 37 |
+
def predict_ohca(model_path, data_path, output_path=None):
|
| 38 |
+
"""
|
| 39 |
+
Apply OHCA model to discharge notes
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model_path: Path to trained model
|
| 43 |
+
data_path: Path to CSV with discharge notes
|
| 44 |
+
output_path: Where to save results (optional)
|
| 45 |
+
"""
|
| 46 |
+
print("OHCA Classifier Prediction")
|
| 47 |
+
print("="*30)
|
| 48 |
+
|
| 49 |
+
# Validate model exists
|
| 50 |
+
if not os.path.exists(model_path):
|
| 51 |
+
raise FileNotFoundError(f"Model not found: {model_path}")
|
| 52 |
+
|
| 53 |
+
print(f"Model: {model_path}")
|
| 54 |
+
print(f"Data: {data_path}")
|
| 55 |
+
|
| 56 |
+
# Load and validate data
|
| 57 |
+
df = pd.read_csv(data_path)
|
| 58 |
+
validate_discharge_data(df)
|
| 59 |
+
|
| 60 |
+
# Set default output path
|
| 61 |
+
if output_path is None:
|
| 62 |
+
base_name = os.path.splitext(os.path.basename(data_path))[0]
|
| 63 |
+
output_path = f"{base_name}_ohca_predictions.csv"
|
| 64 |
+
|
| 65 |
+
print(f"Output: {output_path}")
|
| 66 |
+
|
| 67 |
+
# Run inference
|
| 68 |
+
print(f"\nRunning OHCA prediction on {len(df)} discharge notes...")
|
| 69 |
+
results = quick_inference(
|
| 70 |
+
model_path=model_path,
|
| 71 |
+
data_path=df,
|
| 72 |
+
output_path=output_path
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Analyze results
|
| 76 |
+
if 'ohca_prediction' in results.columns:
|
| 77 |
+
ohca_detected = results['ohca_prediction'].sum()
|
| 78 |
+
threshold_used = results.get('optimal_threshold_used', [0.5]).iloc[0]
|
| 79 |
+
else:
|
| 80 |
+
# Fallback for legacy models
|
| 81 |
+
ohca_detected = (results['ohca_probability'] >= 0.5).sum()
|
| 82 |
+
threshold_used = 0.5
|
| 83 |
+
|
| 84 |
+
high_confidence = (results['ohca_probability'] >= 0.8).sum()
|
| 85 |
+
very_high_confidence = (results['ohca_probability'] >= 0.9).sum()
|
| 86 |
+
|
| 87 |
+
print(f"\nResults Summary:")
|
| 88 |
+
print(f" Total cases analyzed: {len(results)}")
|
| 89 |
+
print(f" OHCA detected: {ohca_detected} ({ohca_detected/len(results)*100:.1f}%)")
|
| 90 |
+
print(f" High confidence (≥0.8): {high_confidence}")
|
| 91 |
+
print(f" Very high confidence (≥0.9): {very_high_confidence}")
|
| 92 |
+
print(f" Threshold used: {threshold_used:.3f}")
|
| 93 |
+
|
| 94 |
+
# Show highest probability cases
|
| 95 |
+
print(f"\nTop 5 highest probability cases:")
|
| 96 |
+
top_cases = results.nlargest(5, 'ohca_probability')
|
| 97 |
+
for _, row in top_cases.iterrows():
|
| 98 |
+
print(f" {row['hadm_id']}: {row['ohca_probability']:.3f}")
|
| 99 |
+
|
| 100 |
+
print(f"\nResults saved to: {output_path}")
|
| 101 |
+
|
| 102 |
+
# Clinical recommendations
|
| 103 |
+
if very_high_confidence > 0:
|
| 104 |
+
print(f"\nClinical Recommendations:")
|
| 105 |
+
print(f" → {very_high_confidence} cases need immediate review (≥90% probability)")
|
| 106 |
+
if high_confidence > very_high_confidence:
|
| 107 |
+
print(f" → {high_confidence - very_high_confidence} cases need priority review (80-90% probability)")
|
| 108 |
+
|
| 109 |
+
return results
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
import argparse
|
| 113 |
+
|
| 114 |
+
parser = argparse.ArgumentParser(description='Apply OHCA classifier to discharge notes')
|
| 115 |
+
parser.add_argument('model_path', help='Path to trained model directory')
|
| 116 |
+
parser.add_argument('data_path', help='Path to CSV file with discharge notes')
|
| 117 |
+
parser.add_argument('--output', help='Output CSV path (default: auto-generated)')
|
| 118 |
+
|
| 119 |
+
args = parser.parse_args()
|
| 120 |
+
|
| 121 |
+
if not os.path.exists(args.model_path):
|
| 122 |
+
print(f"Error: Model not found: {args.model_path}")
|
| 123 |
+
print("Train a model first using: python scripts/train_from_labeled_data.py")
|
| 124 |
+
sys.exit(1)
|
| 125 |
+
|
| 126 |
+
if not os.path.exists(args.data_path):
|
| 127 |
+
print(f"Error: Data file not found: {args.data_path}")
|
| 128 |
+
print("\nYour CSV file should have columns:")
|
| 129 |
+
print(" hadm_id: Unique admission identifier")
|
| 130 |
+
print(" clean_text: Discharge note text")
|
| 131 |
+
sys.exit(1)
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
predict_ohca(args.model_path, args.data_path, args.output)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Prediction failed: {e}")
|
| 137 |
+
sys.exit(1)
|
examples/scripts/prepare_data.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Data Preparation Helper for OHCA Classifier
|
| 4 |
+
|
| 5 |
+
This script helps prepare your data in the correct format for training or inference.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
def prepare_labeled_data(input_path, output_path=None):
|
| 12 |
+
"""Prepare manually labeled data for training"""
|
| 13 |
+
print("Preparing labeled data for training...")
|
| 14 |
+
|
| 15 |
+
df = pd.read_csv(input_path)
|
| 16 |
+
print(f"Loaded {len(df)} records")
|
| 17 |
+
print(f"Columns: {list(df.columns)}")
|
| 18 |
+
|
| 19 |
+
# Interactive column mapping
|
| 20 |
+
required_cols = ['hadm_id', 'clean_text', 'ohca_label']
|
| 21 |
+
column_mapping = {}
|
| 22 |
+
|
| 23 |
+
for req_col in required_cols:
|
| 24 |
+
if req_col not in df.columns:
|
| 25 |
+
print(f"\nColumn '{req_col}' not found.")
|
| 26 |
+
print(f"Available columns: {list(df.columns)}")
|
| 27 |
+
mapped_col = input(f"Which column should be used for '{req_col}'? ")
|
| 28 |
+
if mapped_col in df.columns:
|
| 29 |
+
column_mapping[mapped_col] = req_col
|
| 30 |
+
else:
|
| 31 |
+
print(f"Column '{mapped_col}' not found. Skipping...")
|
| 32 |
+
|
| 33 |
+
# Apply mapping
|
| 34 |
+
if column_mapping:
|
| 35 |
+
df = df.rename(columns=column_mapping)
|
| 36 |
+
print(f"Applied column mapping: {column_mapping}")
|
| 37 |
+
|
| 38 |
+
# Add missing optional columns
|
| 39 |
+
if 'subject_id' not in df.columns:
|
| 40 |
+
df['subject_id'] = df['hadm_id']
|
| 41 |
+
print("Added subject_id column (copied from hadm_id)")
|
| 42 |
+
|
| 43 |
+
if 'confidence' not in df.columns:
|
| 44 |
+
df['confidence'] = 4
|
| 45 |
+
print("Added default confidence scores")
|
| 46 |
+
|
| 47 |
+
# Validate and clean
|
| 48 |
+
df = df.dropna(subset=['hadm_id', 'clean_text', 'ohca_label'])
|
| 49 |
+
|
| 50 |
+
# Set output path
|
| 51 |
+
if output_path is None:
|
| 52 |
+
base_name = input_path.replace('.csv', '')
|
| 53 |
+
output_path = f"{base_name}_prepared.csv"
|
| 54 |
+
|
| 55 |
+
df.to_csv(output_path, index=False)
|
| 56 |
+
|
| 57 |
+
print(f"\nData prepared successfully:")
|
| 58 |
+
print(f" Output: {output_path}")
|
| 59 |
+
print(f" Records: {len(df)}")
|
| 60 |
+
print(f" OHCA cases: {(df['ohca_label']==1).sum()}")
|
| 61 |
+
print(f" Columns: {list(df.columns)}")
|
| 62 |
+
|
| 63 |
+
def prepare_discharge_notes(input_path, output_path=None):
|
| 64 |
+
"""Prepare discharge notes for inference"""
|
| 65 |
+
print("Preparing discharge notes for inference...")
|
| 66 |
+
|
| 67 |
+
df = pd.read_csv(input_path)
|
| 68 |
+
print(f"Loaded {len(df)} records")
|
| 69 |
+
print(f"Columns: {list(df.columns)}")
|
| 70 |
+
|
| 71 |
+
# Interactive column mapping
|
| 72 |
+
required_cols = ['hadm_id', 'clean_text']
|
| 73 |
+
column_mapping = {}
|
| 74 |
+
|
| 75 |
+
for req_col in required_cols:
|
| 76 |
+
if req_col not in df.columns:
|
| 77 |
+
print(f"\nColumn '{req_col}' not found.")
|
| 78 |
+
print(f"Available columns: {list(df.columns)}")
|
| 79 |
+
mapped_col = input(f"Which column should be used for '{req_col}'? ")
|
| 80 |
+
if mapped_col in df.columns:
|
| 81 |
+
column_mapping[mapped_col] = req_col
|
| 82 |
+
|
| 83 |
+
# Apply mapping
|
| 84 |
+
if column_mapping:
|
| 85 |
+
df = df.rename(columns=column_mapping)
|
| 86 |
+
print(f"Applied column mapping: {column_mapping}")
|
| 87 |
+
|
| 88 |
+
# Clean data
|
| 89 |
+
df = df.dropna(subset=['hadm_id', 'clean_text'])
|
| 90 |
+
|
| 91 |
+
# Set output path
|
| 92 |
+
if output_path is None:
|
| 93 |
+
base_name = input_path.replace('.csv', '')
|
| 94 |
+
output_path = f"{base_name}_prepared.csv"
|
| 95 |
+
|
| 96 |
+
df.to_csv(output_path, index=False)
|
| 97 |
+
|
| 98 |
+
print(f"\nDischarge notes prepared:")
|
| 99 |
+
print(f" Output: {output_path}")
|
| 100 |
+
print(f" Records: {len(df)}")
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
if len(sys.argv) < 2:
|
| 104 |
+
print("Usage:")
|
| 105 |
+
print(" python scripts/prepare_data.py labeled <input.csv> # For training data")
|
| 106 |
+
print(" python scripts/prepare_data.py discharge <input.csv> # For inference data")
|
| 107 |
+
sys.exit(1)
|
| 108 |
+
|
| 109 |
+
data_type = sys.argv[1]
|
| 110 |
+
input_path = sys.argv[2]
|
| 111 |
+
|
| 112 |
+
if data_type == "labeled":
|
| 113 |
+
prepare_labeled_data(input_path)
|
| 114 |
+
elif data_type == "discharge":
|
| 115 |
+
prepare_discharge_notes(input_path)
|
| 116 |
+
else:
|
| 117 |
+
print("Data type must be 'labeled' or 'discharge'")
|
| 118 |
+
sys.exit(1)
|
examples/scripts/train_from_labeled_data.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Train OHCA Classifier from Pre-labeled Data
|
| 4 |
+
|
| 5 |
+
This script trains a v3.0 OHCA classifier using your manually labeled data.
|
| 6 |
+
Your data should have columns: hadm_id, clean_text, ohca_label (and optionally subject_id, confidence)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from sklearn.model_selection import train_test_split
|
| 15 |
+
from ohca_training_pipeline import prepare_training_data, train_ohca_model, find_optimal_threshold, save_model_with_metadata
|
| 16 |
+
|
| 17 |
+
def validate_labeled_data(df):
|
| 18 |
+
"""Validate that the labeled data has required columns and format"""
|
| 19 |
+
required_cols = ['hadm_id', 'clean_text', 'ohca_label']
|
| 20 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
| 21 |
+
|
| 22 |
+
if missing_cols:
|
| 23 |
+
raise ValueError(f"Missing required columns: {missing_cols}")
|
| 24 |
+
|
| 25 |
+
# Check ohca_label values
|
| 26 |
+
unique_labels = df['ohca_label'].unique()
|
| 27 |
+
if not set(unique_labels).issubset({0, 1}):
|
| 28 |
+
raise ValueError(f"ohca_label must be 0 or 1, found: {unique_labels}")
|
| 29 |
+
|
| 30 |
+
print(f"Data validation passed:")
|
| 31 |
+
print(f" Total cases: {len(df)}")
|
| 32 |
+
print(f" OHCA cases (label=1): {(df['ohca_label']==1).sum()}")
|
| 33 |
+
print(f" Non-OHCA cases (label=0): {(df['ohca_label']==0).sum()}")
|
| 34 |
+
print(f" OHCA prevalence: {(df['ohca_label']==1).mean():.1%}")
|
| 35 |
+
|
| 36 |
+
def train_from_labeled_data(data_path, model_save_path="./trained_ohca_model",
|
| 37 |
+
test_size=0.2, num_epochs=3):
|
| 38 |
+
"""
|
| 39 |
+
Train OHCA model from pre-labeled data
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
data_path: Path to CSV with labeled data
|
| 43 |
+
model_save_path: Where to save the trained model
|
| 44 |
+
test_size: Fraction to use for validation (default 0.2 = 20%)
|
| 45 |
+
num_epochs: Number of training epochs
|
| 46 |
+
"""
|
| 47 |
+
print("OHCA Classifier Training from Pre-labeled Data")
|
| 48 |
+
print("="*50)
|
| 49 |
+
|
| 50 |
+
# Load and validate data
|
| 51 |
+
print(f"Loading labeled data from: {data_path}")
|
| 52 |
+
df = pd.read_csv(data_path)
|
| 53 |
+
|
| 54 |
+
# Add missing columns if needed
|
| 55 |
+
if 'subject_id' not in df.columns:
|
| 56 |
+
print("Adding subject_id column (using hadm_id as patient ID)")
|
| 57 |
+
df['subject_id'] = df['hadm_id']
|
| 58 |
+
|
| 59 |
+
if 'confidence' not in df.columns:
|
| 60 |
+
print("Adding default confidence scores")
|
| 61 |
+
df['confidence'] = 4 # Default confidence
|
| 62 |
+
|
| 63 |
+
validate_labeled_data(df)
|
| 64 |
+
|
| 65 |
+
# Split into train/validation
|
| 66 |
+
print(f"\nSplitting data (train: {1-test_size:.0%}, validation: {test_size:.0%})")
|
| 67 |
+
train_df, val_df = train_test_split(
|
| 68 |
+
df, test_size=test_size,
|
| 69 |
+
stratify=df['ohca_label'],
|
| 70 |
+
random_state=42
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
print(f"Training data: {len(train_df)} cases ({(train_df['ohca_label']==1).sum()} OHCA)")
|
| 74 |
+
print(f"Validation data: {len(val_df)} cases ({(val_df['ohca_label']==1).sum()} OHCA)")
|
| 75 |
+
|
| 76 |
+
# Save as temporary Excel files
|
| 77 |
+
temp_train = 'temp_train_data.xlsx'
|
| 78 |
+
temp_val = 'temp_val_data.xlsx'
|
| 79 |
+
train_df.to_excel(temp_train, index=False)
|
| 80 |
+
val_df.to_excel(temp_val, index=False)
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
# Prepare training datasets
|
| 84 |
+
print("\nPreparing training datasets...")
|
| 85 |
+
train_dataset, val_dataset, train_df_balanced, val_df_clean, tokenizer = prepare_training_data(
|
| 86 |
+
temp_train, temp_val
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Train the model
|
| 90 |
+
print(f"\nTraining model for {num_epochs} epochs...")
|
| 91 |
+
model, trained_tokenizer = train_ohca_model(
|
| 92 |
+
train_dataset, val_dataset, train_df_balanced, tokenizer,
|
| 93 |
+
num_epochs=num_epochs,
|
| 94 |
+
save_path=model_save_path
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Find optimal threshold
|
| 98 |
+
print("\nFinding optimal threshold...")
|
| 99 |
+
optimal_threshold, val_metrics = find_optimal_threshold(
|
| 100 |
+
model, trained_tokenizer, val_df_clean
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Save model with metadata
|
| 104 |
+
print("\nSaving model with metadata...")
|
| 105 |
+
test_metrics = {'message': 'Trained on user-provided labeled data', 'test_set_size': 0}
|
| 106 |
+
save_model_with_metadata(
|
| 107 |
+
model, trained_tokenizer, optimal_threshold,
|
| 108 |
+
val_metrics, test_metrics, model_save_path
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
print(f"\nTraining completed successfully!")
|
| 112 |
+
print(f"Model saved to: {model_save_path}")
|
| 113 |
+
print(f"Optimal threshold: {optimal_threshold:.3f}")
|
| 114 |
+
print(f"Validation F1-score: {val_metrics['f1_score']:.3f}")
|
| 115 |
+
|
| 116 |
+
return {
|
| 117 |
+
'model_path': model_save_path,
|
| 118 |
+
'optimal_threshold': optimal_threshold,
|
| 119 |
+
'metrics': val_metrics
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
finally:
|
| 123 |
+
# Clean up temporary files
|
| 124 |
+
for temp_file in [temp_train, temp_val]:
|
| 125 |
+
if os.path.exists(temp_file):
|
| 126 |
+
os.remove(temp_file)
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
import argparse
|
| 130 |
+
|
| 131 |
+
parser = argparse.ArgumentParser(description='Train OHCA classifier from labeled data')
|
| 132 |
+
parser.add_argument('data_path', help='Path to CSV file with labeled data')
|
| 133 |
+
parser.add_argument('--model_path', default='./trained_ohca_model',
|
| 134 |
+
help='Where to save trained model (default: ./trained_ohca_model)')
|
| 135 |
+
parser.add_argument('--epochs', type=int, default=3,
|
| 136 |
+
help='Number of training epochs (default: 3)')
|
| 137 |
+
parser.add_argument('--test_size', type=float, default=0.2,
|
| 138 |
+
help='Validation split fraction (default: 0.2)')
|
| 139 |
+
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
if not os.path.exists(args.data_path):
|
| 143 |
+
print(f"Error: Data file not found: {args.data_path}")
|
| 144 |
+
print("\nYour CSV file should have columns:")
|
| 145 |
+
print(" hadm_id: Unique admission identifier")
|
| 146 |
+
print(" clean_text: Discharge note text")
|
| 147 |
+
print(" ohca_label: 1 for OHCA, 0 for non-OHCA")
|
| 148 |
+
print(" subject_id: Patient ID (optional - will use hadm_id if missing)")
|
| 149 |
+
sys.exit(1)
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
train_from_labeled_data(args.data_path, args.model_path, args.test_size, args.epochs)
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"Training failed: {e}")
|
| 155 |
+
sys.exit(1)
|