monajm36 commited on
Commit
493b03a
·
1 Parent(s): e2ef18e

Add user-friendly scripts for training and prediction workflows

Browse files
examples/scripts/predict_ohca.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Apply OHCA Classifier to New Discharge Notes
4
+
5
+ This script applies a trained OHCA classifier to new discharge notes.
6
+ Input data should have columns: hadm_id, clean_text
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
12
+
13
+ import pandas as pd
14
+ from ohca_inference import quick_inference
15
+
16
+ def validate_discharge_data(df):
17
+ """Validate that discharge data has required columns"""
18
+ required_cols = ['hadm_id', 'clean_text']
19
+ missing_cols = [col for col in required_cols if col not in df.columns]
20
+
21
+ if missing_cols:
22
+ raise ValueError(f"Missing required columns: {missing_cols}")
23
+
24
+ # Check for missing values
25
+ missing_ids = df['hadm_id'].isna().sum()
26
+ missing_text = df['clean_text'].isna().sum()
27
+
28
+ if missing_ids > 0:
29
+ print(f"Warning: {missing_ids} rows have missing hadm_id")
30
+ if missing_text > 0:
31
+ print(f"Warning: {missing_text} rows have missing clean_text")
32
+
33
+ print(f"Data validation:")
34
+ print(f" Total discharge notes: {len(df)}")
35
+ print(f" Valid records: {len(df.dropna(subset=required_cols))}")
36
+
37
+ def predict_ohca(model_path, data_path, output_path=None):
38
+ """
39
+ Apply OHCA model to discharge notes
40
+
41
+ Args:
42
+ model_path: Path to trained model
43
+ data_path: Path to CSV with discharge notes
44
+ output_path: Where to save results (optional)
45
+ """
46
+ print("OHCA Classifier Prediction")
47
+ print("="*30)
48
+
49
+ # Validate model exists
50
+ if not os.path.exists(model_path):
51
+ raise FileNotFoundError(f"Model not found: {model_path}")
52
+
53
+ print(f"Model: {model_path}")
54
+ print(f"Data: {data_path}")
55
+
56
+ # Load and validate data
57
+ df = pd.read_csv(data_path)
58
+ validate_discharge_data(df)
59
+
60
+ # Set default output path
61
+ if output_path is None:
62
+ base_name = os.path.splitext(os.path.basename(data_path))[0]
63
+ output_path = f"{base_name}_ohca_predictions.csv"
64
+
65
+ print(f"Output: {output_path}")
66
+
67
+ # Run inference
68
+ print(f"\nRunning OHCA prediction on {len(df)} discharge notes...")
69
+ results = quick_inference(
70
+ model_path=model_path,
71
+ data_path=df,
72
+ output_path=output_path
73
+ )
74
+
75
+ # Analyze results
76
+ if 'ohca_prediction' in results.columns:
77
+ ohca_detected = results['ohca_prediction'].sum()
78
+ threshold_used = results.get('optimal_threshold_used', [0.5]).iloc[0]
79
+ else:
80
+ # Fallback for legacy models
81
+ ohca_detected = (results['ohca_probability'] >= 0.5).sum()
82
+ threshold_used = 0.5
83
+
84
+ high_confidence = (results['ohca_probability'] >= 0.8).sum()
85
+ very_high_confidence = (results['ohca_probability'] >= 0.9).sum()
86
+
87
+ print(f"\nResults Summary:")
88
+ print(f" Total cases analyzed: {len(results)}")
89
+ print(f" OHCA detected: {ohca_detected} ({ohca_detected/len(results)*100:.1f}%)")
90
+ print(f" High confidence (≥0.8): {high_confidence}")
91
+ print(f" Very high confidence (≥0.9): {very_high_confidence}")
92
+ print(f" Threshold used: {threshold_used:.3f}")
93
+
94
+ # Show highest probability cases
95
+ print(f"\nTop 5 highest probability cases:")
96
+ top_cases = results.nlargest(5, 'ohca_probability')
97
+ for _, row in top_cases.iterrows():
98
+ print(f" {row['hadm_id']}: {row['ohca_probability']:.3f}")
99
+
100
+ print(f"\nResults saved to: {output_path}")
101
+
102
+ # Clinical recommendations
103
+ if very_high_confidence > 0:
104
+ print(f"\nClinical Recommendations:")
105
+ print(f" → {very_high_confidence} cases need immediate review (≥90% probability)")
106
+ if high_confidence > very_high_confidence:
107
+ print(f" → {high_confidence - very_high_confidence} cases need priority review (80-90% probability)")
108
+
109
+ return results
110
+
111
+ if __name__ == "__main__":
112
+ import argparse
113
+
114
+ parser = argparse.ArgumentParser(description='Apply OHCA classifier to discharge notes')
115
+ parser.add_argument('model_path', help='Path to trained model directory')
116
+ parser.add_argument('data_path', help='Path to CSV file with discharge notes')
117
+ parser.add_argument('--output', help='Output CSV path (default: auto-generated)')
118
+
119
+ args = parser.parse_args()
120
+
121
+ if not os.path.exists(args.model_path):
122
+ print(f"Error: Model not found: {args.model_path}")
123
+ print("Train a model first using: python scripts/train_from_labeled_data.py")
124
+ sys.exit(1)
125
+
126
+ if not os.path.exists(args.data_path):
127
+ print(f"Error: Data file not found: {args.data_path}")
128
+ print("\nYour CSV file should have columns:")
129
+ print(" hadm_id: Unique admission identifier")
130
+ print(" clean_text: Discharge note text")
131
+ sys.exit(1)
132
+
133
+ try:
134
+ predict_ohca(args.model_path, args.data_path, args.output)
135
+ except Exception as e:
136
+ print(f"Prediction failed: {e}")
137
+ sys.exit(1)
examples/scripts/prepare_data.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Data Preparation Helper for OHCA Classifier
4
+
5
+ This script helps prepare your data in the correct format for training or inference.
6
+ """
7
+
8
+ import pandas as pd
9
+ import sys
10
+
11
+ def prepare_labeled_data(input_path, output_path=None):
12
+ """Prepare manually labeled data for training"""
13
+ print("Preparing labeled data for training...")
14
+
15
+ df = pd.read_csv(input_path)
16
+ print(f"Loaded {len(df)} records")
17
+ print(f"Columns: {list(df.columns)}")
18
+
19
+ # Interactive column mapping
20
+ required_cols = ['hadm_id', 'clean_text', 'ohca_label']
21
+ column_mapping = {}
22
+
23
+ for req_col in required_cols:
24
+ if req_col not in df.columns:
25
+ print(f"\nColumn '{req_col}' not found.")
26
+ print(f"Available columns: {list(df.columns)}")
27
+ mapped_col = input(f"Which column should be used for '{req_col}'? ")
28
+ if mapped_col in df.columns:
29
+ column_mapping[mapped_col] = req_col
30
+ else:
31
+ print(f"Column '{mapped_col}' not found. Skipping...")
32
+
33
+ # Apply mapping
34
+ if column_mapping:
35
+ df = df.rename(columns=column_mapping)
36
+ print(f"Applied column mapping: {column_mapping}")
37
+
38
+ # Add missing optional columns
39
+ if 'subject_id' not in df.columns:
40
+ df['subject_id'] = df['hadm_id']
41
+ print("Added subject_id column (copied from hadm_id)")
42
+
43
+ if 'confidence' not in df.columns:
44
+ df['confidence'] = 4
45
+ print("Added default confidence scores")
46
+
47
+ # Validate and clean
48
+ df = df.dropna(subset=['hadm_id', 'clean_text', 'ohca_label'])
49
+
50
+ # Set output path
51
+ if output_path is None:
52
+ base_name = input_path.replace('.csv', '')
53
+ output_path = f"{base_name}_prepared.csv"
54
+
55
+ df.to_csv(output_path, index=False)
56
+
57
+ print(f"\nData prepared successfully:")
58
+ print(f" Output: {output_path}")
59
+ print(f" Records: {len(df)}")
60
+ print(f" OHCA cases: {(df['ohca_label']==1).sum()}")
61
+ print(f" Columns: {list(df.columns)}")
62
+
63
+ def prepare_discharge_notes(input_path, output_path=None):
64
+ """Prepare discharge notes for inference"""
65
+ print("Preparing discharge notes for inference...")
66
+
67
+ df = pd.read_csv(input_path)
68
+ print(f"Loaded {len(df)} records")
69
+ print(f"Columns: {list(df.columns)}")
70
+
71
+ # Interactive column mapping
72
+ required_cols = ['hadm_id', 'clean_text']
73
+ column_mapping = {}
74
+
75
+ for req_col in required_cols:
76
+ if req_col not in df.columns:
77
+ print(f"\nColumn '{req_col}' not found.")
78
+ print(f"Available columns: {list(df.columns)}")
79
+ mapped_col = input(f"Which column should be used for '{req_col}'? ")
80
+ if mapped_col in df.columns:
81
+ column_mapping[mapped_col] = req_col
82
+
83
+ # Apply mapping
84
+ if column_mapping:
85
+ df = df.rename(columns=column_mapping)
86
+ print(f"Applied column mapping: {column_mapping}")
87
+
88
+ # Clean data
89
+ df = df.dropna(subset=['hadm_id', 'clean_text'])
90
+
91
+ # Set output path
92
+ if output_path is None:
93
+ base_name = input_path.replace('.csv', '')
94
+ output_path = f"{base_name}_prepared.csv"
95
+
96
+ df.to_csv(output_path, index=False)
97
+
98
+ print(f"\nDischarge notes prepared:")
99
+ print(f" Output: {output_path}")
100
+ print(f" Records: {len(df)}")
101
+
102
+ if __name__ == "__main__":
103
+ if len(sys.argv) < 2:
104
+ print("Usage:")
105
+ print(" python scripts/prepare_data.py labeled <input.csv> # For training data")
106
+ print(" python scripts/prepare_data.py discharge <input.csv> # For inference data")
107
+ sys.exit(1)
108
+
109
+ data_type = sys.argv[1]
110
+ input_path = sys.argv[2]
111
+
112
+ if data_type == "labeled":
113
+ prepare_labeled_data(input_path)
114
+ elif data_type == "discharge":
115
+ prepare_discharge_notes(input_path)
116
+ else:
117
+ print("Data type must be 'labeled' or 'discharge'")
118
+ sys.exit(1)
examples/scripts/train_from_labeled_data.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train OHCA Classifier from Pre-labeled Data
4
+
5
+ This script trains a v3.0 OHCA classifier using your manually labeled data.
6
+ Your data should have columns: hadm_id, clean_text, ohca_label (and optionally subject_id, confidence)
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
12
+
13
+ import pandas as pd
14
+ from sklearn.model_selection import train_test_split
15
+ from ohca_training_pipeline import prepare_training_data, train_ohca_model, find_optimal_threshold, save_model_with_metadata
16
+
17
+ def validate_labeled_data(df):
18
+ """Validate that the labeled data has required columns and format"""
19
+ required_cols = ['hadm_id', 'clean_text', 'ohca_label']
20
+ missing_cols = [col for col in required_cols if col not in df.columns]
21
+
22
+ if missing_cols:
23
+ raise ValueError(f"Missing required columns: {missing_cols}")
24
+
25
+ # Check ohca_label values
26
+ unique_labels = df['ohca_label'].unique()
27
+ if not set(unique_labels).issubset({0, 1}):
28
+ raise ValueError(f"ohca_label must be 0 or 1, found: {unique_labels}")
29
+
30
+ print(f"Data validation passed:")
31
+ print(f" Total cases: {len(df)}")
32
+ print(f" OHCA cases (label=1): {(df['ohca_label']==1).sum()}")
33
+ print(f" Non-OHCA cases (label=0): {(df['ohca_label']==0).sum()}")
34
+ print(f" OHCA prevalence: {(df['ohca_label']==1).mean():.1%}")
35
+
36
+ def train_from_labeled_data(data_path, model_save_path="./trained_ohca_model",
37
+ test_size=0.2, num_epochs=3):
38
+ """
39
+ Train OHCA model from pre-labeled data
40
+
41
+ Args:
42
+ data_path: Path to CSV with labeled data
43
+ model_save_path: Where to save the trained model
44
+ test_size: Fraction to use for validation (default 0.2 = 20%)
45
+ num_epochs: Number of training epochs
46
+ """
47
+ print("OHCA Classifier Training from Pre-labeled Data")
48
+ print("="*50)
49
+
50
+ # Load and validate data
51
+ print(f"Loading labeled data from: {data_path}")
52
+ df = pd.read_csv(data_path)
53
+
54
+ # Add missing columns if needed
55
+ if 'subject_id' not in df.columns:
56
+ print("Adding subject_id column (using hadm_id as patient ID)")
57
+ df['subject_id'] = df['hadm_id']
58
+
59
+ if 'confidence' not in df.columns:
60
+ print("Adding default confidence scores")
61
+ df['confidence'] = 4 # Default confidence
62
+
63
+ validate_labeled_data(df)
64
+
65
+ # Split into train/validation
66
+ print(f"\nSplitting data (train: {1-test_size:.0%}, validation: {test_size:.0%})")
67
+ train_df, val_df = train_test_split(
68
+ df, test_size=test_size,
69
+ stratify=df['ohca_label'],
70
+ random_state=42
71
+ )
72
+
73
+ print(f"Training data: {len(train_df)} cases ({(train_df['ohca_label']==1).sum()} OHCA)")
74
+ print(f"Validation data: {len(val_df)} cases ({(val_df['ohca_label']==1).sum()} OHCA)")
75
+
76
+ # Save as temporary Excel files
77
+ temp_train = 'temp_train_data.xlsx'
78
+ temp_val = 'temp_val_data.xlsx'
79
+ train_df.to_excel(temp_train, index=False)
80
+ val_df.to_excel(temp_val, index=False)
81
+
82
+ try:
83
+ # Prepare training datasets
84
+ print("\nPreparing training datasets...")
85
+ train_dataset, val_dataset, train_df_balanced, val_df_clean, tokenizer = prepare_training_data(
86
+ temp_train, temp_val
87
+ )
88
+
89
+ # Train the model
90
+ print(f"\nTraining model for {num_epochs} epochs...")
91
+ model, trained_tokenizer = train_ohca_model(
92
+ train_dataset, val_dataset, train_df_balanced, tokenizer,
93
+ num_epochs=num_epochs,
94
+ save_path=model_save_path
95
+ )
96
+
97
+ # Find optimal threshold
98
+ print("\nFinding optimal threshold...")
99
+ optimal_threshold, val_metrics = find_optimal_threshold(
100
+ model, trained_tokenizer, val_df_clean
101
+ )
102
+
103
+ # Save model with metadata
104
+ print("\nSaving model with metadata...")
105
+ test_metrics = {'message': 'Trained on user-provided labeled data', 'test_set_size': 0}
106
+ save_model_with_metadata(
107
+ model, trained_tokenizer, optimal_threshold,
108
+ val_metrics, test_metrics, model_save_path
109
+ )
110
+
111
+ print(f"\nTraining completed successfully!")
112
+ print(f"Model saved to: {model_save_path}")
113
+ print(f"Optimal threshold: {optimal_threshold:.3f}")
114
+ print(f"Validation F1-score: {val_metrics['f1_score']:.3f}")
115
+
116
+ return {
117
+ 'model_path': model_save_path,
118
+ 'optimal_threshold': optimal_threshold,
119
+ 'metrics': val_metrics
120
+ }
121
+
122
+ finally:
123
+ # Clean up temporary files
124
+ for temp_file in [temp_train, temp_val]:
125
+ if os.path.exists(temp_file):
126
+ os.remove(temp_file)
127
+
128
+ if __name__ == "__main__":
129
+ import argparse
130
+
131
+ parser = argparse.ArgumentParser(description='Train OHCA classifier from labeled data')
132
+ parser.add_argument('data_path', help='Path to CSV file with labeled data')
133
+ parser.add_argument('--model_path', default='./trained_ohca_model',
134
+ help='Where to save trained model (default: ./trained_ohca_model)')
135
+ parser.add_argument('--epochs', type=int, default=3,
136
+ help='Number of training epochs (default: 3)')
137
+ parser.add_argument('--test_size', type=float, default=0.2,
138
+ help='Validation split fraction (default: 0.2)')
139
+
140
+ args = parser.parse_args()
141
+
142
+ if not os.path.exists(args.data_path):
143
+ print(f"Error: Data file not found: {args.data_path}")
144
+ print("\nYour CSV file should have columns:")
145
+ print(" hadm_id: Unique admission identifier")
146
+ print(" clean_text: Discharge note text")
147
+ print(" ohca_label: 1 for OHCA, 0 for non-OHCA")
148
+ print(" subject_id: Patient ID (optional - will use hadm_id if missing)")
149
+ sys.exit(1)
150
+
151
+ try:
152
+ train_from_labeled_data(args.data_path, args.model_path, args.test_size, args.epochs)
153
+ except Exception as e:
154
+ print(f"Training failed: {e}")
155
+ sys.exit(1)