DeepActionPotential commited on
Commit
da0d126
·
verified ·
1 Parent(s): fb88264

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo/strokeline_demo.mp4 filter=lfs diff=lfs merge=lfs -text
LICENCE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Eslam Tarek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,19 +1,142 @@
1
- ---
2
- title: StrokeLineAI
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Predicting a stroke based on medical features
12
- ---
13
-
14
- # Welcome to Streamlit!
15
-
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
-
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stroke Prediction Using Machine Learning
2
+
3
+ ## About the Project
4
+
5
+ This project provides a comprehensive machine learning pipeline for predicting the risk of stroke in individuals based on clinical and demographic features. The goal is to enable early identification of high-risk patients, supporting healthcare professionals in making informed decisions and potentially reducing stroke-related morbidity and mortality. The project covers the full data science workflow: data exploration, preprocessing, feature engineering, model selection, hyperparameter optimization, evaluation, explainability, and deployment. The final solution includes a trained model and a Streamlit web application for real-time inference.
6
+
7
+ ---
8
+
9
+ ## About the Dataset
10
+
11
+ The dataset used is the [Stroke Prediction Dataset](https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-datasett) from Kaggle. It contains 5110 records with 12 features and a binary target variable (`stroke`). The features include:
12
+
13
+ - **id**: Unique identifier (not used for modeling)
14
+ - **gender**: Patient gender (`Male`, `Female`, `Other`)
15
+ - **age**: Age in years
16
+ - **hypertension**: Hypertension status (`0`: No, `1`: Yes)
17
+ - **heart_disease**: Heart disease status (`0`: No, `1`: Yes)
18
+ - **ever_married**: Marital status (`Yes`, `No`)
19
+ - **work_type**: Type of work (`children`, `Govt_job`, `Never_worked`, `Private`, `Self-employed`)
20
+ - **Residence_type**: Living area (`Urban`, `Rural`)
21
+ - **avg_glucose_level**: Average glucose level
22
+ - **bmi**: Body mass index (may contain missing values)
23
+ - **smoking_status**: Smoking behavior (`formerly smoked`, `never smoked`, `smokes`, `Unknown`)
24
+ - **stroke**: Target variable (`1`: Stroke occurred, `0`: No stroke)
25
+
26
+ The dataset is imbalanced, with far fewer positive stroke cases than negatives, and contains missing values in the `bmi` column.
27
+
28
+ ---
29
+
30
+ ## Notebook Summary
31
+
32
+ The notebook documents the entire process:
33
+
34
+ 1. **Problem Definition**: Outlines the clinical motivation, dataset, and challenges.
35
+ 2. **EDA**: Visualizes distributions, checks for missing values, and explores feature-target relationships.
36
+ 3. **Feature Engineering**: Handles missing data, encodes categorical variables, and examines feature correlations.
37
+ 4. **Data Balancing**: Uses RandomUnderSampler and SMOTE to address class imbalance.
38
+ 5. **Model Selection**: Compares Random Forest, SVM, and XGBoost classifiers.
39
+ 6. **Hyperparameter Tuning**: Uses Optuna for automated optimization of XGBoost.
40
+ 7. **Evaluation**: Reports F1 score, confusion matrix, and classification report.
41
+ 8. **Explainability**: Applies SHAP for model interpretation.
42
+ 9. **Model Export**: Saves the trained model for deployment.
43
+
44
+ ---
45
+
46
+ ## Model Results
47
+
48
+ ### Preprocessing
49
+
50
+ - **Missing Values**: Imputed missing `bmi` values with the mean.
51
+ - **Categorical Encoding**: Used `OrdinalEncoder` to convert categorical features to numeric.
52
+ - **Feature Selection**: Dropped the `id` column and checked for highly correlated features.
53
+
54
+ ### Data Balancing
55
+
56
+ - **RandomUnderSampler**: Reduced the majority class to 10% of its original size.
57
+ - **SMOTE**: Oversampled the minority class to achieve a 1:1 ratio.
58
+
59
+ ### Training
60
+
61
+ - **Train-Test Split**: Stratified split to preserve class distribution.
62
+ - **Model Comparison**: Evaluated Random Forest, SVM, and XGBoost on balanced data.
63
+ - **Best Model**: XGBoost achieved the highest F1 score.
64
+
65
+ ### Hyperparameter Tuning
66
+
67
+ - **Optuna**: Ran 50 trials to optimize XGBoost hyperparameters (e.g., `n_estimators`, `max_depth`, `learning_rate`, `gamma`, etc.) using 5-fold cross-validation and F1 score as the metric.
68
+
69
+ ### Evaluation
70
+
71
+ - **F1 Score**: Achieved ~90% F1 score on the balanced test set.
72
+ - **Confusion Matrix**: Demonstrated balanced sensitivity and specificity.
73
+ - **Classification Report**: Provided detailed precision, recall, and F1 for each class.
74
+ - **Explainability**: SHAP analysis identified the most influential features and provided local/global interpretability.
75
+
76
+ ---
77
+
78
+ ## How to Install
79
+
80
+ Follow these steps to set up the project using a virtual environment:
81
+
82
+ ```bash
83
+ # Clone or download the repository
84
+ git clone https://github.com/DeepActionPotential/StrokeLineAI
85
+ cd StrokeLineAI
86
+
87
+ # Create a virtual environment
88
+ python -m venv venv
89
+
90
+ # Activate the virtual environment
91
+ # On Windows:
92
+ venv\Scripts\activate
93
+ # On macOS/Linux:
94
+ source venv/bin/activate
95
+
96
+ # Upgrade pip
97
+ pip install --upgrade pip
98
+
99
+ # Install dependencies
100
+ pip install -r requirements.txt
101
+ ```
102
+
103
+ ---
104
+
105
+ ## How to Use the Software
106
+
107
+ 1. **Run the Web Application**
108
+ Start the Streamlit app:
109
+
110
+ ```bash
111
+ streamlit run app.py
112
+ ```
113
+
114
+ 2. **Demo**
115
+ ## [demo-video](demo/strokeline_demo.mp4)
116
+ ![demo-screenshot](demo/strokeline_demo.jpeg))
117
+
118
+ ---
119
+
120
+ ## Technologies Used
121
+
122
+ ### Data Science & Model Training
123
+
124
+
125
+ - **matplotlib, seaborn**: Data visualization.
126
+ - **scikit-learn**: Preprocessing, model selection, metrics, and pipelines.
127
+ - **imbalanced-learn**: Advanced resampling (SMOTE, RandomUnderSampler) for class balancing.
128
+ - **XGBoost**: High-performance gradient boosting for classification.
129
+ - **Optuna**: Automated hyperparameter optimization.
130
+ - **SHAP**: Model explainability and feature importance analysis.
131
+
132
+ ### Deployment
133
+
134
+ - **Streamlit**: Rapid web app development for interactive model inference.
135
+ - **joblib**: Model serialization for deployment.
136
+
137
+ ---
138
+
139
+ ## License
140
+
141
+ This project is licensed under the MIT License.
142
+ See the [LICENSE](LICENSE) file for details.
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import joblib
3
+ from utils import preprocess_input, predict_stroke
4
+ from ui import input_form, display_result
5
+
6
+ @st.cache_resource
7
+ def load_model(path: str = "./models/model.pkl"):
8
+ """Load the trained classifier from disk."""
9
+ return joblib.load(path)
10
+
11
+
12
+ def local_css(file_name):
13
+ with open(file_name) as f:
14
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
15
+
16
+ local_css("styles.css")
17
+
18
+ def main():
19
+ st.title("Stroke Prediction Demo")
20
+ st.write("Enter patient metrics to predict stroke risk/type.")
21
+
22
+ # Get raw numeric inputs
23
+ data = input_form()
24
+
25
+ # Preprocess and predict
26
+ model = load_model()
27
+ X = preprocess_input(data)
28
+ label, proba = predict_stroke(model, X)
29
+
30
+ # Show result
31
+ display_result(label, proba)
32
+
33
+ if __name__ == "__main__":
34
+ main()
data/healthcare-dataset-stroke-data.csv ADDED
The diff for this file is too large to render. See raw diff
 
