DVAE26-proj / predict.py
RealFishSam's picture
Upload predict.py with huggingface_hub
ab8cf1f verified
import pickle
import pandas as pd
import numpy as np
import os
from huggingface_hub import hf_hub_download
# Constants
REPO_ID = "RealFishSam/DVAE26-proj"
FILENAME = "stacked_ensemble_model.pkl"
# Check locations
possible_paths = [
FILENAME,
os.path.join('models', FILENAME),
os.path.join('..', 'models', FILENAME)
]
model_path = None
for p in possible_paths:
if os.path.exists(p):
model_path = p
break
with open(model_path, 'rb') as f:
components = pickle.load(f)
preprocessor = components['preprocessor']
base_models = components['base_models']
meta_model = components['meta_model']
threshold = components.get('threshold_stacked', 0.5)
import argparse
# Parse arguments
parser = argparse.ArgumentParser(description='Stroke Risk Predictor')
parser.add_argument('--gender', type=str, default='Male', choices=['Male', 'Female'], help='Gender')
parser.add_argument('--age', type=float, default=75, help='Age of the patient')
parser.add_argument('--hypertension', type=int, default=1, choices=[0, 1], help='0: No, 1: Yes')
parser.add_argument('--heart_disease', type=int, default=1, choices=[0, 1], help='0: No, 1: Yes')
parser.add_argument('--ever_married', type=str, default='Yes', choices=['Yes', 'No'], help='Ever married?')
parser.add_argument('--work_type', type=str, default='Private',
choices=['Private', 'Self-employed', 'Govt_job', 'children', 'Never_worked'], help='Work type')
parser.add_argument('--Residence_type', type=str, default='Urban', choices=['Urban', 'Rural'], help='Residence type')
parser.add_argument('--avg_glucose_level', type=float, default=220.5, help='Average glucose level')
parser.add_argument('--bmi', type=float, default=30.1, help='Body Mass Index')
parser.add_argument('--smoking_status', type=str, default='formerly smoked',
choices=['formerly smoked', 'never smoked', 'smokes', 'Unknown'], help='Smoking status')
args = parser.parse_args()
print("\nModel Input:")
for arg, value in vars(args).items():
print(f" {arg}: {value}")
patient = pd.DataFrame([{
'gender': args.gender,
'age': args.age,
'hypertension': args.hypertension,
'heart_disease': args.heart_disease,
'ever_married': args.ever_married,
'work_type': args.work_type,
'Residence_type': args.Residence_type,
'avg_glucose_level': args.avg_glucose_level,
'bmi': args.bmi,
'smoking_status': args.smoking_status
}])
# 1. Preprocess
X = preprocessor.transform(patient)
# 2. Base model predictions
preds = []
for name, m in base_models.items():
p = m.predict_proba(X)[:, 1]
preds.append(p)
# 3. Meta prediction
meta_X = np.column_stack(preds)
final_prob = meta_model.predict_proba(meta_X)[:, 1][0]
print(f"Stroke Probability: {final_prob}")