monajm36
commited on
Commit
·
1b05cbb
1
Parent(s):
493b03a
Add user-friendly training and prediction scripts
Browse files
scripts/train_from_labeled_data.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from sklearn.model_selection import train_test_split
|
| 8 |
+
from ohca_training_pipeline import prepare_training_data, train_ohca_model, find_optimal_threshold, save_model_with_metadata
|
| 9 |
+
|
| 10 |
+
def train_from_labeled_data(data_path, model_save_path="./trained_ohca_model", test_size=0.2, num_epochs=3):
|
| 11 |
+
print("OHCA Classifier Training from Pre-labeled Data")
|
| 12 |
+
print("="*50)
|
| 13 |
+
|
| 14 |
+
# Load data
|
| 15 |
+
print(f"Loading labeled data from: {data_path}")
|
| 16 |
+
df = pd.read_csv(data_path)
|
| 17 |
+
|
| 18 |
+
# Add subject_id if missing
|
| 19 |
+
if 'subject_id' not in df.columns:
|
| 20 |
+
print("Adding subject_id column (using hadm_id as patient ID)")
|
| 21 |
+
df['subject_id'] = df['hadm_id']
|
| 22 |
+
|
| 23 |
+
print(f"Data loaded: {len(df)} cases ({(df['ohca_label']==1).sum()} OHCA, {(df['ohca_label']==0).sum()} non-OHCA)")
|
| 24 |
+
|
| 25 |
+
# Split data
|
| 26 |
+
train_df, val_df = train_test_split(df, test_size=test_size, stratify=df['ohca_label'], random_state=42)
|
| 27 |
+
print(f"Training: {len(train_df)}, Validation: {len(val_df)}")
|
| 28 |
+
|
| 29 |
+
# Save temporary files
|
| 30 |
+
train_df.to_excel('temp_train.xlsx', index=False)
|
| 31 |
+
val_df.to_excel('temp_val.xlsx', index=False)
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# Train
|
| 35 |
+
print("Preparing training data...")
|
| 36 |
+
train_dataset, val_dataset, train_df_balanced, val_df_clean, tokenizer = prepare_training_data('temp_train.xlsx', 'temp_val.xlsx')
|
| 37 |
+
|
| 38 |
+
print(f"Training model for {num_epochs} epochs...")
|
| 39 |
+
model, trained_tokenizer = train_ohca_model(
|
| 40 |
+
train_dataset, val_dataset, train_df_balanced, tokenizer,
|
| 41 |
+
num_epochs=num_epochs, save_path=model_save_path
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
print("Finding optimal threshold...")
|
| 45 |
+
optimal_threshold, val_metrics = find_optimal_threshold(model, trained_tokenizer, val_df_clean)
|
| 46 |
+
|
| 47 |
+
print("Saving model with metadata...")
|
| 48 |
+
test_metrics = {'message': 'Trained on user data', 'test_set_size': 0}
|
| 49 |
+
save_model_with_metadata(model, trained_tokenizer, optimal_threshold, val_metrics, test_metrics, model_save_path)
|
| 50 |
+
|
| 51 |
+
print(f"Training completed!")
|
| 52 |
+
print(f"Model saved to: {model_save_path}")
|
| 53 |
+
print(f"Optimal threshold: {optimal_threshold:.3f}")
|
| 54 |
+
print(f"F1-score: {val_metrics['f1_score']:.3f}")
|
| 55 |
+
|
| 56 |
+
finally:
|
| 57 |
+
# Clean up
|
| 58 |
+
if os.path.exists('temp_train.xlsx'):
|
| 59 |
+
os.remove('temp_train.xlsx')
|
| 60 |
+
if os.path.exists('temp_val.xlsx'):
|
| 61 |
+
os.remove('temp_val.xlsx')
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
import argparse
|
| 65 |
+
parser = argparse.ArgumentParser()
|
| 66 |
+
parser.add_argument('data_path', help='Path to labeled CSV file')
|
| 67 |
+
parser.add_argument('--model_path', default='./trained_ohca_model', help='Model save path')
|
| 68 |
+
parser.add_argument('--epochs', type=int, default=3, help='Training epochs')
|
| 69 |
+
parser.add_argument('--test_size', type=float, default=0.2, help='Validation split')
|
| 70 |
+
|
| 71 |
+
args = parser.parse_args()
|
| 72 |
+
train_from_labeled_data(args.data_path, args.model_path, args.test_size, args.epochs)
|