--- title: StrokeLine - Stroke Prediction Using Machine Learning emoji: 🤖 colorFrom: indigo colorTo: blue sdk: streamlit sdk_version: 1.30.0 app_file: app.py pinned: false license: mit --- # Stroke Prediction Using Machine Learning ## About the Project This project provides a comprehensive machine learning pipeline for predicting the risk of stroke in individuals based on clinical and demographic features. The goal is to enable early identification of high-risk patients, supporting healthcare professionals in making informed decisions and potentially reducing stroke-related morbidity and mortality. The project covers the full data science workflow: data exploration, preprocessing, feature engineering, model selection, hyperparameter optimization, evaluation, explainability, and deployment. The final solution includes a trained model and a Streamlit web application for real-time inference. --- ## About the Dataset The dataset used is the [Stroke Prediction Dataset](https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-datasett) from Kaggle. It contains 5110 records with 12 features and a binary target variable (`stroke`). The features include: - **id**: Unique identifier (not used for modeling) - **gender**: Patient gender (`Male`, `Female`, `Other`) - **age**: Age in years - **hypertension**: Hypertension status (`0`: No, `1`: Yes) - **heart_disease**: Heart disease status (`0`: No, `1`: Yes) - **ever_married**: Marital status (`Yes`, `No`) - **work_type**: Type of work (`children`, `Govt_job`, `Never_worked`, `Private`, `Self-employed`) - **Residence_type**: Living area (`Urban`, `Rural`) - **avg_glucose_level**: Average glucose level - **bmi**: Body mass index (may contain missing values) - **smoking_status**: Smoking behavior (`formerly smoked`, `never smoked`, `smokes`, `Unknown`) - **stroke**: Target variable (`1`: Stroke occurred, `0`: No stroke) The dataset is imbalanced, with far fewer positive stroke cases than negatives, and contains missing values in the `bmi` column. --- ## Notebook Summary The notebook documents the entire process: 1. **Problem Definition**: Outlines the clinical motivation, dataset, and challenges. 2. **EDA**: Visualizes distributions, checks for missing values, and explores feature-target relationships. 3. **Feature Engineering**: Handles missing data, encodes categorical variables, and examines feature correlations. 4. **Data Balancing**: Uses RandomUnderSampler and SMOTE to address class imbalance. 5. **Model Selection**: Compares Random Forest, SVM, and XGBoost classifiers. 6. **Hyperparameter Tuning**: Uses Optuna for automated optimization of XGBoost. 7. **Evaluation**: Reports F1 score, confusion matrix, and classification report. 8. **Explainability**: Applies SHAP for model interpretation. 9. **Model Export**: Saves the trained model for deployment. --- ## Model Results ### Preprocessing - **Missing Values**: Imputed missing `bmi` values with the mean. - **Categorical Encoding**: Used `OrdinalEncoder` to convert categorical features to numeric. - **Feature Selection**: Dropped the `id` column and checked for highly correlated features. ### Data Balancing - **RandomUnderSampler**: Reduced the majority class to 10% of its original size. - **SMOTE**: Oversampled the minority class to achieve a 1:1 ratio. ### Training - **Train-Test Split**: Stratified split to preserve class distribution. - **Model Comparison**: Evaluated Random Forest, SVM, and XGBoost on balanced data. - **Best Model**: XGBoost achieved the highest F1 score. ### Hyperparameter Tuning - **Optuna**: Ran 50 trials to optimize XGBoost hyperparameters (e.g., `n_estimators`, `max_depth`, `learning_rate`, `gamma`, etc.) using 5-fold cross-validation and F1 score as the metric. ### Evaluation - **F1 Score**: Achieved ~90% F1 score on the balanced test set. - **Confusion Matrix**: Demonstrated balanced sensitivity and specificity. - **Classification Report**: Provided detailed precision, recall, and F1 for each class. - **Explainability**: SHAP analysis identified the most influential features and provided local/global interpretability. --- ## How to Install Follow these steps to set up the project using a virtual environment: ```bash # Clone or download the repository git clone https://github.com/DeepActionPotential/StrokeLineAI cd StrokeLineAI # Create a virtual environment python -m venv venv # Activate the virtual environment # On Windows: venv\Scripts\activate # On macOS/Linux: source venv/bin/activate # Upgrade pip pip install --upgrade pip # Install dependencies pip install -r requirements.txt ``` --- ## How to Use the Software 1. **Run the Web Application** Start the Streamlit app: ```bash streamlit run app.py ``` 2. **Demo** ## [demo-video](demo/strokeline_demo.mp4) ![demo-screenshot](demo/strokeline_demo.jpeg)) --- ## Technologies Used ### Data Science & Model Training - **matplotlib, seaborn**: Data visualization. - **scikit-learn**: Preprocessing, model selection, metrics, and pipelines. - **imbalanced-learn**: Advanced resampling (SMOTE, RandomUnderSampler) for class balancing. - **XGBoost**: High-performance gradient boosting for classification. - **Optuna**: Automated hyperparameter optimization. - **SHAP**: Model explainability and feature importance analysis. ### Deployment - **Streamlit**: Rapid web app development for interactive model inference. - **joblib**: Model serialization for deployment. --- ## License This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.