Spaces:
Sleeping
Sleeping
File size: 4,891 Bytes
66d45ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | #!/usr/bin/env python
"""
CSV Prediction Script for SWAN Menopause Stage Forecasting
This script demonstrates how to use the trained forecasting module to make predictions
on a batch of individuals from a CSV file and save results with confidence scores
and performance metrics.
Usage:
python predict_csv.py --input demo_individuals.csv --model RandomForest
python predict_csv.py --input individuals.csv --output results.csv --model LogisticRegression
The script will:
1. Read input CSV with individual feature values
2. Make predictions using trained model
3. Save results with predicted stage, confidence, and probabilities
4. Display summary statistics
"""
import os
import sys
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
def main():
"""Main function to handle CSV prediction."""
parser = argparse.ArgumentParser(
description='Make menopause stage predictions from CSV file'
)
parser.add_argument(
'--input', '-i',
required=True,
help='Path to input CSV file with individual feature values'
)
parser.add_argument(
'--output', '-o',
default=None,
help='Path to output CSV file (default: input_predictions.csv)'
)
parser.add_argument(
'--model', '-m',
choices=['RandomForest', 'LogisticRegression'],
default='RandomForest',
help='Which model to use for predictions'
)
parser.add_argument(
'--forecast-dir',
default='swan_ml_output',
help='Directory containing trained forecast models'
)
args = parser.parse_args()
# Import after parsing args
try:
from menopause import load_forecast_model, predict_from_csv
except ImportError:
print("ERROR: Could not import menopause module.")
print("Make sure you're in the correct directory and menopause.py is available.")
sys.exit(1)
# Check if input file exists
if not os.path.exists(args.input):
print(f"ERROR: Input file not found: {args.input}")
sys.exit(1)
# Check if forecast models exist
forecast_dir = args.forecast_dir
if not os.path.exists(os.path.join(forecast_dir, 'rf_pipeline.pkl')):
print(f"ERROR: Forecast models not found in {forecast_dir}")
print("Please run 'python menopause.py' first to train models.")
sys.exit(1)
print("="*80)
print("MENOPAUSE STAGE PREDICTION FROM CSV")
print("="*80)
# Load forecaster
print(f"\nLoading forecaster from {forecast_dir}...")
forecast = load_forecast_model(forecast_dir)
# Make predictions
print(f"\nUsing model: {args.model}")
results = predict_from_csv(
args.input,
forecast,
output_csv=args.output,
model=args.model,
output_dir='.'
)
if results is not None:
print("\n" + "="*80)
print("PREDICTION RESULTS")
print("="*80)
# Display results table
print("\nDetailed Results:")
print(results.to_string(index=False))
# Display performance metrics
print("\n" + "="*80)
print("PERFORMANCE SUMMARY")
print("="*80)
print(f"\nTotal Individuals: {len(results)}")
print(f"\nStage Distribution:")
for stage, count in results['predicted_stage'].value_counts().items():
pct = count / len(results) * 100
print(f" {stage}: {count} ({pct:.1f}%)")
print(f"\nConfidence Scores:")
print(f" Mean: {results['confidence'].mean():.3f}")
print(f" Min: {results['confidence'].min():.3f}")
print(f" Max: {results['confidence'].max():.3f}")
print(f" Std Dev: {results['confidence'].std():.3f}")
# Confidence distribution
high_conf = (results['confidence'] > 0.8).sum()
med_conf = ((results['confidence'] > 0.6) & (results['confidence'] <= 0.8)).sum()
low_conf = (results['confidence'] <= 0.6).sum()
print(f"\nConfidence Distribution:")
print(f" High (>0.80): {high_conf}/{len(results)} ({high_conf/len(results)*100:.1f}%)")
print(f" Medium (0.60-0.80): {med_conf}/{len(results)} ({med_conf/len(results)*100:.1f}%)")
print(f" Low (≤0.60): {low_conf}/{len(results)} ({low_conf/len(results)*100:.1f}%)")
# Output file confirmation
output_path = args.output if args.output else f"{Path(args.input).stem}_predictions.csv"
print(f"\n✅ Results saved to: {output_path}")
else:
print("ERROR: Prediction failed.")
sys.exit(1)
print("\n" + "="*80)
if __name__ == '__main__':
main()
|