Spaces:
Sleeping
Sleeping
File size: 5,605 Bytes
247c16d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
---
title: StrokeLine - Stroke Prediction Using Machine Learning
emoji: 🤖
colorFrom: indigo
colorTo: blue
sdk: streamlit
sdk_version: 1.30.0
app_file: app.py
pinned: false
license: mit
---
# Stroke Prediction Using Machine Learning
## About the Project
This project provides a comprehensive machine learning pipeline for predicting the risk of stroke in individuals based on clinical and demographic features. The goal is to enable early identification of high-risk patients, supporting healthcare professionals in making informed decisions and potentially reducing stroke-related morbidity and mortality. The project covers the full data science workflow: data exploration, preprocessing, feature engineering, model selection, hyperparameter optimization, evaluation, explainability, and deployment. The final solution includes a trained model and a Streamlit web application for real-time inference.
---
## About the Dataset
The dataset used is the [Stroke Prediction Dataset](https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-datasett) from Kaggle. It contains 5110 records with 12 features and a binary target variable (`stroke`). The features include:
- **id**: Unique identifier (not used for modeling)
- **gender**: Patient gender (`Male`, `Female`, `Other`)
- **age**: Age in years
- **hypertension**: Hypertension status (`0`: No, `1`: Yes)
- **heart_disease**: Heart disease status (`0`: No, `1`: Yes)
- **ever_married**: Marital status (`Yes`, `No`)
- **work_type**: Type of work (`children`, `Govt_job`, `Never_worked`, `Private`, `Self-employed`)
- **Residence_type**: Living area (`Urban`, `Rural`)
- **avg_glucose_level**: Average glucose level
- **bmi**: Body mass index (may contain missing values)
- **smoking_status**: Smoking behavior (`formerly smoked`, `never smoked`, `smokes`, `Unknown`)
- **stroke**: Target variable (`1`: Stroke occurred, `0`: No stroke)
The dataset is imbalanced, with far fewer positive stroke cases than negatives, and contains missing values in the `bmi` column.
---
## Notebook Summary
The notebook documents the entire process:
1. **Problem Definition**: Outlines the clinical motivation, dataset, and challenges.
2. **EDA**: Visualizes distributions, checks for missing values, and explores feature-target relationships.
3. **Feature Engineering**: Handles missing data, encodes categorical variables, and examines feature correlations.
4. **Data Balancing**: Uses RandomUnderSampler and SMOTE to address class imbalance.
5. **Model Selection**: Compares Random Forest, SVM, and XGBoost classifiers.
6. **Hyperparameter Tuning**: Uses Optuna for automated optimization of XGBoost.
7. **Evaluation**: Reports F1 score, confusion matrix, and classification report.
8. **Explainability**: Applies SHAP for model interpretation.
9. **Model Export**: Saves the trained model for deployment.
---
## Model Results
### Preprocessing
- **Missing Values**: Imputed missing `bmi` values with the mean.
- **Categorical Encoding**: Used `OrdinalEncoder` to convert categorical features to numeric.
- **Feature Selection**: Dropped the `id` column and checked for highly correlated features.
### Data Balancing
- **RandomUnderSampler**: Reduced the majority class to 10% of its original size.
- **SMOTE**: Oversampled the minority class to achieve a 1:1 ratio.
### Training
- **Train-Test Split**: Stratified split to preserve class distribution.
- **Model Comparison**: Evaluated Random Forest, SVM, and XGBoost on balanced data.
- **Best Model**: XGBoost achieved the highest F1 score.
### Hyperparameter Tuning
- **Optuna**: Ran 50 trials to optimize XGBoost hyperparameters (e.g., `n_estimators`, `max_depth`, `learning_rate`, `gamma`, etc.) using 5-fold cross-validation and F1 score as the metric.
### Evaluation
- **F1 Score**: Achieved ~90% F1 score on the balanced test set.
- **Confusion Matrix**: Demonstrated balanced sensitivity and specificity.
- **Classification Report**: Provided detailed precision, recall, and F1 for each class.
- **Explainability**: SHAP analysis identified the most influential features and provided local/global interpretability.
---
## How to Install
Follow these steps to set up the project using a virtual environment:
```bash
# Clone or download the repository
git clone https://github.com/DeepActionPotential/StrokeLineAI
cd StrokeLineAI
# Create a virtual environment
python -m venv venv
# Activate the virtual environment
# On Windows:
venv\Scripts\activate
# On macOS/Linux:
source venv/bin/activate
# Upgrade pip
pip install --upgrade pip
# Install dependencies
pip install -r requirements.txt
```
---
## How to Use the Software
1. **Run the Web Application**
Start the Streamlit app:
```bash
streamlit run app.py
```
2. **Demo**
## [demo-video](demo/strokeline_demo.mp4)
)
---
## Technologies Used
### Data Science & Model Training
- **matplotlib, seaborn**: Data visualization.
- **scikit-learn**: Preprocessing, model selection, metrics, and pipelines.
- **imbalanced-learn**: Advanced resampling (SMOTE, RandomUnderSampler) for class balancing.
- **XGBoost**: High-performance gradient boosting for classification.
- **Optuna**: Automated hyperparameter optimization.
- **SHAP**: Model explainability and feature importance analysis.
### Deployment
- **Streamlit**: Rapid web app development for interactive model inference.
- **joblib**: Model serialization for deployment.
---
## License
This project is licensed under the MIT License.
See the [LICENSE](LICENSE) file for details.
|