File size: 5,605 Bytes
247c16d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
---
title: StrokeLine - Stroke Prediction Using Machine Learning
emoji: 🤖
colorFrom: indigo
colorTo: blue
sdk: streamlit
sdk_version: 1.30.0
app_file: app.py
pinned: false
license: mit
---


# Stroke Prediction Using Machine Learning



## About the Project

This project provides a comprehensive machine learning pipeline for predicting the risk of stroke in individuals based on clinical and demographic features. The goal is to enable early identification of high-risk patients, supporting healthcare professionals in making informed decisions and potentially reducing stroke-related morbidity and mortality. The project covers the full data science workflow: data exploration, preprocessing, feature engineering, model selection, hyperparameter optimization, evaluation, explainability, and deployment. The final solution includes a trained model and a Streamlit web application for real-time inference.

---

## About the Dataset

The dataset used is the [Stroke Prediction Dataset](https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-datasett) from Kaggle. It contains 5110 records with 12 features and a binary target variable (`stroke`). The features include:

- **id**: Unique identifier (not used for modeling)
- **gender**: Patient gender (`Male`, `Female`, `Other`)
- **age**: Age in years
- **hypertension**: Hypertension status (`0`: No, `1`: Yes)
- **heart_disease**: Heart disease status (`0`: No, `1`: Yes)
- **ever_married**: Marital status (`Yes`, `No`)
- **work_type**: Type of work (`children`, `Govt_job`, `Never_worked`, `Private`, `Self-employed`)
- **Residence_type**: Living area (`Urban`, `Rural`)
- **avg_glucose_level**: Average glucose level
- **bmi**: Body mass index (may contain missing values)
- **smoking_status**: Smoking behavior (`formerly smoked`, `never smoked`, `smokes`, `Unknown`)
- **stroke**: Target variable (`1`: Stroke occurred, `0`: No stroke)

The dataset is imbalanced, with far fewer positive stroke cases than negatives, and contains missing values in the `bmi` column.

---

## Notebook Summary

The notebook documents the entire process:

1. **Problem Definition**: Outlines the clinical motivation, dataset, and challenges.
2. **EDA**: Visualizes distributions, checks for missing values, and explores feature-target relationships.
3. **Feature Engineering**: Handles missing data, encodes categorical variables, and examines feature correlations.
4. **Data Balancing**: Uses RandomUnderSampler and SMOTE to address class imbalance.
5. **Model Selection**: Compares Random Forest, SVM, and XGBoost classifiers.
6. **Hyperparameter Tuning**: Uses Optuna for automated optimization of XGBoost.
7. **Evaluation**: Reports F1 score, confusion matrix, and classification report.
8. **Explainability**: Applies SHAP for model interpretation.
9. **Model Export**: Saves the trained model for deployment.

---

## Model Results

### Preprocessing

- **Missing Values**: Imputed missing `bmi` values with the mean.
- **Categorical Encoding**: Used `OrdinalEncoder` to convert categorical features to numeric.
- **Feature Selection**: Dropped the `id` column and checked for highly correlated features.

### Data Balancing

- **RandomUnderSampler**: Reduced the majority class to 10% of its original size.
- **SMOTE**: Oversampled the minority class to achieve a 1:1 ratio.

### Training

- **Train-Test Split**: Stratified split to preserve class distribution.
- **Model Comparison**: Evaluated Random Forest, SVM, and XGBoost on balanced data.
- **Best Model**: XGBoost achieved the highest F1 score.

### Hyperparameter Tuning

- **Optuna**: Ran 50 trials to optimize XGBoost hyperparameters (e.g., `n_estimators`, `max_depth`, `learning_rate`, `gamma`, etc.) using 5-fold cross-validation and F1 score as the metric.

### Evaluation

- **F1 Score**: Achieved ~90% F1 score on the balanced test set.
- **Confusion Matrix**: Demonstrated balanced sensitivity and specificity.
- **Classification Report**: Provided detailed precision, recall, and F1 for each class.
- **Explainability**: SHAP analysis identified the most influential features and provided local/global interpretability.

---

## How to Install

Follow these steps to set up the project using a virtual environment:

```bash
# Clone or download the repository
git clone https://github.com/DeepActionPotential/StrokeLineAI
cd StrokeLineAI

# Create a virtual environment
python -m venv venv

# Activate the virtual environment
# On Windows:
venv\Scripts\activate
# On macOS/Linux:
source venv/bin/activate

# Upgrade pip
pip install --upgrade pip

# Install dependencies
pip install -r requirements.txt
```

---

## How to Use the Software

1. **Run the Web Application**  
   Start the Streamlit app:

   ```bash
   streamlit run app.py
   ```

2. **Demo**  
   ## [demo-video](demo/strokeline_demo.mp4)
   ![demo-screenshot](demo/strokeline_demo.jpeg))

---

## Technologies Used

### Data Science & Model Training


- **matplotlib, seaborn**: Data visualization.
- **scikit-learn**: Preprocessing, model selection, metrics, and pipelines.
- **imbalanced-learn**: Advanced resampling (SMOTE, RandomUnderSampler) for class balancing.
- **XGBoost**: High-performance gradient boosting for classification.
- **Optuna**: Automated hyperparameter optimization.
- **SHAP**: Model explainability and feature importance analysis.

### Deployment

- **Streamlit**: Rapid web app development for interactive model inference.
- **joblib**: Model serialization for deployment.

---

## License

This project is licensed under the MIT License.  
See the [LICENSE](LICENSE) file for details.