StrokeLineAI / README.md
DeepActionPotential's picture
Update README.md
247c16d verified
---
title: StrokeLine - Stroke Prediction Using Machine Learning
emoji: 🤖
colorFrom: indigo
colorTo: blue
sdk: streamlit
sdk_version: 1.30.0
app_file: app.py
pinned: false
license: mit
---
# Stroke Prediction Using Machine Learning
## About the Project
This project provides a comprehensive machine learning pipeline for predicting the risk of stroke in individuals based on clinical and demographic features. The goal is to enable early identification of high-risk patients, supporting healthcare professionals in making informed decisions and potentially reducing stroke-related morbidity and mortality. The project covers the full data science workflow: data exploration, preprocessing, feature engineering, model selection, hyperparameter optimization, evaluation, explainability, and deployment. The final solution includes a trained model and a Streamlit web application for real-time inference.
---
## About the Dataset
The dataset used is the [Stroke Prediction Dataset](https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-datasett) from Kaggle. It contains 5110 records with 12 features and a binary target variable (`stroke`). The features include:
- **id**: Unique identifier (not used for modeling)
- **gender**: Patient gender (`Male`, `Female`, `Other`)
- **age**: Age in years
- **hypertension**: Hypertension status (`0`: No, `1`: Yes)
- **heart_disease**: Heart disease status (`0`: No, `1`: Yes)
- **ever_married**: Marital status (`Yes`, `No`)
- **work_type**: Type of work (`children`, `Govt_job`, `Never_worked`, `Private`, `Self-employed`)
- **Residence_type**: Living area (`Urban`, `Rural`)
- **avg_glucose_level**: Average glucose level
- **bmi**: Body mass index (may contain missing values)
- **smoking_status**: Smoking behavior (`formerly smoked`, `never smoked`, `smokes`, `Unknown`)
- **stroke**: Target variable (`1`: Stroke occurred, `0`: No stroke)
The dataset is imbalanced, with far fewer positive stroke cases than negatives, and contains missing values in the `bmi` column.
---
## Notebook Summary
The notebook documents the entire process:
1. **Problem Definition**: Outlines the clinical motivation, dataset, and challenges.
2. **EDA**: Visualizes distributions, checks for missing values, and explores feature-target relationships.
3. **Feature Engineering**: Handles missing data, encodes categorical variables, and examines feature correlations.
4. **Data Balancing**: Uses RandomUnderSampler and SMOTE to address class imbalance.
5. **Model Selection**: Compares Random Forest, SVM, and XGBoost classifiers.
6. **Hyperparameter Tuning**: Uses Optuna for automated optimization of XGBoost.
7. **Evaluation**: Reports F1 score, confusion matrix, and classification report.
8. **Explainability**: Applies SHAP for model interpretation.
9. **Model Export**: Saves the trained model for deployment.
---
## Model Results
### Preprocessing
- **Missing Values**: Imputed missing `bmi` values with the mean.
- **Categorical Encoding**: Used `OrdinalEncoder` to convert categorical features to numeric.
- **Feature Selection**: Dropped the `id` column and checked for highly correlated features.
### Data Balancing
- **RandomUnderSampler**: Reduced the majority class to 10% of its original size.
- **SMOTE**: Oversampled the minority class to achieve a 1:1 ratio.
### Training
- **Train-Test Split**: Stratified split to preserve class distribution.
- **Model Comparison**: Evaluated Random Forest, SVM, and XGBoost on balanced data.
- **Best Model**: XGBoost achieved the highest F1 score.
### Hyperparameter Tuning
- **Optuna**: Ran 50 trials to optimize XGBoost hyperparameters (e.g., `n_estimators`, `max_depth`, `learning_rate`, `gamma`, etc.) using 5-fold cross-validation and F1 score as the metric.
### Evaluation
- **F1 Score**: Achieved ~90% F1 score on the balanced test set.
- **Confusion Matrix**: Demonstrated balanced sensitivity and specificity.
- **Classification Report**: Provided detailed precision, recall, and F1 for each class.
- **Explainability**: SHAP analysis identified the most influential features and provided local/global interpretability.
---
## How to Install
Follow these steps to set up the project using a virtual environment:
```bash
# Clone or download the repository
git clone https://github.com/DeepActionPotential/StrokeLineAI
cd StrokeLineAI
# Create a virtual environment
python -m venv venv
# Activate the virtual environment
# On Windows:
venv\Scripts\activate
# On macOS/Linux:
source venv/bin/activate
# Upgrade pip
pip install --upgrade pip
# Install dependencies
pip install -r requirements.txt
```
---
## How to Use the Software
1. **Run the Web Application**
Start the Streamlit app:
```bash
streamlit run app.py
```
2. **Demo**
## [demo-video](demo/strokeline_demo.mp4)
![demo-screenshot](demo/strokeline_demo.jpeg))
---
## Technologies Used
### Data Science & Model Training
- **matplotlib, seaborn**: Data visualization.
- **scikit-learn**: Preprocessing, model selection, metrics, and pipelines.
- **imbalanced-learn**: Advanced resampling (SMOTE, RandomUnderSampler) for class balancing.
- **XGBoost**: High-performance gradient boosting for classification.
- **Optuna**: Automated hyperparameter optimization.
- **SHAP**: Model explainability and feature importance analysis.
### Deployment
- **Streamlit**: Rapid web app development for interactive model inference.
- **joblib**: Model serialization for deployment.
---
## License
This project is licensed under the MIT License.
See the [LICENSE](LICENSE) file for details.