Jotheeswaransakthivel commited on
Commit
254fd25
·
verified ·
1 Parent(s): b9593c3

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +56 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,59 @@
1
- import altair as alt
 
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import joblib
3
  import numpy as np
4
  import pandas as pd
 
5
 
6
+ # ----------------------------
7
+ # Load Saved Model & Encoders
8
+ # ----------------------------
9
+ model = joblib.load("best_stroke_model.pkl")
10
+ x_num_scaler = joblib.load("x_num_scaler.pkl")
11
+ x_cat_encoder = joblib.load("x_cat_encoder.pkl")
12
+ x_cat_ordinal_encoder = joblib.load("x_cat_ordinal_encoder.pkl")
13
+
14
+ # ----------------------------
15
+ # Streamlit App Layout
16
+ # ----------------------------
17
+ st.set_page_config(page_title="Stroke Prediction App", page_icon="🧠", layout="centered")
18
+ st.title("🧠 Stroke Prediction App")
19
+ st.write("Fill in the details below to check the risk of stroke.")
20
+
21
+ # ----------------------------
22
+ # Input Fields
23
+ # ----------------------------
24
+ gender = st.selectbox("Gender", ["Male", "Female", "Other"])
25
+ age = st.number_input("Age", min_value=1, max_value=120, value=30)
26
+ hypertension = st.selectbox("Hypertension (0=No, 1=Yes)", [0, 1])
27
+ heart_disease = st.selectbox("Heart Disease (0=No, 1=Yes)", [0, 1])
28
+ work_type = st.selectbox("Work Type", ["Private", "Self-employed", "Govt_job", "children", "Never_worked"])
29
+ avg_glucose_level = st.number_input("Average Glucose Level", min_value=50.0, max_value=300.0, value=100.0)
30
+ bmi = st.number_input("BMI", min_value=10.0, max_value=50.0, value=22.0)
31
+ smoking_status = st.selectbox("Smoking Status", ["never smoked", "formerly smoked", "smokes", "Unknown"])
32
+
33
+ # ----------------------------
34
+ # Prediction Button
35
+ # ----------------------------
36
+ if st.button("Predict"):
37
+ # Convert to DataFrame for transformation
38
+ input_df = pd.DataFrame([[gender, age, hypertension, heart_disease, work_type, avg_glucose_level, bmi, smoking_status]],
39
+ columns=["gender", "age", "hypertension", "heart_disease", "work_type", "avg_glucose_level", "bmi", "smoking_status"])
40
+
41
+ # Transform numeric columns
42
+ scaled_x_num = x_num_scaler.transform(input_df[["age", "avg_glucose_level", "bmi"]])
43
+
44
+ # Binary values (no transformation needed)
45
+ binary_x = input_df[["hypertension", "heart_disease"]].values
46
+
47
+ # Encode categorical columns
48
+ encoded_x_cat = x_cat_encoder.transform(input_df[["work_type", "smoking_status"]])
49
+ ordinal_encoded_x_cat = x_cat_ordinal_encoder.transform(input_df[["gender"]])
50
+
51
+ # Combine all into final input
52
+ final_x = np.hstack([scaled_x_num, binary_x, encoded_x_cat, ordinal_encoded_x_cat])
53
+
54
+ # Prediction
55
+ prediction = model.predict(final_x)
56
+ result = "⚠️ High Risk of Stroke" if prediction[0] == 1 else "✅ Low Risk of Stroke"
57
+
58
+ st.success(f"**Prediction:** {result}")
59
+ st.info("This prediction is based on machine learning and should not replace professional medical advice.")