RealFishSam commited on
Commit
1cd24c3
·
verified ·
1 Parent(s): 4a1a806

Upload predict.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. predict.py +58 -0
predict.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ # Constants
8
+ REPO_ID = "RealFishSam/DVAE26-proj"
9
+ FILENAME = "stacked_ensemble_model.pkl"
10
+
11
+ # Check locations
12
+ possible_paths = [
13
+ FILENAME,
14
+ os.path.join('models', FILENAME),
15
+ os.path.join('..', 'models', FILENAME)
16
+ ]
17
+
18
+ model_path = None
19
+ for p in possible_paths:
20
+ if os.path.exists(p):
21
+ model_path = p
22
+ break
23
+
24
+ with open(model_path, 'rb') as f:
25
+ components = pickle.load(f)
26
+
27
+ preprocessor = components['preprocessor']
28
+ base_models = components['base_models']
29
+ meta_model = components['meta_model']
30
+ threshold = components.get('threshold_stacked', 0.5)
31
+
32
+ patient = pd.DataFrame([{
33
+ 'gender': 'Male', # one of ['Male', 'Female'] # Other was dropped
34
+ 'age': 75,
35
+ 'hypertension': 1,
36
+ 'heart_disease': 1,
37
+ 'ever_married': 'Yes', # one of ['Yes', 'No']
38
+ 'work_type': 'Private', # one of ['Private', 'Self-employed', 'Govt_job', 'Children', 'Never_worked']
39
+ 'Residence_type': 'Urban', # one of ['Urban', 'Rural']
40
+ 'avg_glucose_level': 220.5,
41
+ 'bmi': 30.1,
42
+ 'smoking_status': 'formerly smoked' # one of ['formerly smoked', 'never smoked', 'smokes']
43
+ }])
44
+
45
+ # 1. Preprocess
46
+ X = preprocessor.transform(patient)
47
+
48
+ # 2. Base model predictions
49
+ preds = []
50
+ for name, m in base_models.items():
51
+ p = m.predict_proba(X)[:, 1]
52
+ preds.append(p)
53
+
54
+ # 3. Meta prediction
55
+ meta_X = np.column_stack(preds)
56
+ final_prob = meta_model.predict_proba(meta_X)[:, 1][0]
57
+
58
+ print(f"Stroke Probability: {final_prob}")