monajm36 commited on
Commit
1b05cbb
·
1 Parent(s): 493b03a

Add user-friendly training and prediction scripts

Browse files
Files changed (1) hide show
  1. scripts/train_from_labeled_data.py +72 -0
scripts/train_from_labeled_data.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import os
4
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
5
+
6
+ import pandas as pd
7
+ from sklearn.model_selection import train_test_split
8
+ from ohca_training_pipeline import prepare_training_data, train_ohca_model, find_optimal_threshold, save_model_with_metadata
9
+
10
+ def train_from_labeled_data(data_path, model_save_path="./trained_ohca_model", test_size=0.2, num_epochs=3):
11
+ print("OHCA Classifier Training from Pre-labeled Data")
12
+ print("="*50)
13
+
14
+ # Load data
15
+ print(f"Loading labeled data from: {data_path}")
16
+ df = pd.read_csv(data_path)
17
+
18
+ # Add subject_id if missing
19
+ if 'subject_id' not in df.columns:
20
+ print("Adding subject_id column (using hadm_id as patient ID)")
21
+ df['subject_id'] = df['hadm_id']
22
+
23
+ print(f"Data loaded: {len(df)} cases ({(df['ohca_label']==1).sum()} OHCA, {(df['ohca_label']==0).sum()} non-OHCA)")
24
+
25
+ # Split data
26
+ train_df, val_df = train_test_split(df, test_size=test_size, stratify=df['ohca_label'], random_state=42)
27
+ print(f"Training: {len(train_df)}, Validation: {len(val_df)}")
28
+
29
+ # Save temporary files
30
+ train_df.to_excel('temp_train.xlsx', index=False)
31
+ val_df.to_excel('temp_val.xlsx', index=False)
32
+
33
+ try:
34
+ # Train
35
+ print("Preparing training data...")
36
+ train_dataset, val_dataset, train_df_balanced, val_df_clean, tokenizer = prepare_training_data('temp_train.xlsx', 'temp_val.xlsx')
37
+
38
+ print(f"Training model for {num_epochs} epochs...")
39
+ model, trained_tokenizer = train_ohca_model(
40
+ train_dataset, val_dataset, train_df_balanced, tokenizer,
41
+ num_epochs=num_epochs, save_path=model_save_path
42
+ )
43
+
44
+ print("Finding optimal threshold...")
45
+ optimal_threshold, val_metrics = find_optimal_threshold(model, trained_tokenizer, val_df_clean)
46
+
47
+ print("Saving model with metadata...")
48
+ test_metrics = {'message': 'Trained on user data', 'test_set_size': 0}
49
+ save_model_with_metadata(model, trained_tokenizer, optimal_threshold, val_metrics, test_metrics, model_save_path)
50
+
51
+ print(f"Training completed!")
52
+ print(f"Model saved to: {model_save_path}")
53
+ print(f"Optimal threshold: {optimal_threshold:.3f}")
54
+ print(f"F1-score: {val_metrics['f1_score']:.3f}")
55
+
56
+ finally:
57
+ # Clean up
58
+ if os.path.exists('temp_train.xlsx'):
59
+ os.remove('temp_train.xlsx')
60
+ if os.path.exists('temp_val.xlsx'):
61
+ os.remove('temp_val.xlsx')
62
+
63
+ if __name__ == "__main__":
64
+ import argparse
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument('data_path', help='Path to labeled CSV file')
67
+ parser.add_argument('--model_path', default='./trained_ohca_model', help='Model save path')
68
+ parser.add_argument('--epochs', type=int, default=3, help='Training epochs')
69
+ parser.add_argument('--test_size', type=float, default=0.2, help='Validation split')
70
+
71
+ args = parser.parse_args()
72
+ train_from_labeled_data(args.data_path, args.model_path, args.test_size, args.epochs)