File size: 4,891 Bytes
66d45ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python
"""

CSV Prediction Script for SWAN Menopause Stage Forecasting



This script demonstrates how to use the trained forecasting module to make predictions

on a batch of individuals from a CSV file and save results with confidence scores

and performance metrics.



Usage:

    python predict_csv.py --input demo_individuals.csv --model RandomForest

    python predict_csv.py --input individuals.csv --output results.csv --model LogisticRegression



The script will:

1. Read input CSV with individual feature values

2. Make predictions using trained model

3. Save results with predicted stage, confidence, and probabilities

4. Display summary statistics

"""

import os
import sys
import argparse
import pandas as pd
import numpy as np
from pathlib import Path


def main():
    """Main function to handle CSV prediction."""
    
    parser = argparse.ArgumentParser(
        description='Make menopause stage predictions from CSV file'
    )
    parser.add_argument(
        '--input', '-i',
        required=True,
        help='Path to input CSV file with individual feature values'
    )
    parser.add_argument(
        '--output', '-o',
        default=None,
        help='Path to output CSV file (default: input_predictions.csv)'
    )
    parser.add_argument(
        '--model', '-m',
        choices=['RandomForest', 'LogisticRegression'],
        default='RandomForest',
        help='Which model to use for predictions'
    )
    parser.add_argument(
        '--forecast-dir',
        default='swan_ml_output',
        help='Directory containing trained forecast models'
    )
    
    args = parser.parse_args()
    
    # Import after parsing args
    try:
        from menopause import load_forecast_model, predict_from_csv
    except ImportError:
        print("ERROR: Could not import menopause module.")
        print("Make sure you're in the correct directory and menopause.py is available.")
        sys.exit(1)
    
    # Check if input file exists
    if not os.path.exists(args.input):
        print(f"ERROR: Input file not found: {args.input}")
        sys.exit(1)
    
    # Check if forecast models exist
    forecast_dir = args.forecast_dir
    if not os.path.exists(os.path.join(forecast_dir, 'rf_pipeline.pkl')):
        print(f"ERROR: Forecast models not found in {forecast_dir}")
        print("Please run 'python menopause.py' first to train models.")
        sys.exit(1)
    
    print("="*80)
    print("MENOPAUSE STAGE PREDICTION FROM CSV")
    print("="*80)
    
    # Load forecaster
    print(f"\nLoading forecaster from {forecast_dir}...")
    forecast = load_forecast_model(forecast_dir)
    
    # Make predictions
    print(f"\nUsing model: {args.model}")
    results = predict_from_csv(
        args.input,
        forecast,
        output_csv=args.output,
        model=args.model,
        output_dir='.'
    )
    
    if results is not None:
        print("\n" + "="*80)
        print("PREDICTION RESULTS")
        print("="*80)
        
        # Display results table
        print("\nDetailed Results:")
        print(results.to_string(index=False))
        
        # Display performance metrics
        print("\n" + "="*80)
        print("PERFORMANCE SUMMARY")
        print("="*80)
        
        print(f"\nTotal Individuals: {len(results)}")
        print(f"\nStage Distribution:")
        for stage, count in results['predicted_stage'].value_counts().items():
            pct = count / len(results) * 100
            print(f"  {stage}: {count} ({pct:.1f}%)")
        
        print(f"\nConfidence Scores:")
        print(f"  Mean: {results['confidence'].mean():.3f}")
        print(f"  Min: {results['confidence'].min():.3f}")
        print(f"  Max: {results['confidence'].max():.3f}")
        print(f"  Std Dev: {results['confidence'].std():.3f}")
        
        # Confidence distribution
        high_conf = (results['confidence'] > 0.8).sum()
        med_conf = ((results['confidence'] > 0.6) & (results['confidence'] <= 0.8)).sum()
        low_conf = (results['confidence'] <= 0.6).sum()
        
        print(f"\nConfidence Distribution:")
        print(f"  High (>0.80): {high_conf}/{len(results)} ({high_conf/len(results)*100:.1f}%)")
        print(f"  Medium (0.60-0.80): {med_conf}/{len(results)} ({med_conf/len(results)*100:.1f}%)")
        print(f"  Low (≤0.60): {low_conf}/{len(results)} ({low_conf/len(results)*100:.1f}%)")
        
        # Output file confirmation
        output_path = args.output if args.output else f"{Path(args.input).stem}_predictions.csv"
        print(f"\n✅ Results saved to: {output_path}")
    else:
        print("ERROR: Prediction failed.")
        sys.exit(1)
    
    print("\n" + "="*80)


if __name__ == '__main__':
    main()