Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,23 +5,15 @@ import torch
|
|
| 5 |
import json
|
| 6 |
import pickle
|
| 7 |
from pytorch_tabular import TabularModel
|
| 8 |
-
|
| 9 |
|
| 10 |
# Load model
|
| 11 |
model = TabularModel.load_model("FTTransformerModel")
|
| 12 |
|
| 13 |
-
# Wrap datamodule.label_encoder into a list if it's a single LabelEncoder
|
| 14 |
-
dm = model.datamodule
|
| 15 |
-
if hasattr(dm, "label_encoder") and isinstance(dm.label_encoder, LabelEncoder):
|
| 16 |
-
dm.label_encoder = [dm.label_encoder]
|
| 17 |
-
|
| 18 |
# Load threshold
|
| 19 |
with open("FTTransformerModel/threshold.json", "r") as f:
|
| 20 |
threshold = json.load(f)["threshold"]
|
| 21 |
|
| 22 |
-
# Load encoders
|
| 23 |
-
with open("FTTransformerModel/encoder.pkl", "rb") as f:
|
| 24 |
-
label_encoders = pickle.load(f)
|
| 25 |
|
| 26 |
# User input
|
| 27 |
st.title("Cardiovascular Disease Risk Prediction")
|
|
@@ -32,8 +24,8 @@ height = st.number_input("Height (cm)", min_value=100, max_value=250, value=170)
|
|
| 32 |
weight = st.number_input("Weight (kg)", min_value=30, max_value=200, value=70)
|
| 33 |
systolic = st.number_input("Systolic BP", min_value=80, max_value=250, value=120)
|
| 34 |
diastolic = st.number_input("Diastolic BP", min_value=40, max_value=150, value=80)
|
| 35 |
-
cholesterol = st.selectbox("Cholesterol", [
|
| 36 |
-
gluc = st.selectbox("Glucose", [
|
| 37 |
gender = st.selectbox("Gender", [0, 1], format_func=lambda x: {0: "Female", 1: "Male"}[x])
|
| 38 |
smoke = st.selectbox("Do you smoke?", [0, 1], format_func=lambda x: "Yes" if x else "No")
|
| 39 |
alco = st.selectbox("Do you drink alcohol?", [0, 1], format_func=lambda x: "Yes" if x else "No")
|
|
@@ -53,10 +45,14 @@ input_data = pd.DataFrame([{
|
|
| 53 |
"active": active
|
| 54 |
}])
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
if st.button("Predict CVD Risk"):
|
| 62 |
preds = model.predict(input_data)
|
|
|
|
| 5 |
import json
|
| 6 |
import pickle
|
| 7 |
from pytorch_tabular import TabularModel
|
| 8 |
+
|
| 9 |
|
| 10 |
# Load model
|
| 11 |
model = TabularModel.load_model("FTTransformerModel")
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
# Load threshold
|
| 14 |
with open("FTTransformerModel/threshold.json", "r") as f:
|
| 15 |
threshold = json.load(f)["threshold"]
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# User input
|
| 19 |
st.title("Cardiovascular Disease Risk Prediction")
|
|
|
|
| 24 |
weight = st.number_input("Weight (kg)", min_value=30, max_value=200, value=70)
|
| 25 |
systolic = st.number_input("Systolic BP", min_value=80, max_value=250, value=120)
|
| 26 |
diastolic = st.number_input("Diastolic BP", min_value=40, max_value=150, value=80)
|
| 27 |
+
cholesterol = st.selectbox("Cholesterol", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
|
| 28 |
+
gluc = st.selectbox("Glucose", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
|
| 29 |
gender = st.selectbox("Gender", [0, 1], format_func=lambda x: {0: "Female", 1: "Male"}[x])
|
| 30 |
smoke = st.selectbox("Do you smoke?", [0, 1], format_func=lambda x: "Yes" if x else "No")
|
| 31 |
alco = st.selectbox("Do you drink alcohol?", [0, 1], format_func=lambda x: "Yes" if x else "No")
|
|
|
|
| 45 |
"active": active
|
| 46 |
}])
|
| 47 |
|
| 48 |
+
|
| 49 |
+
input_data["bmi"] = input_data["weight"] / ((input_data["height"]/100)**2)
|
| 50 |
+
|
| 51 |
+
input_data["pulse_pressure"] = input_data["ap_hi"] - input_data["ap_lo"]
|
| 52 |
+
|
| 53 |
+
input_data["hypertension"] = (
|
| 54 |
+
(input_data["ap_hi"] > 140) | (input_data["ap_lo"] > 90)
|
| 55 |
+
).astype(int)
|
| 56 |
|
| 57 |
if st.button("Predict CVD Risk"):
|
| 58 |
preds = model.predict(input_data)
|