|
|
|
|
|
import sys |
|
|
import os |
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src')) |
|
|
|
|
|
import pandas as pd |
|
|
from sklearn.model_selection import train_test_split |
|
|
from ohca_training_pipeline import prepare_training_data, train_ohca_model, find_optimal_threshold, save_model_with_metadata |
|
|
|
|
|
def train_from_labeled_data(data_path, model_save_path="./trained_ohca_model", test_size=0.2, num_epochs=3): |
|
|
print("OHCA Classifier Training from Pre-labeled Data") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
print(f"Loading labeled data from: {data_path}") |
|
|
df = pd.read_csv(data_path) |
|
|
|
|
|
|
|
|
if 'subject_id' not in df.columns: |
|
|
print("Adding subject_id column (using hadm_id as patient ID)") |
|
|
df['subject_id'] = df['hadm_id'] |
|
|
|
|
|
print(f"Data loaded: {len(df)} cases ({(df['ohca_label']==1).sum()} OHCA, {(df['ohca_label']==0).sum()} non-OHCA)") |
|
|
|
|
|
|
|
|
train_df, val_df = train_test_split(df, test_size=test_size, stratify=df['ohca_label'], random_state=42) |
|
|
print(f"Training: {len(train_df)}, Validation: {len(val_df)}") |
|
|
|
|
|
|
|
|
train_df.to_excel('temp_train.xlsx', index=False) |
|
|
val_df.to_excel('temp_val.xlsx', index=False) |
|
|
|
|
|
try: |
|
|
|
|
|
print("Preparing training data...") |
|
|
train_dataset, val_dataset, train_df_balanced, val_df_clean, tokenizer = prepare_training_data('temp_train.xlsx', 'temp_val.xlsx') |
|
|
|
|
|
print(f"Training model for {num_epochs} epochs...") |
|
|
model, trained_tokenizer = train_ohca_model( |
|
|
train_dataset, val_dataset, train_df_balanced, tokenizer, |
|
|
num_epochs=num_epochs, save_path=model_save_path |
|
|
) |
|
|
|
|
|
print("Finding optimal threshold...") |
|
|
optimal_threshold, val_metrics = find_optimal_threshold(model, trained_tokenizer, val_df_clean) |
|
|
|
|
|
print("Saving model with metadata...") |
|
|
test_metrics = {'message': 'Trained on user data', 'test_set_size': 0} |
|
|
save_model_with_metadata(model, trained_tokenizer, optimal_threshold, val_metrics, test_metrics, model_save_path) |
|
|
|
|
|
print(f"Training completed!") |
|
|
print(f"Model saved to: {model_save_path}") |
|
|
print(f"Optimal threshold: {optimal_threshold:.3f}") |
|
|
print(f"F1-score: {val_metrics['f1_score']:.3f}") |
|
|
|
|
|
finally: |
|
|
|
|
|
if os.path.exists('temp_train.xlsx'): |
|
|
os.remove('temp_train.xlsx') |
|
|
if os.path.exists('temp_val.xlsx'): |
|
|
os.remove('temp_val.xlsx') |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('data_path', help='Path to labeled CSV file') |
|
|
parser.add_argument('--model_path', default='./trained_ohca_model', help='Model save path') |
|
|
parser.add_argument('--epochs', type=int, default=3, help='Training epochs') |
|
|
parser.add_argument('--test_size', type=float, default=0.2, help='Validation split') |
|
|
|
|
|
args = parser.parse_args() |
|
|
train_from_labeled_data(args.data_path, args.model_path, args.test_size, args.epochs) |
|
|
|