File size: 3,125 Bytes
1b05cbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python3
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))

import pandas as pd
from sklearn.model_selection import train_test_split
from ohca_training_pipeline import prepare_training_data, train_ohca_model, find_optimal_threshold, save_model_with_metadata

def train_from_labeled_data(data_path, model_save_path="./trained_ohca_model", test_size=0.2, num_epochs=3):
    print("OHCA Classifier Training from Pre-labeled Data")
    print("="*50)
    
    # Load data
    print(f"Loading labeled data from: {data_path}")
    df = pd.read_csv(data_path)
    
    # Add subject_id if missing
    if 'subject_id' not in df.columns:
        print("Adding subject_id column (using hadm_id as patient ID)")
        df['subject_id'] = df['hadm_id']
    
    print(f"Data loaded: {len(df)} cases ({(df['ohca_label']==1).sum()} OHCA, {(df['ohca_label']==0).sum()} non-OHCA)")
    
    # Split data
    train_df, val_df = train_test_split(df, test_size=test_size, stratify=df['ohca_label'], random_state=42)
    print(f"Training: {len(train_df)}, Validation: {len(val_df)}")
    
    # Save temporary files
    train_df.to_excel('temp_train.xlsx', index=False)
    val_df.to_excel('temp_val.xlsx', index=False)
    
    try:
        # Train
        print("Preparing training data...")
        train_dataset, val_dataset, train_df_balanced, val_df_clean, tokenizer = prepare_training_data('temp_train.xlsx', 'temp_val.xlsx')
        
        print(f"Training model for {num_epochs} epochs...")
        model, trained_tokenizer = train_ohca_model(
            train_dataset, val_dataset, train_df_balanced, tokenizer,
            num_epochs=num_epochs, save_path=model_save_path
        )
        
        print("Finding optimal threshold...")
        optimal_threshold, val_metrics = find_optimal_threshold(model, trained_tokenizer, val_df_clean)
        
        print("Saving model with metadata...")
        test_metrics = {'message': 'Trained on user data', 'test_set_size': 0}
        save_model_with_metadata(model, trained_tokenizer, optimal_threshold, val_metrics, test_metrics, model_save_path)
        
        print(f"Training completed!")
        print(f"Model saved to: {model_save_path}")
        print(f"Optimal threshold: {optimal_threshold:.3f}")
        print(f"F1-score: {val_metrics['f1_score']:.3f}")
        
    finally:
        # Clean up
        if os.path.exists('temp_train.xlsx'):
            os.remove('temp_train.xlsx')
        if os.path.exists('temp_val.xlsx'):
            os.remove('temp_val.xlsx')

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('data_path', help='Path to labeled CSV file')
    parser.add_argument('--model_path', default='./trained_ohca_model', help='Model save path')
    parser.add_argument('--epochs', type=int, default=3, help='Training epochs')
    parser.add_argument('--test_size', type=float, default=0.2, help='Validation split')
    
    args = parser.parse_args()
    train_from_labeled_data(args.data_path, args.model_path, args.test_size, args.epochs)