File size: 4,900 Bytes
493b03a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
"""
Apply OHCA Classifier to New Discharge Notes

This script applies a trained OHCA classifier to new discharge notes.
Input data should have columns: hadm_id, clean_text
"""

import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))

import pandas as pd
from ohca_inference import quick_inference

def validate_discharge_data(df):
    """Validate that discharge data has required columns"""
    required_cols = ['hadm_id', 'clean_text']
    missing_cols = [col for col in required_cols if col not in df.columns]
    
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Check for missing values
    missing_ids = df['hadm_id'].isna().sum()
    missing_text = df['clean_text'].isna().sum()
    
    if missing_ids > 0:
        print(f"Warning: {missing_ids} rows have missing hadm_id")
    if missing_text > 0:
        print(f"Warning: {missing_text} rows have missing clean_text")
    
    print(f"Data validation:")
    print(f"  Total discharge notes: {len(df)}")
    print(f"  Valid records: {len(df.dropna(subset=required_cols))}")

def predict_ohca(model_path, data_path, output_path=None):
    """
    Apply OHCA model to discharge notes
    
    Args:
        model_path: Path to trained model
        data_path: Path to CSV with discharge notes
        output_path: Where to save results (optional)
    """
    print("OHCA Classifier Prediction")
    print("="*30)
    
    # Validate model exists
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model not found: {model_path}")
    
    print(f"Model: {model_path}")
    print(f"Data: {data_path}")
    
    # Load and validate data
    df = pd.read_csv(data_path)
    validate_discharge_data(df)
    
    # Set default output path
    if output_path is None:
        base_name = os.path.splitext(os.path.basename(data_path))[0]
        output_path = f"{base_name}_ohca_predictions.csv"
    
    print(f"Output: {output_path}")
    
    # Run inference
    print(f"\nRunning OHCA prediction on {len(df)} discharge notes...")
    results = quick_inference(
        model_path=model_path,
        data_path=df,
        output_path=output_path
    )
    
    # Analyze results
    if 'ohca_prediction' in results.columns:
        ohca_detected = results['ohca_prediction'].sum()
        threshold_used = results.get('optimal_threshold_used', [0.5]).iloc[0]
    else:
        # Fallback for legacy models
        ohca_detected = (results['ohca_probability'] >= 0.5).sum()
        threshold_used = 0.5
    
    high_confidence = (results['ohca_probability'] >= 0.8).sum()
    very_high_confidence = (results['ohca_probability'] >= 0.9).sum()
    
    print(f"\nResults Summary:")
    print(f"  Total cases analyzed: {len(results)}")
    print(f"  OHCA detected: {ohca_detected} ({ohca_detected/len(results)*100:.1f}%)")
    print(f"  High confidence (≥0.8): {high_confidence}")
    print(f"  Very high confidence (≥0.9): {very_high_confidence}")
    print(f"  Threshold used: {threshold_used:.3f}")
    
    # Show highest probability cases
    print(f"\nTop 5 highest probability cases:")
    top_cases = results.nlargest(5, 'ohca_probability')
    for _, row in top_cases.iterrows():
        print(f"  {row['hadm_id']}: {row['ohca_probability']:.3f}")
    
    print(f"\nResults saved to: {output_path}")
    
    # Clinical recommendations
    if very_high_confidence > 0:
        print(f"\nClinical Recommendations:")
        print(f"  → {very_high_confidence} cases need immediate review (≥90% probability)")
    if high_confidence > very_high_confidence:
        print(f"  → {high_confidence - very_high_confidence} cases need priority review (80-90% probability)")
    
    return results

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Apply OHCA classifier to discharge notes')
    parser.add_argument('model_path', help='Path to trained model directory')
    parser.add_argument('data_path', help='Path to CSV file with discharge notes')
    parser.add_argument('--output', help='Output CSV path (default: auto-generated)')
    
    args = parser.parse_args()
    
    if not os.path.exists(args.model_path):
        print(f"Error: Model not found: {args.model_path}")
        print("Train a model first using: python scripts/train_from_labeled_data.py")
        sys.exit(1)
        
    if not os.path.exists(args.data_path):
        print(f"Error: Data file not found: {args.data_path}")
        print("\nYour CSV file should have columns:")
        print("  hadm_id: Unique admission identifier")
        print("  clean_text: Discharge note text")
        sys.exit(1)
    
    try:
        predict_ohca(args.model_path, args.data_path, args.output)
    except Exception as e:
        print(f"Prediction failed: {e}")
        sys.exit(1)