yonghan93 commited on
Commit
6519728
·
verified ·
1 Parent(s): c770b25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -5,23 +5,15 @@ import torch
5
  import json
6
  import pickle
7
  from pytorch_tabular import TabularModel
8
- from sklearn.preprocessing import LabelEncoder
9
 
10
  # Load model
11
  model = TabularModel.load_model("FTTransformerModel")
12
 
13
- # Wrap datamodule.label_encoder into a list if it's a single LabelEncoder
14
- dm = model.datamodule
15
- if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):
16
- dm.label_encoder = [dm.label_encoder]
17
-
18
  # Load threshold
19
  with open("FTTransformerModel/threshold.json", "r") as f:
20
  threshold = json.load(f)["threshold"]
21
 
22
- # Load encoders
23
- with open("FTTransformerModel/encoder.pkl", "rb") as f:
24
- label_encoders = pickle.load(f)
25
 
26
  # User input
27
  st.title("Cardiovascular Disease Risk Prediction")
@@ -32,8 +24,8 @@ height = st.number_input("Height (cm)", min_value=100, max_value=250, value=170)
32
  weight = st.number_input("Weight (kg)", min_value=30, max_value=200, value=70)
33
  systolic = st.number_input("Systolic BP", min_value=80, max_value=250, value=120)
34
  diastolic = st.number_input("Diastolic BP", min_value=40, max_value=150, value=80)
35
- cholesterol = st.selectbox("Cholesterol", [1, 2, 3], format_func=lambda x: {1: "Normal", 2: "High", 3: "Very High"}[x])
36
- gluc = st.selectbox("Glucose", [1, 2, 3], format_func=lambda x: {1: "Normal", 2: "High", 3: "Very High"}[x])
37
  gender = st.selectbox("Gender", [0, 1], format_func=lambda x: {0: "Female", 1: "Male"}[x])
38
  smoke = st.selectbox("Do you smoke?", [0, 1], format_func=lambda x: "Yes" if x else "No")
39
  alco = st.selectbox("Do you drink alcohol?", [0, 1], format_func=lambda x: "Yes" if x else "No")
@@ -53,10 +45,14 @@ input_data = pd.DataFrame([{
53
  "active": active
54
  }])
55
 
56
- categorical_cols = ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']
57
- for col in categorical_cols:
58
- le = label_encoders[col]
59
- input_data[col] = le.transform(input_data[col])
 
 
 
 
60
 
61
  if st.button("Predict CVD Risk"):
62
  preds = model.predict(input_data)
 
5
  import json
6
  import pickle
7
  from pytorch_tabular import TabularModel
8
+
9
 
10
  # Load model
11
  model = TabularModel.load_model("FTTransformerModel")
12
 
 
 
 
 
 
13
  # Load threshold
14
  with open("FTTransformerModel/threshold.json", "r") as f:
15
  threshold = json.load(f)["threshold"]
16
 
 
 
 
17
 
18
  # User input
19
  st.title("Cardiovascular Disease Risk Prediction")
 
24
  weight = st.number_input("Weight (kg)", min_value=30, max_value=200, value=70)
25
  systolic = st.number_input("Systolic BP", min_value=80, max_value=250, value=120)
26
  diastolic = st.number_input("Diastolic BP", min_value=40, max_value=150, value=80)
27
+ cholesterol = st.selectbox("Cholesterol", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
28
+ gluc = st.selectbox("Glucose", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
29
  gender = st.selectbox("Gender", [0, 1], format_func=lambda x: {0: "Female", 1: "Male"}[x])
30
  smoke = st.selectbox("Do you smoke?", [0, 1], format_func=lambda x: "Yes" if x else "No")
31
  alco = st.selectbox("Do you drink alcohol?", [0, 1], format_func=lambda x: "Yes" if x else "No")
 
45
  "active": active
46
  }])
47
 
48
+
49
+ input_data["bmi"] = input_data["weight"] / ((input_data["height"]/100)**2)
50
+
51
+ input_data["pulse_pressure"] = input_data["ap_hi"] - input_data["ap_lo"]
52
+
53
+ input_data["hypertension"] = (
54
+ (input_data["ap_hi"] > 140) | (input_data["ap_lo"] > 90)
55
+ ).astype(int)
56
 
57
  if st.button("Predict CVD Risk"):
58
  preds = model.predict(input_data)