Spaces:
Sleeping
A newer version of the Streamlit SDK is available:
1.54.0
title: StrokeLine - Stroke Prediction Using Machine Learning
emoji: 🤖
colorFrom: indigo
colorTo: blue
sdk: streamlit
sdk_version: 1.30.0
app_file: app.py
pinned: false
license: mit
Stroke Prediction Using Machine Learning
About the Project
This project provides a comprehensive machine learning pipeline for predicting the risk of stroke in individuals based on clinical and demographic features. The goal is to enable early identification of high-risk patients, supporting healthcare professionals in making informed decisions and potentially reducing stroke-related morbidity and mortality. The project covers the full data science workflow: data exploration, preprocessing, feature engineering, model selection, hyperparameter optimization, evaluation, explainability, and deployment. The final solution includes a trained model and a Streamlit web application for real-time inference.
About the Dataset
The dataset used is the Stroke Prediction Dataset from Kaggle. It contains 5110 records with 12 features and a binary target variable (stroke). The features include:
- id: Unique identifier (not used for modeling)
- gender: Patient gender (
Male,Female,Other) - age: Age in years
- hypertension: Hypertension status (
0: No,1: Yes) - heart_disease: Heart disease status (
0: No,1: Yes) - ever_married: Marital status (
Yes,No) - work_type: Type of work (
children,Govt_job,Never_worked,Private,Self-employed) - Residence_type: Living area (
Urban,Rural) - avg_glucose_level: Average glucose level
- bmi: Body mass index (may contain missing values)
- smoking_status: Smoking behavior (
formerly smoked,never smoked,smokes,Unknown) - stroke: Target variable (
1: Stroke occurred,0: No stroke)
The dataset is imbalanced, with far fewer positive stroke cases than negatives, and contains missing values in the bmi column.
Notebook Summary
The notebook documents the entire process:
- Problem Definition: Outlines the clinical motivation, dataset, and challenges.
- EDA: Visualizes distributions, checks for missing values, and explores feature-target relationships.
- Feature Engineering: Handles missing data, encodes categorical variables, and examines feature correlations.
- Data Balancing: Uses RandomUnderSampler and SMOTE to address class imbalance.
- Model Selection: Compares Random Forest, SVM, and XGBoost classifiers.
- Hyperparameter Tuning: Uses Optuna for automated optimization of XGBoost.
- Evaluation: Reports F1 score, confusion matrix, and classification report.
- Explainability: Applies SHAP for model interpretation.
- Model Export: Saves the trained model for deployment.
Model Results
Preprocessing
- Missing Values: Imputed missing
bmivalues with the mean. - Categorical Encoding: Used
OrdinalEncoderto convert categorical features to numeric. - Feature Selection: Dropped the
idcolumn and checked for highly correlated features.
Data Balancing
- RandomUnderSampler: Reduced the majority class to 10% of its original size.
- SMOTE: Oversampled the minority class to achieve a 1:1 ratio.
Training
- Train-Test Split: Stratified split to preserve class distribution.
- Model Comparison: Evaluated Random Forest, SVM, and XGBoost on balanced data.
- Best Model: XGBoost achieved the highest F1 score.
Hyperparameter Tuning
- Optuna: Ran 50 trials to optimize XGBoost hyperparameters (e.g.,
n_estimators,max_depth,learning_rate,gamma, etc.) using 5-fold cross-validation and F1 score as the metric.
Evaluation
- F1 Score: Achieved ~90% F1 score on the balanced test set.
- Confusion Matrix: Demonstrated balanced sensitivity and specificity.
- Classification Report: Provided detailed precision, recall, and F1 for each class.
- Explainability: SHAP analysis identified the most influential features and provided local/global interpretability.
How to Install
Follow these steps to set up the project using a virtual environment:
# Clone or download the repository
git clone https://github.com/DeepActionPotential/StrokeLineAI
cd StrokeLineAI
# Create a virtual environment
python -m venv venv
# Activate the virtual environment
# On Windows:
venv\Scripts\activate
# On macOS/Linux:
source venv/bin/activate
# Upgrade pip
pip install --upgrade pip
# Install dependencies
pip install -r requirements.txt
How to Use the Software
Run the Web Application
Start the Streamlit app:streamlit run app.pyDemo
demo-video
Technologies Used
Data Science & Model Training
- matplotlib, seaborn: Data visualization.
- scikit-learn: Preprocessing, model selection, metrics, and pipelines.
- imbalanced-learn: Advanced resampling (SMOTE, RandomUnderSampler) for class balancing.
- XGBoost: High-performance gradient boosting for classification.
- Optuna: Automated hyperparameter optimization.
- SHAP: Model explainability and feature importance analysis.
Deployment
- Streamlit: Rapid web app development for interactive model inference.
- joblib: Model serialization for deployment.
License
This project is licensed under the MIT License.
See the LICENSE file for details.