demo/strokeline_demo.jpeg ADDED
demo/strokeline_demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b35c7ccc4990cd87a60cf139ba0c628d36e91bc54dfefc7523e6a1f5b4ebafe3
3
+ size 2894759
models/model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca2bc023cf8a92424c7cb37655ec4bcbd60e69dc7f6d74fb1c0937eb14a597cd
3
+ size 676565
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
1
+ streamlit>=1.20.0
2
+ scikit-learn>=1.2.0
3
+ pandas>=1.5.0
4
+ numpy>=1.22.0
5
+ xgboost>=2.0.0
6
+ joblib>=1.2.0
7
+ xgboost>=2.1.0
run.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import subprocess
2
+
3
+ subprocess.run(['streamlit', 'run', 'app.py'])
stroke-prediction-using-smote-90-f1-score.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
styles.css ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Hide Streamlit default UI elements */
2
+ #MainMenu, header, footer {
3
+ visibility: hidden;
4
+ }
5
+
6
+ /* Full-screen center layout */
7
+ .stApp {
8
+ display: flex;
9
+ justify-content: center;
10
+ align-items: center;
11
+ min-height: 100vh;
12
+ margin: 10;
13
+ padding: 10;
14
+ }
15
+
16
+ /* Global dark theme base */
17
+ body {
18
+ background-color: #343541; /* ChatGPT dark gray */
19
+ color: #ececf1; /* Light neutral for text */
20
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
21
+ margin: 10;
22
+ padding: 10;
23
+ }
24
+
25
+ /* Container centering */
26
+ .centered-container {
27
+ display: flex;
28
+ align-items: center;
29
+ justify-content: center;
30
+ height: 100vh;
31
+ width: 100vw;
32
+ }
33
+
34
+ /* ChatGPT-style button */
35
+ .stButton > button {
36
+ background-color: #444654 !important;
37
+ color: #ececf1 !important;
38
+ border: 1px solid #5c5f72 !important;
39
+ border-radius: 999px !important;
40
+ padding: 0.5rem 1.25rem !important;
41
+ font-weight: 500;
42
+ transition: background-color 0.2s ease, transform 0.1s ease;
43
+ position: relative;
44
+ }
45
+
46
+ .stButton > button:hover {
47
+ background-color: #565869 !important;
48
+ transform: scale(1.03);
49
+ }
50
+
51
+ /* Sidebar styling */
52
+ [data-testid="stSidebar"] {
53
+ background-color: #202123;
54
+ color: #ececf1;
55
+ border-right: 1px solid #2d2f36;
56
+ min-width: 140px;
57
+ max-width: 250px;
58
+ transition: all 0.3s ease;
59
+ }
60
+
61
+ [data-testid="stSidebar"][aria-expanded="false"] {
62
+ margin-left: -250px;
63
+ }
64
+
65
+ [data-testid="stSidebar"] h1,
66
+ [data-testid="stSidebar"] h2,
67
+ [data-testid="stSidebar"] h3 {
68
+ color: #ececf1;
69
+ }
70
+
71
+ /* Markdown and text elements */
72
+ .stMarkdown, .stCaption, .stHeader {
73
+ color: #ececf1;
74
+ }
75
+
76
+ /* Dropdown styling */
77
+ select {
78
+ background-color: #3e3f4b;
79
+ color: #ececf1;
80
+ border: 1px solid #5c5f72;
81
+ border-radius: 6px;
82
+ padding: 6px 10px;
83
+ }
84
+
85
+ /* Selectbox refinements */
86
+ .stSelectbox {
87
+ cursor: pointer !important;
88
+ }
89
+ .stSelectbox input {
90
+ cursor: pointer !important;
91
+ caret-color: transparent !important;
92
+ }
93
+ .stSelectbox div[data-baseweb="select"] {
94
+ cursor: pointer !important;
95
+ }
96
+ .stSelectbox [role="option"] {
97
+ cursor: pointer !important;
98
+ }
99
+ .stSelectbox ::selection {
100
+ background: transparent !important;
101
+ }
102
+
103
+ /* General container */
104
+ .block-container {
105
+ padding: 15px !important;
106
+ margin: 15px !important;
107
+ max-width: 100% !important;
108
+ }
109
+
110
+ /* Progress bar */
111
+ .stProgress > div > div > div {
112
+ background-color: #10a37f !important; /* ChatGPT green */
113
+ }
114
+ .stProgress > div > div {
115
+ background-color: #3e3f4b !important;
116
+ height: 10px !important;
117
+ border-radius: 5px;
118
+ }
119
+
120
+ /* Loading or status text */
121
+ .st-emotion-cache-1q7spjk {
122
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
123
+ color: #ececf1 !important;
124
+ font-size: 1.1rem;
125
+ margin-bottom: 15px;
126
+ }
127
+
128
+ /* Optional animation (retained from your original) */
129
+ .rotate {
130
+ display: inline-block;
131
+ color: #10a37f;
132
+ animation: rotation 2s infinite linear;
133
+ }
134
+ @keyframes rotation {
135
+ from { transform: rotate(0deg); }
136
+ to { transform: rotate(359deg); }
137
+ }
138
+
139
+ /* Centered button containers */
140
+ .centered-button-container,
141
+ .button-container {
142
+ display: flex;
143
+ justify-content: center;
144
+ align-items: center;
145
+ text-align: center;
146
+ }
ui.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def input_form() -> dict:
4
+ """Collect numeric-encoded patient features via sidebar widgets."""
5
+ st.sidebar.header("Patient Information")
6
+
7
+ return {
8
+ "gender": st.sidebar.selectbox("Gender", [(0.0, "Male"), (1.0, "Female")])[0],
9
+ "age": st.sidebar.slider("Age", 0.0, 100.0, 50.0),
10
+ "hypertension": st.sidebar.selectbox("Hypertension", [(0, "No"), (1, "Yes")])[0],
11
+ "heart_disease": st.sidebar.selectbox("Heart Disease", [(0, "No"), (1, "Yes")])[0],
12
+ "ever_married": st.sidebar.selectbox("Ever Married", [(0.0, "No"), (1.0, "Yes")])[0],
13
+ "work_type": st.sidebar.selectbox(
14
+ "Work Type",
15
+ [(0.0, "Private"), (1.0, "Self-employed"), (2.0, "Govt_job"), (3.0, "children"), (4.0, "Never_worked")]
16
+ )[0],
17
+ "Residence_type": st.sidebar.selectbox(
18
+ "Residence Type", [(0.0, "Urban"), (1.0, "Rural")]
19
+ )[0],
20
+ "avg_glucose_level": st.sidebar.number_input("Avg Glucose Level", 40.0, 300.0, 100.0),
21
+ "bmi": st.sidebar.number_input("BMI", 10.0, 60.0, 25.0),
22
+ "smoking_status": st.sidebar.selectbox(
23
+ "Smoking Status",
24
+ [(0.0, "formerly smoked"), (1.0, "never smoked"), (2.0, "smokes"), (3.0, "Unknown")]
25
+ )[0]
26
+ }
27
+
28
+ def display_result(label: str, proba: float):
29
+ """Render prediction and confidence."""
30
+ st.header("Prediction Result")
31
+ st.markdown(f"**Stroke Type:** {label}")
32
+ st.markdown(f"**Confidence:** {proba:.1%}")
33
+ if proba < 0.5:
34
+ st.info("Model confidence is low — consider additional evaluation.")
utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ def preprocess_input(data: dict) -> pd.DataFrame:
4
+ """
5
+ Build a single-row DataFrame matching the training schema:
6
+ ['gender','age','hypertension','heart_disease','ever_married',
7
+ 'work_type','Residence_type','avg_glucose_level','bmi',
8
+ 'smoking_status']
9
+ """
10
+ # Note: 'stroke' column is not included as a feature
11
+ feature_cols = [
12
+ "gender","age","hypertension","heart_disease","ever_married",
13
+ "work_type","Residence_type","avg_glucose_level","bmi",
14
+ "smoking_status"
15
+ ]
16
+ df = pd.DataFrame([{k: data[k] for k in feature_cols}])
17
+ return df
18
+
19
+ def predict_stroke(model, X: pd.DataFrame):
20
+ """
21
+ Returns human-readable label and probability for the top class.
22
+ """
23
+ proba = model.predict_proba(X)[0]
24
+ idx = proba.argmax()
25
+ label_map = {0: "No Stroke", 1: "Stroke"}
26
+ return label_map[idx], proba[idx]