|
|
import pickle |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import os |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
REPO_ID = "RealFishSam/DVAE26-proj" |
|
|
FILENAME = "stacked_ensemble_model.pkl" |
|
|
|
|
|
|
|
|
possible_paths = [ |
|
|
FILENAME, |
|
|
os.path.join('models', FILENAME), |
|
|
os.path.join('..', 'models', FILENAME) |
|
|
] |
|
|
|
|
|
model_path = None |
|
|
for p in possible_paths: |
|
|
if os.path.exists(p): |
|
|
model_path = p |
|
|
break |
|
|
|
|
|
with open(model_path, 'rb') as f: |
|
|
components = pickle.load(f) |
|
|
|
|
|
preprocessor = components['preprocessor'] |
|
|
base_models = components['base_models'] |
|
|
meta_model = components['meta_model'] |
|
|
threshold = components.get('threshold_stacked', 0.5) |
|
|
|
|
|
import argparse |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Stroke Risk Predictor') |
|
|
parser.add_argument('--gender', type=str, default='Male', choices=['Male', 'Female'], help='Gender') |
|
|
parser.add_argument('--age', type=float, default=75, help='Age of the patient') |
|
|
parser.add_argument('--hypertension', type=int, default=1, choices=[0, 1], help='0: No, 1: Yes') |
|
|
parser.add_argument('--heart_disease', type=int, default=1, choices=[0, 1], help='0: No, 1: Yes') |
|
|
parser.add_argument('--ever_married', type=str, default='Yes', choices=['Yes', 'No'], help='Ever married?') |
|
|
parser.add_argument('--work_type', type=str, default='Private', |
|
|
choices=['Private', 'Self-employed', 'Govt_job', 'children', 'Never_worked'], help='Work type') |
|
|
parser.add_argument('--Residence_type', type=str, default='Urban', choices=['Urban', 'Rural'], help='Residence type') |
|
|
parser.add_argument('--avg_glucose_level', type=float, default=220.5, help='Average glucose level') |
|
|
parser.add_argument('--bmi', type=float, default=30.1, help='Body Mass Index') |
|
|
parser.add_argument('--smoking_status', type=str, default='formerly smoked', |
|
|
choices=['formerly smoked', 'never smoked', 'smokes', 'Unknown'], help='Smoking status') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print("\nModel Input:") |
|
|
for arg, value in vars(args).items(): |
|
|
print(f" {arg}: {value}") |
|
|
|
|
|
patient = pd.DataFrame([{ |
|
|
'gender': args.gender, |
|
|
'age': args.age, |
|
|
'hypertension': args.hypertension, |
|
|
'heart_disease': args.heart_disease, |
|
|
'ever_married': args.ever_married, |
|
|
'work_type': args.work_type, |
|
|
'Residence_type': args.Residence_type, |
|
|
'avg_glucose_level': args.avg_glucose_level, |
|
|
'bmi': args.bmi, |
|
|
'smoking_status': args.smoking_status |
|
|
}]) |
|
|
|
|
|
|
|
|
X = preprocessor.transform(patient) |
|
|
|
|
|
|
|
|
preds = [] |
|
|
for name, m in base_models.items(): |
|
|
p = m.predict_proba(X)[:, 1] |
|
|
preds.append(p) |
|
|
|
|
|
|
|
|
meta_X = np.column_stack(preds) |
|
|
final_prob = meta_model.predict_proba(meta_X)[:, 1][0] |
|
|
|
|
|
print(f"Stroke Probability: {final_prob}") |
|
|
|