|
|
|
|
|
""" |
|
|
Apply OHCA Classifier to New Discharge Notes |
|
|
|
|
|
This script applies a trained OHCA classifier to new discharge notes. |
|
|
Input data should have columns: hadm_id, clean_text |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src')) |
|
|
|
|
|
import pandas as pd |
|
|
from ohca_inference import quick_inference |
|
|
|
|
|
def validate_discharge_data(df): |
|
|
"""Validate that discharge data has required columns""" |
|
|
required_cols = ['hadm_id', 'clean_text'] |
|
|
missing_cols = [col for col in required_cols if col not in df.columns] |
|
|
|
|
|
if missing_cols: |
|
|
raise ValueError(f"Missing required columns: {missing_cols}") |
|
|
|
|
|
|
|
|
missing_ids = df['hadm_id'].isna().sum() |
|
|
missing_text = df['clean_text'].isna().sum() |
|
|
|
|
|
if missing_ids > 0: |
|
|
print(f"Warning: {missing_ids} rows have missing hadm_id") |
|
|
if missing_text > 0: |
|
|
print(f"Warning: {missing_text} rows have missing clean_text") |
|
|
|
|
|
print(f"Data validation:") |
|
|
print(f" Total discharge notes: {len(df)}") |
|
|
print(f" Valid records: {len(df.dropna(subset=required_cols))}") |
|
|
|
|
|
def predict_ohca(model_path, data_path, output_path=None): |
|
|
""" |
|
|
Apply OHCA model to discharge notes |
|
|
|
|
|
Args: |
|
|
model_path: Path to trained model |
|
|
data_path: Path to CSV with discharge notes |
|
|
output_path: Where to save results (optional) |
|
|
""" |
|
|
print("OHCA Classifier Prediction") |
|
|
print("="*30) |
|
|
|
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model not found: {model_path}") |
|
|
|
|
|
print(f"Model: {model_path}") |
|
|
print(f"Data: {data_path}") |
|
|
|
|
|
|
|
|
df = pd.read_csv(data_path) |
|
|
validate_discharge_data(df) |
|
|
|
|
|
|
|
|
if output_path is None: |
|
|
base_name = os.path.splitext(os.path.basename(data_path))[0] |
|
|
output_path = f"{base_name}_ohca_predictions.csv" |
|
|
|
|
|
print(f"Output: {output_path}") |
|
|
|
|
|
|
|
|
print(f"\nRunning OHCA prediction on {len(df)} discharge notes...") |
|
|
results = quick_inference( |
|
|
model_path=model_path, |
|
|
data_path=df, |
|
|
output_path=output_path |
|
|
) |
|
|
|
|
|
|
|
|
if 'ohca_prediction' in results.columns: |
|
|
ohca_detected = results['ohca_prediction'].sum() |
|
|
threshold_used = results.get('optimal_threshold_used', [0.5]).iloc[0] |
|
|
else: |
|
|
|
|
|
ohca_detected = (results['ohca_probability'] >= 0.5).sum() |
|
|
threshold_used = 0.5 |
|
|
|
|
|
high_confidence = (results['ohca_probability'] >= 0.8).sum() |
|
|
very_high_confidence = (results['ohca_probability'] >= 0.9).sum() |
|
|
|
|
|
print(f"\nResults Summary:") |
|
|
print(f" Total cases analyzed: {len(results)}") |
|
|
print(f" OHCA detected: {ohca_detected} ({ohca_detected/len(results)*100:.1f}%)") |
|
|
print(f" High confidence (≥0.8): {high_confidence}") |
|
|
print(f" Very high confidence (≥0.9): {very_high_confidence}") |
|
|
print(f" Threshold used: {threshold_used:.3f}") |
|
|
|
|
|
|
|
|
print(f"\nTop 5 highest probability cases:") |
|
|
top_cases = results.nlargest(5, 'ohca_probability') |
|
|
for _, row in top_cases.iterrows(): |
|
|
print(f" {row['hadm_id']}: {row['ohca_probability']:.3f}") |
|
|
|
|
|
print(f"\nResults saved to: {output_path}") |
|
|
|
|
|
|
|
|
if very_high_confidence > 0: |
|
|
print(f"\nClinical Recommendations:") |
|
|
print(f" → {very_high_confidence} cases need immediate review (≥90% probability)") |
|
|
if high_confidence > very_high_confidence: |
|
|
print(f" → {high_confidence - very_high_confidence} cases need priority review (80-90% probability)") |
|
|
|
|
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Apply OHCA classifier to discharge notes') |
|
|
parser.add_argument('model_path', help='Path to trained model directory') |
|
|
parser.add_argument('data_path', help='Path to CSV file with discharge notes') |
|
|
parser.add_argument('--output', help='Output CSV path (default: auto-generated)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.model_path): |
|
|
print(f"Error: Model not found: {args.model_path}") |
|
|
print("Train a model first using: python scripts/train_from_labeled_data.py") |
|
|
sys.exit(1) |
|
|
|
|
|
if not os.path.exists(args.data_path): |
|
|
print(f"Error: Data file not found: {args.data_path}") |
|
|
print("\nYour CSV file should have columns:") |
|
|
print(" hadm_id: Unique admission identifier") |
|
|
print(" clean_text: Discharge note text") |
|
|
sys.exit(1) |
|
|
|
|
|
try: |
|
|
predict_ohca(args.model_path, args.data_path, args.output) |
|
|
except Exception as e: |
|
|
print(f"Prediction failed: {e}") |
|
|
sys.exit(1) |
|
|
|