File size: 4,900 Bytes
493b03a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
#!/usr/bin/env python3
"""
Apply OHCA Classifier to New Discharge Notes
This script applies a trained OHCA classifier to new discharge notes.
Input data should have columns: hadm_id, clean_text
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
import pandas as pd
from ohca_inference import quick_inference
def validate_discharge_data(df):
"""Validate that discharge data has required columns"""
required_cols = ['hadm_id', 'clean_text']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Check for missing values
missing_ids = df['hadm_id'].isna().sum()
missing_text = df['clean_text'].isna().sum()
if missing_ids > 0:
print(f"Warning: {missing_ids} rows have missing hadm_id")
if missing_text > 0:
print(f"Warning: {missing_text} rows have missing clean_text")
print(f"Data validation:")
print(f" Total discharge notes: {len(df)}")
print(f" Valid records: {len(df.dropna(subset=required_cols))}")
def predict_ohca(model_path, data_path, output_path=None):
"""
Apply OHCA model to discharge notes
Args:
model_path: Path to trained model
data_path: Path to CSV with discharge notes
output_path: Where to save results (optional)
"""
print("OHCA Classifier Prediction")
print("="*30)
# Validate model exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found: {model_path}")
print(f"Model: {model_path}")
print(f"Data: {data_path}")
# Load and validate data
df = pd.read_csv(data_path)
validate_discharge_data(df)
# Set default output path
if output_path is None:
base_name = os.path.splitext(os.path.basename(data_path))[0]
output_path = f"{base_name}_ohca_predictions.csv"
print(f"Output: {output_path}")
# Run inference
print(f"\nRunning OHCA prediction on {len(df)} discharge notes...")
results = quick_inference(
model_path=model_path,
data_path=df,
output_path=output_path
)
# Analyze results
if 'ohca_prediction' in results.columns:
ohca_detected = results['ohca_prediction'].sum()
threshold_used = results.get('optimal_threshold_used', [0.5]).iloc[0]
else:
# Fallback for legacy models
ohca_detected = (results['ohca_probability'] >= 0.5).sum()
threshold_used = 0.5
high_confidence = (results['ohca_probability'] >= 0.8).sum()
very_high_confidence = (results['ohca_probability'] >= 0.9).sum()
print(f"\nResults Summary:")
print(f" Total cases analyzed: {len(results)}")
print(f" OHCA detected: {ohca_detected} ({ohca_detected/len(results)*100:.1f}%)")
print(f" High confidence (≥0.8): {high_confidence}")
print(f" Very high confidence (≥0.9): {very_high_confidence}")
print(f" Threshold used: {threshold_used:.3f}")
# Show highest probability cases
print(f"\nTop 5 highest probability cases:")
top_cases = results.nlargest(5, 'ohca_probability')
for _, row in top_cases.iterrows():
print(f" {row['hadm_id']}: {row['ohca_probability']:.3f}")
print(f"\nResults saved to: {output_path}")
# Clinical recommendations
if very_high_confidence > 0:
print(f"\nClinical Recommendations:")
print(f" → {very_high_confidence} cases need immediate review (≥90% probability)")
if high_confidence > very_high_confidence:
print(f" → {high_confidence - very_high_confidence} cases need priority review (80-90% probability)")
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Apply OHCA classifier to discharge notes')
parser.add_argument('model_path', help='Path to trained model directory')
parser.add_argument('data_path', help='Path to CSV file with discharge notes')
parser.add_argument('--output', help='Output CSV path (default: auto-generated)')
args = parser.parse_args()
if not os.path.exists(args.model_path):
print(f"Error: Model not found: {args.model_path}")
print("Train a model first using: python scripts/train_from_labeled_data.py")
sys.exit(1)
if not os.path.exists(args.data_path):
print(f"Error: Data file not found: {args.data_path}")
print("\nYour CSV file should have columns:")
print(" hadm_id: Unique admission identifier")
print(" clean_text: Discharge note text")
sys.exit(1)
try:
predict_ohca(args.model_path, args.data_path, args.output)
except Exception as e:
print(f"Prediction failed: {e}")
sys.exit(1)
|