import pickle import pandas as pd import numpy as np import os from huggingface_hub import hf_hub_download # Constants REPO_ID = "RealFishSam/DVAE26-proj" FILENAME = "stacked_ensemble_model.pkl" # Check locations possible_paths = [ FILENAME, os.path.join('models', FILENAME), os.path.join('..', 'models', FILENAME) ] model_path = None for p in possible_paths: if os.path.exists(p): model_path = p break with open(model_path, 'rb') as f: components = pickle.load(f) preprocessor = components['preprocessor'] base_models = components['base_models'] meta_model = components['meta_model'] threshold = components.get('threshold_stacked', 0.5) import argparse # Parse arguments parser = argparse.ArgumentParser(description='Stroke Risk Predictor') parser.add_argument('--gender', type=str, default='Male', choices=['Male', 'Female'], help='Gender') parser.add_argument('--age', type=float, default=75, help='Age of the patient') parser.add_argument('--hypertension', type=int, default=1, choices=[0, 1], help='0: No, 1: Yes') parser.add_argument('--heart_disease', type=int, default=1, choices=[0, 1], help='0: No, 1: Yes') parser.add_argument('--ever_married', type=str, default='Yes', choices=['Yes', 'No'], help='Ever married?') parser.add_argument('--work_type', type=str, default='Private', choices=['Private', 'Self-employed', 'Govt_job', 'children', 'Never_worked'], help='Work type') parser.add_argument('--Residence_type', type=str, default='Urban', choices=['Urban', 'Rural'], help='Residence type') parser.add_argument('--avg_glucose_level', type=float, default=220.5, help='Average glucose level') parser.add_argument('--bmi', type=float, default=30.1, help='Body Mass Index') parser.add_argument('--smoking_status', type=str, default='formerly smoked', choices=['formerly smoked', 'never smoked', 'smokes', 'Unknown'], help='Smoking status') args = parser.parse_args() print("\nModel Input:") for arg, value in vars(args).items(): print(f" {arg}: {value}") patient = pd.DataFrame([{ 'gender': args.gender, 'age': args.age, 'hypertension': args.hypertension, 'heart_disease': args.heart_disease, 'ever_married': args.ever_married, 'work_type': args.work_type, 'Residence_type': args.Residence_type, 'avg_glucose_level': args.avg_glucose_level, 'bmi': args.bmi, 'smoking_status': args.smoking_status }]) # 1. Preprocess X = preprocessor.transform(patient) # 2. Base model predictions preds = [] for name, m in base_models.items(): p = m.predict_proba(X)[:, 1] preds.append(p) # 3. Meta prediction meta_X = np.column_stack(preds) final_prob = meta_model.predict_proba(meta_X)[:, 1][0] print(f"Stroke Probability: {final_prob}")