Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
|
|
|
|
|
|
| 1 |
import torch.utils.data as _tud
|
| 2 |
from pytorch_tabular.tabular_datamodule import TabularDatamodule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
_OriginalDataLoader = _tud.DataLoader
|
| 5 |
class SafeDataLoader(_OriginalDataLoader):
|
|
@@ -43,49 +50,81 @@ with open("FTTransformerModel/threshold.json", "r") as f:
|
|
| 43 |
|
| 44 |
|
| 45 |
# User input
|
| 46 |
-
st.
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pip install streamlit-option-menu
|
| 2 |
+
|
| 3 |
import torch.utils.data as _tud
|
| 4 |
from pytorch_tabular.tabular_datamodule import TabularDatamodule
|
| 5 |
+
from streamlit_option_menu import option_menu
|
| 6 |
+
import google.generativeai as genai
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
genai.configure(api_key=os.getenv("AIzaSyD7x66J1Bvs3CclxROPsC6QWnXcQiBGYBg"))
|
| 10 |
|
| 11 |
_OriginalDataLoader = _tud.DataLoader
|
| 12 |
class SafeDataLoader(_OriginalDataLoader):
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
# User input
|
| 53 |
+
with st.sidebar:
|
| 54 |
+
selected = option_menu(
|
| 55 |
+
menu_title="Navigation",
|
| 56 |
+
options=["Prediction", "CVD Knowledge"],
|
| 57 |
+
icons=["activity", "book"],
|
| 58 |
+
menu_icon="cast",
|
| 59 |
+
default_index=0,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if selected == "Prediction":
|
| 63 |
+
col1, col2, col3 = st.columns([2, 1, 2])
|
| 64 |
+
|
| 65 |
+
with col1:
|
| 66 |
+
st.title("Cardiovascular Disease Risk Prediction")
|
| 67 |
+
st.write("Please enter the following health information:")
|
| 68 |
+
|
| 69 |
+
age = st.number_input("Age", min_value=18, max_value=100, value=50)
|
| 70 |
+
height = st.number_input("Height (cm)", min_value=100, max_value=250, value=170)
|
| 71 |
+
weight = st.number_input("Weight (kg)", min_value=30, max_value=200, value=70)
|
| 72 |
+
systolic = st.number_input("Systolic BP", min_value=80, max_value=250, value=120)
|
| 73 |
+
diastolic = st.number_input("Diastolic BP", min_value=40, max_value=150, value=80)
|
| 74 |
+
cholesterol = st.selectbox("Cholesterol", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
|
| 75 |
+
gluc = st.selectbox("Glucose", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
|
| 76 |
+
gender = st.selectbox("Gender", [0, 1], format_func=lambda x: {0: "Female", 1: "Male"}[x])
|
| 77 |
+
smoke = st.selectbox("Do you smoke?", [0, 1], format_func=lambda x: "Yes" if x else "No")
|
| 78 |
+
alco = st.selectbox("Do you drink alcohol?", [0, 1], format_func=lambda x: "Yes" if x else "No")
|
| 79 |
+
active = st.selectbox("Are you physically active?", [0, 1], format_func=lambda x: "Yes" if x else "No")
|
| 80 |
+
|
| 81 |
+
input_data = pd.DataFrame([{
|
| 82 |
+
"age": age * 365,
|
| 83 |
+
"height": height,
|
| 84 |
+
"weight": weight,
|
| 85 |
+
"ap_hi": systolic,
|
| 86 |
+
"ap_lo": diastolic,
|
| 87 |
+
"cholesterol": cholesterol,
|
| 88 |
+
"gluc": gluc,
|
| 89 |
+
"gender": gender,
|
| 90 |
+
"smoke": smoke,
|
| 91 |
+
"alco": alco,
|
| 92 |
+
"active": active
|
| 93 |
+
}])
|
| 94 |
+
|
| 95 |
+
input_data["bmi"] = input_data["weight"] / ((input_data["height"]/100)**2)
|
| 96 |
+
input_data["pulse_pressure"] = input_data["ap_hi"] - input_data["ap_lo"]
|
| 97 |
+
input_data["hypertension"] = ((input_data["ap_hi"] > 140) | (input_data["ap_lo"] > 90)).astype(int)
|
| 98 |
+
|
| 99 |
+
if st.button("Predict CVD Risk"):
|
| 100 |
+
preds = model.predict(input_data)
|
| 101 |
+
proba = preds["cardio_1_probability"].iloc[0]
|
| 102 |
+
result = "❌ At Risk of CVD" if proba >= threshold else "✅ Not at Risk"
|
| 103 |
+
|
| 104 |
+
with col2:
|
| 105 |
+
st.markdown("### 🎯 Prediction Result")
|
| 106 |
+
st.markdown(f"**<div style='text-align: center; font-size: 24px'>{result}</div>**", unsafe_allow_html=True)
|
| 107 |
+
st.write(f"(Probability: {proba:.2%}, Threshold: {threshold:.2f})")
|
| 108 |
+
|
| 109 |
+
st.markdown("---")
|
| 110 |
+
st.markdown("### 🤖 Gemini Suggestion")
|
| 111 |
+
instruction = st.text_area("Custom Gemini Instruction", value="Give detailed lifestyle advice for this risk level")
|
| 112 |
+
|
| 113 |
+
if st.button("Get Gemini Advice"):
|
| 114 |
+
prompt = f"{instruction}. The predicted CVD risk is: {result}. Probability: {proba:.2%}."
|
| 115 |
+
model_gemini = genai.GenerativeModel("gemini-pro")
|
| 116 |
+
response = model_gemini.generate_content(prompt)
|
| 117 |
+
st.markdown("### Gemini's Advice:")
|
| 118 |
+
st.write(response.text)
|
| 119 |
+
|
| 120 |
+
elif selected == "CVD Knowledge":
|
| 121 |
+
st.title("Understanding Cardiovascular Disease")
|
| 122 |
+
st.markdown("""
|
| 123 |
+
- **What is CVD?** A group of heart and blood vessel disorders like heart attack and stroke.
|
| 124 |
+
- **Major Risk Factors:** Age, high blood pressure, high cholesterol, smoking, inactivity, obesity.
|
| 125 |
+
- **Prevention Tips:**
|
| 126 |
+
- Exercise regularly (30 mins/day)
|
| 127 |
+
- Eat low-sodium and low-fat diets
|
| 128 |
+
- Avoid tobacco and limit alcohol
|
| 129 |
+
- Monitor blood pressure and sugar
|
| 130 |
+
""")
|