File size: 2,777 Bytes
1cd24c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab8cf1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd24c3
ab8cf1f
 
 
 
 
 
 
 
 
 
1cd24c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pickle
import pandas as pd
import numpy as np
import os
from huggingface_hub import hf_hub_download

# Constants
REPO_ID = "RealFishSam/DVAE26-proj"
FILENAME = "stacked_ensemble_model.pkl"

# Check locations
possible_paths = [
    FILENAME,
    os.path.join('models', FILENAME),
    os.path.join('..', 'models', FILENAME)
]

model_path = None
for p in possible_paths:
    if os.path.exists(p):
        model_path = p
        break

with open(model_path, 'rb') as f:
    components = pickle.load(f)

preprocessor = components['preprocessor']
base_models = components['base_models']
meta_model = components['meta_model']
threshold = components.get('threshold_stacked', 0.5)

import argparse

# Parse arguments
parser = argparse.ArgumentParser(description='Stroke Risk Predictor')
parser.add_argument('--gender', type=str, default='Male', choices=['Male', 'Female'], help='Gender')
parser.add_argument('--age', type=float, default=75, help='Age of the patient')
parser.add_argument('--hypertension', type=int, default=1, choices=[0, 1], help='0: No, 1: Yes')
parser.add_argument('--heart_disease', type=int, default=1, choices=[0, 1], help='0: No, 1: Yes')
parser.add_argument('--ever_married', type=str, default='Yes', choices=['Yes', 'No'], help='Ever married?')
parser.add_argument('--work_type', type=str, default='Private', 
                    choices=['Private', 'Self-employed', 'Govt_job', 'children', 'Never_worked'], help='Work type')
parser.add_argument('--Residence_type', type=str, default='Urban', choices=['Urban', 'Rural'], help='Residence type')
parser.add_argument('--avg_glucose_level', type=float, default=220.5, help='Average glucose level')
parser.add_argument('--bmi', type=float, default=30.1, help='Body Mass Index')
parser.add_argument('--smoking_status', type=str, default='formerly smoked', 
                    choices=['formerly smoked', 'never smoked', 'smokes', 'Unknown'], help='Smoking status')

args = parser.parse_args()

print("\nModel Input:")
for arg, value in vars(args).items():
    print(f"  {arg}: {value}")

patient = pd.DataFrame([{
    'gender': args.gender,
    'age': args.age,
    'hypertension': args.hypertension,
    'heart_disease': args.heart_disease,
    'ever_married': args.ever_married,
    'work_type': args.work_type,
    'Residence_type': args.Residence_type,
    'avg_glucose_level': args.avg_glucose_level,
    'bmi': args.bmi,
    'smoking_status': args.smoking_status
}])

# 1. Preprocess
X = preprocessor.transform(patient)

# 2. Base model predictions
preds = []
for name, m in base_models.items():
    p = m.predict_proba(X)[:, 1]
    preds.append(p)

# 3. Meta prediction
meta_X = np.column_stack(preds)
final_prob = meta_model.predict_proba(meta_X)[:, 1][0]

print(f"Stroke Probability: {final_prob}")