yonghan93 commited on
Commit
50334f0
·
verified ·
1 Parent(s): faa226c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -46
app.py CHANGED
@@ -1,5 +1,12 @@
 
 
1
  import torch.utils.data as _tud
2
  from pytorch_tabular.tabular_datamodule import TabularDatamodule
 
 
 
 
 
3
 
4
  _OriginalDataLoader = _tud.DataLoader
5
  class SafeDataLoader(_OriginalDataLoader):
@@ -43,49 +50,81 @@ with open("FTTransformerModel/threshold.json", "r") as f:
43
 
44
 
45
  # User input
46
- st.title("Cardiovascular Disease Risk Prediction")
47
- st.write("Please enter the following health information:")
48
-
49
- age = st.number_input("Age", min_value=18, max_value=100, value=50)
50
- height = st.number_input("Height (cm)", min_value=100, max_value=250, value=170)
51
- weight = st.number_input("Weight (kg)", min_value=30, max_value=200, value=70)
52
- systolic = st.number_input("Systolic BP", min_value=80, max_value=250, value=120)
53
- diastolic = st.number_input("Diastolic BP", min_value=40, max_value=150, value=80)
54
- cholesterol = st.selectbox("Cholesterol", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
55
- gluc = st.selectbox("Glucose", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
56
- gender = st.selectbox("Gender", [0, 1], format_func=lambda x: {0: "Female", 1: "Male"}[x])
57
- smoke = st.selectbox("Do you smoke?", [0, 1], format_func=lambda x: "Yes" if x else "No")
58
- alco = st.selectbox("Do you drink alcohol?", [0, 1], format_func=lambda x: "Yes" if x else "No")
59
- active = st.selectbox("Are you physically active?", [0, 1], format_func=lambda x: "Yes" if x else "No")
60
-
61
- input_data = pd.DataFrame([{
62
- "age": age * 365,
63
- "height": height,
64
- "weight": weight,
65
- "ap_hi": systolic,
66
- "ap_lo": diastolic,
67
- "cholesterol": cholesterol,
68
- "gluc": gluc,
69
- "gender": gender,
70
- "smoke": smoke,
71
- "alco": alco,
72
- "active": active
73
- }])
74
-
75
-
76
- input_data["bmi"] = input_data["weight"] / ((input_data["height"]/100)**2)
77
-
78
- input_data["pulse_pressure"] = input_data["ap_hi"] - input_data["ap_lo"]
79
-
80
- input_data["hypertension"] = (
81
- (input_data["ap_hi"] > 140) | (input_data["ap_lo"] > 90)
82
- ).astype(int)
83
-
84
- if st.button("Predict CVD Risk"):
85
- preds = model.predict(input_data)
86
- # 直接取 CVD(label=1)的概率
87
- proba = preds["cardio_1_probability"].iloc[0]
88
-
89
- result = "❌ At Risk of CVD" if proba >= threshold else " Not at Risk"
90
- st.subheader("Prediction Result")
91
- st.write(f"**{result}** (Probability: {proba:.2%}, Threshold: {threshold:.2f})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip install streamlit-option-menu
2
+
3
  import torch.utils.data as _tud
4
  from pytorch_tabular.tabular_datamodule import TabularDatamodule
5
+ from streamlit_option_menu import option_menu
6
+ import google.generativeai as genai
7
+ import os
8
+
9
+ genai.configure(api_key=os.getenv("AIzaSyD7x66J1Bvs3CclxROPsC6QWnXcQiBGYBg"))
10
 
11
  _OriginalDataLoader = _tud.DataLoader
12
  class SafeDataLoader(_OriginalDataLoader):
 
50
 
51
 
52
  # User input
53
+ with st.sidebar:
54
+ selected = option_menu(
55
+ menu_title="Navigation",
56
+ options=["Prediction", "CVD Knowledge"],
57
+ icons=["activity", "book"],
58
+ menu_icon="cast",
59
+ default_index=0,
60
+ )
61
+
62
+ if selected == "Prediction":
63
+ col1, col2, col3 = st.columns([2, 1, 2])
64
+
65
+ with col1:
66
+ st.title("Cardiovascular Disease Risk Prediction")
67
+ st.write("Please enter the following health information:")
68
+
69
+ age = st.number_input("Age", min_value=18, max_value=100, value=50)
70
+ height = st.number_input("Height (cm)", min_value=100, max_value=250, value=170)
71
+ weight = st.number_input("Weight (kg)", min_value=30, max_value=200, value=70)
72
+ systolic = st.number_input("Systolic BP", min_value=80, max_value=250, value=120)
73
+ diastolic = st.number_input("Diastolic BP", min_value=40, max_value=150, value=80)
74
+ cholesterol = st.selectbox("Cholesterol", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
75
+ gluc = st.selectbox("Glucose", [0, 1, 2], format_func=lambda x: {0: "Normal", 1: "High", 2: "Very High"}[x])
76
+ gender = st.selectbox("Gender", [0, 1], format_func=lambda x: {0: "Female", 1: "Male"}[x])
77
+ smoke = st.selectbox("Do you smoke?", [0, 1], format_func=lambda x: "Yes" if x else "No")
78
+ alco = st.selectbox("Do you drink alcohol?", [0, 1], format_func=lambda x: "Yes" if x else "No")
79
+ active = st.selectbox("Are you physically active?", [0, 1], format_func=lambda x: "Yes" if x else "No")
80
+
81
+ input_data = pd.DataFrame([{
82
+ "age": age * 365,
83
+ "height": height,
84
+ "weight": weight,
85
+ "ap_hi": systolic,
86
+ "ap_lo": diastolic,
87
+ "cholesterol": cholesterol,
88
+ "gluc": gluc,
89
+ "gender": gender,
90
+ "smoke": smoke,
91
+ "alco": alco,
92
+ "active": active
93
+ }])
94
+
95
+ input_data["bmi"] = input_data["weight"] / ((input_data["height"]/100)**2)
96
+ input_data["pulse_pressure"] = input_data["ap_hi"] - input_data["ap_lo"]
97
+ input_data["hypertension"] = ((input_data["ap_hi"] > 140) | (input_data["ap_lo"] > 90)).astype(int)
98
+
99
+ if st.button("Predict CVD Risk"):
100
+ preds = model.predict(input_data)
101
+ proba = preds["cardio_1_probability"].iloc[0]
102
+ result = "❌ At Risk of CVD" if proba >= threshold else "✅ Not at Risk"
103
+
104
+ with col2:
105
+ st.markdown("### 🎯 Prediction Result")
106
+ st.markdown(f"**<div style='text-align: center; font-size: 24px'>{result}</div>**", unsafe_allow_html=True)
107
+ st.write(f"(Probability: {proba:.2%}, Threshold: {threshold:.2f})")
108
+
109
+ st.markdown("---")
110
+ st.markdown("### 🤖 Gemini Suggestion")
111
+ instruction = st.text_area("Custom Gemini Instruction", value="Give detailed lifestyle advice for this risk level")
112
+
113
+ if st.button("Get Gemini Advice"):
114
+ prompt = f"{instruction}. The predicted CVD risk is: {result}. Probability: {proba:.2%}."
115
+ model_gemini = genai.GenerativeModel("gemini-pro")
116
+ response = model_gemini.generate_content(prompt)
117
+ st.markdown("### Gemini's Advice:")
118
+ st.write(response.text)
119
+
120
+ elif selected == "CVD Knowledge":
121
+ st.title("Understanding Cardiovascular Disease")
122
+ st.markdown("""
123
+ - **What is CVD?** A group of heart and blood vessel disorders like heart attack and stroke.
124
+ - **Major Risk Factors:** Age, high blood pressure, high cholesterol, smoking, inactivity, obesity.
125
+ - **Prevention Tips:**
126
+ - Exercise regularly (30 mins/day)
127
+ - Eat low-sodium and low-fat diets
128
+ - Avoid tobacco and limit alcohol
129
+ - Monitor blood pressure and sugar
130
+ """)