Martinacap02 commited on
Commit
a7ce724
·
1 Parent(s): b296845

Initial HF Space Docker deployment

Browse files
Files changed (42) hide show
  1. .dockerignore +2 -0
  2. Dockerfile +34 -0
  3. README.md +235 -7
  4. data/README.md +176 -0
  5. data/interim/preprocess_artifacts/scaler.joblib +3 -0
  6. models/README.md +110 -0
  7. models/nosex/random_forest.joblib +3 -0
  8. predicting_outcomes_in_heart_failure/__init__.py +1 -0
  9. predicting_outcomes_in_heart_failure/__pycache__/__init__.cpython-311.pyc +0 -0
  10. predicting_outcomes_in_heart_failure/__pycache__/config.cpython-311.pyc +0 -0
  11. predicting_outcomes_in_heart_failure/app/__init__.py +0 -0
  12. predicting_outcomes_in_heart_failure/app/__pycache__/__init__.cpython-311.pyc +0 -0
  13. predicting_outcomes_in_heart_failure/app/__pycache__/main.cpython-311.pyc +0 -0
  14. predicting_outcomes_in_heart_failure/app/__pycache__/schema.cpython-311.pyc +0 -0
  15. predicting_outcomes_in_heart_failure/app/__pycache__/utils.cpython-311.pyc +0 -0
  16. predicting_outcomes_in_heart_failure/app/__pycache__/wrapper.cpython-311.pyc +0 -0
  17. predicting_outcomes_in_heart_failure/app/main.py +151 -0
  18. predicting_outcomes_in_heart_failure/app/routers/__pycache__/cards.cpython-311.pyc +0 -0
  19. predicting_outcomes_in_heart_failure/app/routers/__pycache__/general.cpython-311.pyc +0 -0
  20. predicting_outcomes_in_heart_failure/app/routers/__pycache__/model_info.cpython-311.pyc +0 -0
  21. predicting_outcomes_in_heart_failure/app/routers/__pycache__/prediction.cpython-311.pyc +0 -0
  22. predicting_outcomes_in_heart_failure/app/routers/cards.py +52 -0
  23. predicting_outcomes_in_heart_failure/app/routers/general.py +19 -0
  24. predicting_outcomes_in_heart_failure/app/routers/model_info.py +95 -0
  25. predicting_outcomes_in_heart_failure/app/routers/prediction.py +135 -0
  26. predicting_outcomes_in_heart_failure/app/schema.py +28 -0
  27. predicting_outcomes_in_heart_failure/app/utils.py +41 -0
  28. predicting_outcomes_in_heart_failure/app/wrapper.py +135 -0
  29. predicting_outcomes_in_heart_failure/config.py +129 -0
  30. predicting_outcomes_in_heart_failure/data/dataset.py +22 -0
  31. predicting_outcomes_in_heart_failure/data/preprocess.py +116 -0
  32. predicting_outcomes_in_heart_failure/data/split_data.py +114 -0
  33. predicting_outcomes_in_heart_failure/modeling/__init__.py +0 -0
  34. predicting_outcomes_in_heart_failure/modeling/__pycache__/__init__.cpython-311.pyc +0 -0
  35. predicting_outcomes_in_heart_failure/modeling/__pycache__/explainability.cpython-311.pyc +0 -0
  36. predicting_outcomes_in_heart_failure/modeling/__pycache__/predict.cpython-311.pyc +0 -0
  37. predicting_outcomes_in_heart_failure/modeling/evaluate.py +182 -0
  38. predicting_outcomes_in_heart_failure/modeling/explainability.py +202 -0
  39. predicting_outcomes_in_heart_failure/modeling/predict.py +135 -0
  40. predicting_outcomes_in_heart_failure/modeling/train.py +261 -0
  41. reports/figures/.gitkeep +0 -0
  42. reports/nosex/random_forest/cv_parameters.json +40 -0
.dockerignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ venv/
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11.9-slim-bookworm
2
+
3
+ ENV PYTHONUNBUFFERED=1
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PIP_NO_CACHE_DIR=1
6
+ ENV PIP_DISABLE_PIP_VERSION_CHECK=1
7
+
8
+ RUN useradd -m -u 1000 user
9
+ USER user
10
+ ENV PATH="/home/user/.local/bin:$PATH"
11
+
12
+ WORKDIR /cardioTrack
13
+
14
+ RUN apt-get update && apt-get install -y --no-install-recommends \
15
+ curl ca-certificates \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh
19
+ #ENV PATH="/root/.local/bin:$PATH"
20
+
21
+ COPY --chown=user pyproject.toml uv.lock ./
22
+
23
+ RUN uv sync --locked --no-install-project
24
+
25
+ COPY --chown=user predicting_outcomes_in_heart_failure ./predicting_outcomes_in_heart_failure
26
+ COPY --chown=user models/nosex/random_forest.joblib ./models/nosex/random_forest.joblib
27
+ COPY --chown=user reports/nosex/random_forest/cv_parameters.json ./reports/nosex/random_forest/cv_parameters.json
28
+ COPY --chown=user data/interim/preprocess_artifacts/scaler.joblib ./data/interim/preprocess_artifacts/scaler.joblib
29
+ COPY --chown=user metrics/test/nosex/random_forest.json ./metrics/test/nosex/random_forest.json
30
+ COPY --chown=user README.md ./README.md
31
+
32
+ EXPOSE 7860
33
+
34
+ CMD ["uv", "run", "uvicorn", "predicting_outcomes_in_heart_failure.app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,239 @@
1
  ---
2
- title: CardioTrack
3
- emoji: 😻
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: docker
7
- pinned: false
8
- short_description: Predicting Outcomes in Heart Failure
9
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: CardioTrack API
3
+ emoji: ❤️
4
+ colorFrom: purple
5
+ colorTo: gray
6
  sdk: docker
7
+ app_port: 7860
 
8
  ---
9
+ # Predicting Outcomes in Heart Failure
10
+
11
+ ## Table of Contents
12
+ 1. [Project Overview](#project-overview)
13
+ 2. [Project Organization](#project-organization)
14
+ 3. [DVC Pipeline Defined](#dvc-pipeline-defined)
15
+ 4. [Milestones Summary](#milestones-summary)
16
+ - [Milestone 1 - Inception](#milestone-1---inception)
17
+ - [Milestone 2 - Reproducibility](#milestone-2---reproducibility)
18
+ - [Milestone 3 - Quality Assurance](#milestone-3---quality-assurance)
19
+ - [Milestone 4 - API Integration](#milestone-4---API-Integration)
20
+
21
+ ## Project Overview
22
+ <a target="_blank" href="https://cookiecutter-data-science.drivendata.org/">
23
+ <img src="https://img.shields.io/badge/CCDS-Project%20template-328F97?logo=cookiecutter" />
24
+ </a>
25
+
26
+ This project develops a predictive pipeline for patient outcome prediction in heart failure, using a publicly available dataset of clinical records. The goal is to design and evaluate machine learning models within a reproducible workflow that can be integrated into larger systems for clinical decision support. The workflow addresses data heterogeneity, defines consistent preprocessing and feature engineering strategies, and explores alternative modeling approaches with systematic evaluation using clinically relevant metrics. It also emphasizes model transparency and auditability, ensuring that the resulting pipeline can be deployed as a reliable, adaptable software component in healthcare applications. The project aims not only to improve baseline predictive performance but also to demonstrate how data-driven models can be effectively integrated into end-to-end AI-enabled healthcare systems.
27
+
28
+ ## Project Organization
29
+
30
+ ```
31
+ ├── LICENSE <- Open-source license if one is chosen
32
+ ├── Makefile <- Makefile with convenience commands like `make data` or `make train`
33
+ ├── README.md <- The top-level README for developers using this project.
34
+ ├── data
35
+ │ ├── external <- Data from third party sources.
36
+ │ ├── interim <- Intermediate data that has been transformed.
37
+ │ ├── processed <- The final, canonical data sets for modeling.
38
+ │ └── raw <- The original, immutable data dump.
39
+
40
+ ├── docs <- A default mkdocs project; see www.mkdocs.org for details
41
+
42
+ ├── models <- Trained and serialized models, model predictions, or model summaries
43
+
44
+ ├── notebooks <- Jupyter notebooks. Naming convention is a number (for ordering),
45
+ │ the creator's initials, and a short `-` delimited description, e.g.
46
+ │ `1.0-jqp-initial-data-exploration`.
47
+
48
+ ├── pyproject.toml <- Project configuration file with package metadata for
49
+ │ predicting_outcomes_in_heart_failure and configuration for tools like black
50
+
51
+ ├── references <- Data dictionaries, manuals, and all other explanatory materials.
52
+
53
+ ├── reports <- Generated analysis as HTML, PDF, LaTeX, etc.
54
+ │ └── figures <- Generated graphics and figures to be used in reporting
55
+
56
+ ├── requirements.txt <- The requirements file for reproducing the analysis environment, e.g.
57
+ │ generated with `pip freeze > requirements.txt`
58
+
59
+ ├── setup.cfg <- Configuration file for flake8
60
+
61
+ └── predicting_outcomes_in_heart_failure <- Source code for use in this project.
62
+
63
+ ├── __init__.py <- Makes predicting_outcomes_in_heart_failure a Python module
64
+
65
+ ├── config.py <- Store useful variables and configuration
66
+
67
+ ├── data
68
+ │ ├── __init__.py
69
+ │ ├── dataset.py <- Scripts to download or generate data
70
+ | ├── preprocess.py <- Data preprocessing code
71
+ │ └── split_data.py <- Split dataset into train and test code
72
+
73
+ ├── features.py <- Code to create features for modeling
74
+
75
+ ├── modeling
76
+ │ ├── __init__.py
77
+ │ ├── predict.py <- Code to run model inference with trained models
78
+ │ └── train.py <- Code to train models
79
+
80
+ └── plots.py <- Code to create visualizations
81
+ ```
82
+
83
+ ## DVC Pipeline defined
84
+ ```
85
+ +---------------+
86
+ | download_data |
87
+ +---------------+
88
+ *
89
+ *
90
+ *
91
+ +---------------+
92
+ | preprocessing |
93
+ +---------------+
94
+ *
95
+ *
96
+ *
97
+ +------------+
98
+ | split_data |
99
+ +------------+
100
+ *** ***
101
+ * *
102
+ ** ***
103
+ +----------+ *
104
+ | training | ***
105
+ +----------+ *
106
+ *** ***
107
+ * *
108
+ ** **
109
+ +------------+
110
+ | evaluation |
111
+ +------------+
112
+ ```
113
+
114
+ ## Milestones Summary
115
+
116
+ ### Milestone 1 - Inception
117
+ During this milestone, the **CCDS Project Template** was used as the foundation for organizing the project.
118
+ The main conceptual and structural components of the system were defined, following the template guidelines to ensure consistency and traceability.
119
+
120
+ Additionally, a **Machine Learning Canvas** has been added in the [`docs/`](./docs) folder.
121
+ It outlines the model objectives, the data to be used, and the key methodological aspects planned for the next phases of the project.
122
+
123
+ ### Milestone 2 - Reproducibility
124
+ Milestone-2 introduces **reproducibility**, from **data management** to **model training and evaluation**. This includes a fully automated pipeline, experiment tracking, and model registry integration, ensuring every step can be consistently reproduced and monitored.
125
+
126
+ #### Exploratory Data Analysis (EDA)
127
+ As part of the early steps, we added and refined an **Exploratory Data Analysis** to better understand the dataset, its distribution, and relationships between variables. This helped define the preprocessing and modeling strategies used later.
128
+
129
+ #### DVC Initialization and Pipeline Setup
130
+ We initialized **DVC** and configured a full pipeline to automate the main steps of the ML workflow:
131
+ - Automatic data **download**
132
+ - **Preprocessing**
133
+ - **Data splitting**
134
+ - **Training** and **evaluation**
135
+
136
+ The pipeline is fully reproducible and version-controlled through DVC.
137
+
138
+ #### Model Training and Experiment Tracking
139
+ We implemented the **training scripts** and integrated **MLflow** for experiment tracking.
140
+ Three models are trained and evaluated within this workflow:
141
+ - Decision Tree
142
+ - Random Forest
143
+ - Logistic Regression
144
+
145
+ Each experiment is logged to MLflow.
146
+
147
+ #### Model Registry and Thresholds
148
+ Models that reach or exceed the predefined **performance thresholds** (as defined in the ML Canvas) are automatically **saved to the model registry**.
149
+
150
+ ### Milestone 3 – Quality Assurance
151
+
152
+ In this milestone, we introduced **Quality Assurance** layer to the system.
153
+
154
+ #### Static Linters
155
+ Two static linters were added to improve code style and consistency:
156
+
157
+ - **Ruff** for Python files in the `predicting_outcomes_in_heart_failure` and `tests` folders.
158
+ It checks formatting, syntax, and common anti-patterns, and is integrated into the GitHub workflow via an *action*.
159
+ - **Pynblint** for Jupyter notebooks, also integrated into the GitHub workflow through a dedicated *action*.
160
+
161
+ #### Data Quality
162
+ We implemented **data quality checks** on both raw and processed data using **Great Expectations**.
163
+ These validations help to:
164
+
165
+ - detect anomalies or invalid values at the data source
166
+ - prevent the propagation of data issues into downstream processes
167
+
168
+ #### Code Quality
169
+ We added automated **unit and integration tests** using **pytest**, covering the main modules and functionalities of the system.
170
+
171
+
172
+ #### ML Pipeline Enhancements
173
+ we applied the following enhancements to the ML pipeline:
174
+
175
+ - Refactored preprocessing with gender-based dataset variants.
176
+ - Added validation (e.g., error on single-row datasets).
177
+ - Saved StandardScaler as preprocessing artifact.
178
+ - Updated split logic and DVC pipeline.
179
+ - Training now creates variant-specific MLflow experiments.
180
+ - Added RandomOverSampler to address class imbalance.
181
+ - Updated evaluation and inference to align with the new structure.
182
+
183
+ #### Explainability
184
+ We applied an explainability module:
185
+
186
+ - Added SHAP explainability module.
187
+ - Added tests for explainability functionality.
188
+
189
+
190
+ #### Risk Classification
191
+ We added a **Risk Classification** analysis for the system in accordance with **IMDRF** and **AI Act** regulations.
192
+ The documentation is available in the [`docs/`](./docs) folder.
193
+
194
+ Ecco la versione finale **in Markdown puro**, già formattata correttamente:
195
+
196
+
197
+ ### Milestone 4 - API Integration
198
+
199
+ During Milestone 4, we implemented a fully functional API and Dataset Card and Model card for the champion model and the following used dataset.
200
+ APIs are structured into four main routers:
201
+
202
+
203
+ #### **General Router**
204
+ - **GET /**
205
+ Returns a welcome message and confirms that the API is running.
206
+
207
+
208
+ #### **Prediction Router**
209
+ - **POST /predictions**
210
+ Generates a binary prediction (0/1) for a single patient sample.
211
+
212
+ - **POST /predict-batch**
213
+ Accepts a list of patient samples and returns a prediction for each element in the batch.
214
+
215
+ - **POST /explanations**
216
+ Produces SHAP-based explanations for a single input and returns the URL of the generated SHAP waterfall plot.
217
+
218
+
219
+ #### **Model Info Router**
220
+ - **GET /model/hyperparameters**
221
+ Returns the hyperparameters and cross-validation results of the model defined in `MODEL_PATH`.
222
+
223
+ - **GET /model/metrics**
224
+ Returns the test-set metrics stored during the model evaluation stage.
225
+
226
+
227
+ #### **Cards Router**
228
+ - **GET /card/{card_type}**
229
+ Returns the content of a “card” file (dataset card or model card).
230
+
231
+
232
+ ### **Cards**
233
+
234
+ During this milestone, we also created:
235
+
236
+ - a **dataset card** describing the dataset used by the champion model
237
+ - a **model card** documenting the champion model itself
238
+
239
 
 
data/README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Card
2
+
3
+ ## Table of Contents
4
+ - [Dataset Description](#dataset-description)
5
+ - [Dataset Summary](#dataset-summary)
6
+ - [Supported Tasks](#supported-tasks)
7
+ - [Languages](#languages)
8
+ - [Dataset Structure](#dataset-structure)
9
+ - [Data Instances](#data-instances)
10
+ - [Data Fields](#data-fields)
11
+ - [Dataset Creation](#dataset-creation)
12
+ - [Source Data](#source-data)
13
+ - [Annotations](#annotations)
14
+ - [Personal and Sensitive Information](#personal-and-sensitive-information)
15
+ - [Considerations for Using the Data](#considerations-for-using-the-data)
16
+ - [Social Impact of Dataset](#social-impact-of-dataset)
17
+ - [Discussion of Biases](#discussion-of-biases)
18
+ - [Additional Information](#additional-information)
19
+ - [Dataset Curators](#dataset-curators)
20
+ - [Citation Information](#citation-information)
21
+
22
+
23
+
24
+ ## Dataset Description
25
+
26
+ - **Homepage:** https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction
27
+
28
+
29
+ ### Dataset Summary
30
+
31
+ This dataset contains anonymized clinical data used to predict the risk of heart failure.
32
+ It includes **918 patient records**, **11 clinical features**, and **one target variable**.
33
+ The original dataset was downloaded from Kaggle and was created by merging five well-known cardiology datasets.
34
+
35
+ The version used in this project underwent additional preprocessing steps, including standardization, normalization, categorical encoding, and removal of the Sex feature. The resulting dataset is used for experimentation and model development.
36
+
37
+
38
+
39
+ ### Supported Tasks
40
+
41
+ This dataset can be used for a variety of machine learning tasks, including:
42
+
43
+ - **Binary Classification**
44
+
45
+ Predicting whether a patient has heart disease.
46
+ - **Risk Scoring / Clinical Risk Stratification**
47
+
48
+ Estimating cardiac risk based on clinical variables.
49
+ - **Explainable AI (XAI)**
50
+
51
+ Useful for feature-importance analysis and interpretability.
52
+
53
+
54
+ ### Languages
55
+
56
+ English **(en)**
57
+
58
+
59
+ ## Dataset Structure
60
+
61
+ ### Data Instances
62
+
63
+ Each instance represents one patient. Example:
64
+
65
+ | Age |Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease |
66
+ |-----|----|---------------|-----------|-------------|-----------|------------|-------|----------------|---------|----------|--------------|
67
+ | 54 | M | ASY | 140 | 239 | 0 | Normal | 160 | N | 1.2 | Flat | 1 |
68
+
69
+
70
+
71
+ ### Data Fields
72
+
73
+ | Field | Type | Description |
74
+ |----------------|-----------|---------------------------------------------------------------|
75
+ | Age | int | Patient age in years |
76
+ | Sex | binary | Patient sex (M = male, F = female) |
77
+ | ChestPainType | category | Chest pain type (TA, ATA, NAP, ASY) |
78
+ | RestingBP | int | Resting blood pressure (mm Hg) |
79
+ | Cholesterol | int | Serum cholesterol (mg/dL) |
80
+ | FastingBS | binary | Fasting blood sugar (1 if >120 mg/dL, 0 otherwise) |
81
+ | RestingECG | category | Resting ECG results (Normal, ST, LVH) |
82
+ | MaxHR | int | Maximum heart rate achieved |
83
+ | ExerciseAngina | binary | Exercise-induced angina (Y/N) |
84
+ | Oldpeak | float | ST depression relative to rest |
85
+ | ST_Slope | category | Slope of the ST segment (Up, Flat, Down) |
86
+ | HeartDisease | binary | Target variable (1 = disease, 0 = no disease) |
87
+
88
+
89
+
90
+ ## Dataset Creation
91
+
92
+ ### Source Data
93
+
94
+ The preprocessed dataset used in this project originates from the Kaggle dataset *“Heart Failure Prediction Dataset”*.
95
+
96
+ The raw dataset was created by merging five widely-used cardiology datasets:
97
+
98
+ - Cleveland (303 samples)
99
+ - Hungarian (294 samples)
100
+ - Switzerland (123 samples)
101
+ - Long Beach VA (200 samples)
102
+ - Stalog (270 samples)
103
+
104
+ The Kaggle author selected the 11 common features and merged the datasets into a unified collection of **1,190 records**, then removed **272 duplicates**, resulting in **918 unique samples**.
105
+
106
+ All initial merging and normalization steps were performed by the dataset author on Kaggle.
107
+
108
+
109
+
110
+ ### Annotations
111
+
112
+ No manual annotations were added.
113
+ The target variable `HeartDisease` is already included in the original dataset.
114
+
115
+
116
+
117
+ ### Personal and Sensitive Information
118
+
119
+ Although the dataset contains clinical information (sensitive under GDPR), it is fully anonymized:
120
+
121
+ - No personal identifiers (name, address, contact details, IDs).
122
+ - All sources were already anonymized before publication.
123
+ - No biometric or genetic data are included.
124
+
125
+ Thus, while clinically sensitive, the dataset does **not** pose identifiable privacy risks.
126
+
127
+
128
+
129
+ ## Considerations for Using the Data
130
+
131
+ ### Social Impact of Dataset
132
+
133
+ The dataset can support research and development of models for cardiac risk prediction and early detection.
134
+
135
+ However:
136
+
137
+ - Models trained on this dataset **must not be used as standalone diagnostic tools**.
138
+ - They should **not** be the sole basis for clinical decisions.
139
+ - Misuse in healthcare contexts may lead to incorrect risk assessment.
140
+
141
+
142
+
143
+ ### Discussion of Biases
144
+
145
+
146
+ This dataset may contain several sources of bias that can affect model performance and fairness:
147
+
148
+ - The data comes from multiple hospitals and countries, each with different patient profiles and clinical protocols. Some groups may be underrepresented.
149
+ - Source datasets used different diagnostic practices and measurement standards, which may introduce noise or inconsistency in labels and clinical values.
150
+ - Only 11 features are included, omitting other relevant clinical variables. This can cause proxy bias or oversimplification of cardiac risk.
151
+ - Some datasets are older and may not reflect current medical practices or population characteristics.
152
+
153
+
154
+
155
+ ## Additional Information
156
+
157
+ ### Dataset Curators
158
+
159
+ The original dataset was created and published by **[fedesoriano](https://www.kaggle.com/fedesoriano)** on Kaggle.
160
+
161
+ The preprocessed dataset was curated by the **CardioTrack** team:
162
+
163
+ - [Fabrizio Rosmarino](https://github.com/Fabrizio250)
164
+ - [Martina Capone](https://github.com/Martycap)
165
+ - [Donato Boccuzzi](https://github.com/donatooooooo)
166
+
167
+ Work carried out as part of the *Software Engineering for AI-Enabled Systems* program at the University of Bari.
168
+
169
+ ### Citation Information
170
+
171
+ If you use this datasets, please cite:
172
+
173
+ **Original Dataset**
174
+ Soriano, F. (2021). *Heart Failure Prediction Dataset*. Kaggle.
175
+ https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction
176
+
data/interim/preprocess_artifacts/scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97f621e75c15a9059080f44de0985cb0cf22c889f09bc928e1eed2646168c9d1
3
+ size 1023
models/README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card
2
+
3
+ ## Table of Contents
4
+
5
+ - [Model Details](#model-details)
6
+ - [Training Information](#training-information)
7
+ - [Intended Use](#intended-use)
8
+ - [Primary Intended Uses](#primary-intended-uses)
9
+ - [Primary Intended Users](#primary-intended-users)
10
+ - [Out-of-scope Use Cases](#out-of-scope-use-cases)
11
+ - [Factors](#factors)
12
+ - [Relevant Factors](#relevant-factors)
13
+ - [Evaluation Factors](#evaluation-factors)
14
+
15
+ - [Metrics](#metrics)
16
+ - [Model Performance](#model-performance)
17
+ - [Variation Approaches](#variation-approaches)
18
+
19
+ - [Evaluation Data](#evaluation-data)
20
+ - [Datasets](#datasets)
21
+ - [Motivation](#motivation)
22
+ - [Preprocessing](#preprocessing)
23
+
24
+ - [Training Data](#training-data)
25
+ - [Datasets](#datasets-1)
26
+ - [Preprocessing](#preprocessing-1)
27
+
28
+ - [Ethical Considerations](#ethical-considerations)
29
+
30
+ - [Caveats and Recommendations](#caveats-and-recommendations)
31
+
32
+ ## Model Details
33
+ - Developed by: D. Boccuzzi, M. Capone, F. Rosmarino
34
+ - Model Date: November 11th, 2025
35
+ - Model Version: 6 - nosex
36
+ - Model Type: RandomForestClassifier
37
+ ### Training information
38
+ - Best hyperparameters tuned with a 5-fold cross validation:
39
+ - `max_depth` 12
40
+ - `n_estimators` 800
41
+ - `class_weight` balanced
42
+ - Applied approaches:
43
+ During training, oversampling technique was applied to balance the dataset and reduce bias toward the majority class. This ensured that the model learned equally from positive and negative cases, improving prediction performance for the minority class.
44
+ - Training started at: 11:26:59 2025-11-12
45
+ - Training ended at: 11:34:29 2025-11-12
46
+
47
+ ## Intended Use
48
+ ### Primary intended uses
49
+ The CardioTrack ML system is designed to support early detection of heart failure by analyzing clinical features and identifying patients who may be at risk. Its purpose is to assist cardiologists in deciding when further diagnostic tests, monitoring, or preventive treatments are needed. The system is also intended for local public health authorities, who can use aggregated predictions to plan healthcare resources and implement prevention strategies within the population.
50
+ ### Primary intended users
51
+ The primary users of the model are cardiologists and other qualified medical professionals who rely on clinical decision support tools. They are responsible for interpreting the model’s predictions in conjunction with the patient’s medical history and additional clinical information. Public health authorities may also use aggregated, non-individual results to support long-term planning and policy development.
52
+ ### Out-of-scope use cases
53
+ The model should not be used without access to complete and reliable clinical features, and it is not suitable for real-time emergency triage or for predictive tasks not directly related to heart failure.
54
+
55
+ ## Factors
56
+ ### Relevant factors
57
+ Model performance may vary depending on patient characteristics that influence heart disease risk, as reflected in the contributions of individual clinical features. Age remains a relevant factor because it strongly correlates with cardiovascular conditions. In addition, features such as ST_Slope, ChestPainType, MaxHR, and ExerciseAngina have the largest impact on individual predictions, as highlighted by SHAP module for XAI. These features capture meaningful physiological and clinical differences among patients and explain why the model predicts higher or lower risk for specific individuals. Instrumentation and environmental factors are not relevant because the model operates on structured clinical data rather than on signals or images affected by measurement devices or environmental conditions.
58
+ ### Evaluation Factors
59
+ The evaluation focuses on key clinical features that the model heavily relies on. The Relevant factors were chosen because they are both present in the dataset and have the largest impact on the model’s outputs, allowing clear interpretation of how predictions are made.
60
+
61
+ ## Metrics
62
+ ### Model Performance
63
+ - `F1 Score` 0.8990
64
+ - `Recall` 0.9019
65
+ - `Accuracy` 0.8876
66
+ - `ROC-AUC` 0.9399
67
+ ### Variation approaches
68
+ The reported metrics were computed using the best model selected during cross validation for hyperparameter tuning, and evaluated on a completely independent test set. This setup was chosen because it provides a cleaner estimate of real-world performance, reduces the risk of overfitting to validation folds, and ensures that the results reflect the model’s generalization ability.
69
+
70
+ ## Evaluation Data
71
+ ### Datasets
72
+ The evaluation was performed using 276 of 918 (30%) observations of the Kaggle's [Heart Failure Prediction Dataset](https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction), which contains clinical data from both healthy individuals and patients diagnosed with heart failure.
73
+ ### Motivation
74
+ This dataset was chosen because it provides a comprehensive set of relevant clinical features that capture key cardiovascular risk factors, enabling the model to perform early detection of heart failure in individual patients. Its publicly available nature ensures transparency.
75
+ ### Preprocessing
76
+ Before evaluation, the data was preprocessed as follows:
77
+
78
+ - **Cleaning of invalid values**
79
+ Rows with impossible clinical values (e.g., `RestingBP = 0`) were removed.
80
+ Zero cholesterol values were treated as missing and replaced using a central-tendency statistic.
81
+
82
+ - **Encoding of categorical variables**
83
+ Binary categories were converted to numerical format, while multi-class fields (`ChestPainType`, `RestingECG`, `ST_Slope`) were one-hot encoded.
84
+
85
+ - **Scaling of numerical features**
86
+ Continuous variables were standardized to have mean 0 and unit variance.
87
+
88
+ - **Removal of the `Sex` feature**
89
+ The Sex feature was removed to reduce potential fairness concerns and because it was not required for the planned experiments.
90
+
91
+ - The processed dataset is versioned on Dagshub at following [link](https://dagshub.com/se4ai2526-uniba/CardioTrack)
92
+
93
+ ## Training Data
94
+ ### Datasets
95
+ The training data mirrors the evaluation dataset, using 642 of 918 (70%) of the same Kaggle Heart Failure Prediction Dataset.
96
+ ### Preprocessing
97
+ The training data underwent the same preprocessing steps as the evaluation data. Additionally, RandomOversampler technique was applied to balance the classes, ensuring that the model learned equally from positive (heart failure) and negative cases.
98
+
99
+ ## Ethical Considerations
100
+ The Cardio Track ML system is intended to support clinical decision-making but not replace professional judgment. Ethical considerations include:
101
+ - **Privacy and security**: All patient data is processed on-premises, in accordance with hospital IT protocols, protecting sensitive health information.
102
+ - **Transparency**: Feature importance with SHAP visualizations allow clinicians to interpret predictions.
103
+ - **Clinical responsibility**: Diagnosis must be combined with patient history, exams, and expert judgment. Misuse in isolation could lead to incorrect interventions.
104
+
105
+ ## Caveats and Recommendations
106
+ - **Inference time**: The model’s inference time is about 0.2 seconds, but it can varies with computing power where inference is run.
107
+ - **Limitations**: The model is trained on a specific public dataset and may not capture rare cardiovascular conditions or population-specific variations.
108
+ - **Data quality**: Accurate predictions depend on complete and correctly measured clinical features. Erroneous data can reduce performance. Missing data is not allowed.
109
+ - **Not for emergency triage**: Predictions are intended for early detection and planning, not for immediate emergency decision-making.
110
+ - **Periodic retraining**: To maintain accuracy, the model should be updated with newly collected clinical data to account for shifts in patient population or disease prevalence.
models/nosex/random_forest.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6b1b4f1a0485bfc97ad4ddde264cd17279e08becafd4510907accdd14473435
3
+ size 13471369
predicting_outcomes_in_heart_failure/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from predicting_outcomes_in_heart_failure import config # noqa: F401
predicting_outcomes_in_heart_failure/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (293 Bytes). View file
 
predicting_outcomes_in_heart_failure/__pycache__/config.cpython-311.pyc ADDED
Binary file (3.59 kB). View file
 
predicting_outcomes_in_heart_failure/app/__init__.py ADDED
File without changes
predicting_outcomes_in_heart_failure/app/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (216 Bytes). View file
 
predicting_outcomes_in_heart_failure/app/__pycache__/main.cpython-311.pyc ADDED
Binary file (8.76 kB). View file
 
predicting_outcomes_in_heart_failure/app/__pycache__/schema.cpython-311.pyc ADDED
Binary file (1.88 kB). View file
 
predicting_outcomes_in_heart_failure/app/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.11 kB). View file
 
predicting_outcomes_in_heart_failure/app/__pycache__/wrapper.cpython-311.pyc ADDED
Binary file (8.61 kB). View file
 
predicting_outcomes_in_heart_failure/app/main.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+
3
+ from fastapi import FastAPI
4
+ from fastapi.staticfiles import StaticFiles
5
+ import gradio as gr
6
+ import joblib
7
+ from loguru import logger
8
+
9
+ from predicting_outcomes_in_heart_failure.app.routers import cards, general, model_info, prediction
10
+ from predicting_outcomes_in_heart_failure.app.utils import load_page
11
+ from predicting_outcomes_in_heart_failure.app.wrapper import Wrapper
12
+ from predicting_outcomes_in_heart_failure.config import FIGURES_DIR, MODEL_PATH
13
+
14
+
15
+ @asynccontextmanager
16
+ async def lifespan(app: FastAPI):
17
+ """Context manager to handle application lifespan events."""
18
+ if not MODEL_PATH.exists():
19
+ logger.error(f"Model file not found at: {MODEL_PATH}")
20
+ raise FileNotFoundError(f"Model file not found at: {MODEL_PATH}")
21
+
22
+ logger.info(f"Loading default model from {MODEL_PATH} ...")
23
+ app.state.model = joblib.load(MODEL_PATH)
24
+ logger.success(f"Default model loaded from {MODEL_PATH}")
25
+
26
+ try:
27
+ yield
28
+ finally:
29
+ app.state.model = None
30
+ logger.info("Default model cleared on application shutdown")
31
+
32
+
33
+ app = FastAPI(
34
+ title="CardioTrack's Model Space - Heart Failure Prediction",
35
+ version="0.01",
36
+ lifespan=lifespan,
37
+ )
38
+
39
+ app.mount("/figures", StaticFiles(directory=str(FIGURES_DIR)), name="figures")
40
+ app.include_router(general.router)
41
+ app.include_router(prediction.router)
42
+ app.include_router(model_info.router)
43
+ app.include_router(cards.router)
44
+
45
+
46
+ # UI Definition
47
+ with gr.Blocks(title="CardioTrack") as io:
48
+ gr.Markdown(
49
+ """
50
+ # 🫀 CardioTrack's Model Space - Heart Failure Prediction
51
+ Choose an area to access the platform's features.
52
+ """
53
+ )
54
+
55
+ with gr.Tabs():
56
+ with gr.TabItem("Single Prediction"):
57
+ gr.Markdown("### Enter patient data for prediction")
58
+
59
+ with gr.Row():
60
+ with gr.Column():
61
+ age = gr.Slider(minimum=20, maximum=100, step=1, label="Age", value=60)
62
+ resting_bp = gr.Slider(
63
+ minimum=80,
64
+ maximum=200,
65
+ step=1,
66
+ label="Resting Blood Pressure (mm Hg)",
67
+ value=120,
68
+ )
69
+ cholesterol = gr.Slider(
70
+ minimum=0, maximum=600, step=1, label="Cholesterol (mg/dL)", value=200
71
+ )
72
+ max_hr = gr.Slider(
73
+ minimum=60, maximum=220, step=1, label="Max Heart Rate", value=150
74
+ )
75
+ oldpeak = gr.Slider(
76
+ minimum=-3.0,
77
+ maximum=7.0,
78
+ step=0.1,
79
+ label="Oldpeak (ST Depression)",
80
+ value=1.0,
81
+ )
82
+
83
+ with gr.Column():
84
+ chest_pain_type = gr.Dropdown(
85
+ choices=["TA", "ATA", "NAP", "ASY"], label="Chest Pain Type", value="ASY"
86
+ )
87
+ fasting_bs = gr.Dropdown(
88
+ choices=[0, 1],
89
+ label="Fasting Blood Sugar (0: <=120 mg/dL, 1: >120 mg/dL)",
90
+ value=0,
91
+ )
92
+ resting_ecg = gr.Dropdown(
93
+ choices=["Normal", "ST", "LVH"], label="Resting ECG", value="Normal"
94
+ )
95
+ exercise_angina = gr.Dropdown(
96
+ choices=["Y", "N"], label="Exercise Angina", value="N"
97
+ )
98
+ st_slope = gr.Dropdown(
99
+ choices=["Up", "Flat", "Down"], label="ST Slope", value="Flat"
100
+ )
101
+
102
+ predict_btn = gr.Button("Predict", variant="primary")
103
+ single_output = gr.Markdown(label="Prediction Result")
104
+
105
+ predict_btn.click(
106
+ fn=Wrapper.prediction,
107
+ inputs=[
108
+ age,
109
+ chest_pain_type,
110
+ resting_bp,
111
+ cholesterol,
112
+ fasting_bs,
113
+ resting_ecg,
114
+ max_hr,
115
+ exercise_angina,
116
+ oldpeak,
117
+ st_slope,
118
+ ],
119
+ outputs=single_output,
120
+ )
121
+
122
+ with gr.TabItem("Batch Prediction"):
123
+ gr.Markdown("### Upload a CSV file for batch predictions")
124
+ gr.Markdown(
125
+ "The CSV should contain columns: Age, ChestPainType, RestingBP, Cholesterol,"
126
+ + "FastingBS, RestingECG, MaxHR, ExerciseAngina, Oldpeak, ST_Slope"
127
+ )
128
+
129
+ file_input = gr.File(label="Upload CSV", file_types=[".csv"])
130
+ batch_predict_btn = gr.Button("Predict Batch", variant="primary")
131
+ batch_output = gr.Dataframe(label="Batch Prediction Results")
132
+
133
+ batch_predict_btn.click(
134
+ fn=Wrapper.batch_prediction, inputs=file_input, outputs=batch_output
135
+ )
136
+
137
+ with gr.TabItem("ModelCard"):
138
+ io = load_page(io, Wrapper.get_model_card)
139
+
140
+ with gr.TabItem("DatasetCard"):
141
+ io = load_page(io, Wrapper.get_dataset_card)
142
+
143
+ with gr.TabItem("Hyperparameters"):
144
+ gr.Markdown("## Model Hyperparameters")
145
+ io = load_page(io, Wrapper.get_hyperparameters)
146
+
147
+ with gr.TabItem("Evaluation Metrics"):
148
+ gr.Markdown("## Model Performance Metrics")
149
+ io = load_page(io, Wrapper.get_metrics)
150
+
151
+ app = gr.mount_gradio_app(app, io, path="/ui")
predicting_outcomes_in_heart_failure/app/routers/__pycache__/cards.cpython-311.pyc ADDED
Binary file (3 kB). View file
 
predicting_outcomes_in_heart_failure/app/routers/__pycache__/general.cpython-311.pyc ADDED
Binary file (1.13 kB). View file
 
predicting_outcomes_in_heart_failure/app/routers/__pycache__/model_info.cpython-311.pyc ADDED
Binary file (4.21 kB). View file
 
predicting_outcomes_in_heart_failure/app/routers/__pycache__/prediction.cpython-311.pyc ADDED
Binary file (6.46 kB). View file
 
predicting_outcomes_in_heart_failure/app/routers/cards.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from http import HTTPStatus
2
+
3
+ from fastapi import APIRouter, HTTPException, Request
4
+ from loguru import logger
5
+ from predicting_outcomes_in_heart_failure.app.utils import construct_response
6
+ from predicting_outcomes_in_heart_failure.config import CARD_PATHS
7
+
8
+ router = APIRouter(tags=["Cards"])
9
+
10
+
11
+ @router.get("/cards/{card_type}")
12
+ @construct_response
13
+ def card(request: Request, card_type: str):
14
+ """Return card information.
15
+ card_type = dataset_card / model_card
16
+ """
17
+ logger.info(f"Received /cards/{card_type} request")
18
+
19
+ # Normalizza il card_type per gestire eventuali varianti
20
+ card_type = card_type.lower().replace("-", "_")
21
+
22
+ path = CARD_PATHS.get(card_type)
23
+ if path is None:
24
+ logger.warning(f"Unsupported card_type requested: {card_type}")
25
+ raise HTTPException(
26
+ status_code=HTTPStatus.NOT_FOUND,
27
+ detail=f"Card type '{card_type}' not supported."
28
+ + f" Valid types: {', '.join(CARD_PATHS.keys())}",
29
+ )
30
+
31
+ try:
32
+ with open(path, encoding="utf-8") as f:
33
+ card_content = f.read()
34
+
35
+ logger.success(f"{path} loaded successfully")
36
+
37
+ return {
38
+ "message": HTTPStatus.OK.phrase,
39
+ "status-code": HTTPStatus.OK.value,
40
+ "data": {
41
+ "card_type": card_type,
42
+ "path": str(path),
43
+ "card_lines": card_content.split("\n"),
44
+ },
45
+ }
46
+
47
+ except Exception as e:
48
+ logger.exception(f"Failed to load card content from {path}: {e}")
49
+ raise HTTPException(
50
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
51
+ detail=f"Error reading card file: {e}",
52
+ ) from e
predicting_outcomes_in_heart_failure/app/routers/general.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from http import HTTPStatus
2
+
3
+ from fastapi import APIRouter, Request
4
+ from loguru import logger
5
+ from predicting_outcomes_in_heart_failure.app.utils import construct_response
6
+
7
+ router = APIRouter(tags=["General"])
8
+
9
+
10
+ @router.get("/")
11
+ @construct_response
12
+ def index(request: Request):
13
+ """Root endpoint."""
14
+ logger.info("General requested")
15
+ return {
16
+ "message": HTTPStatus.OK.phrase,
17
+ "status-code": HTTPStatus.OK,
18
+ "data": {"message": "Welcome to Heart Failure Predictor!"},
19
+ }
predicting_outcomes_in_heart_failure/app/routers/model_info.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from http import HTTPStatus
2
+ import json
3
+ from typing import Any
4
+
5
+ from fastapi import APIRouter, Request
6
+ from loguru import logger
7
+ from predicting_outcomes_in_heart_failure.app.utils import construct_response
8
+ from predicting_outcomes_in_heart_failure.config import (
9
+ MODEL_PATH,
10
+ REPORTS_DIR,
11
+ TEST_METRICS_DIR,
12
+ )
13
+
14
+ router = APIRouter(tags=["Model"])
15
+
16
+
17
+ @router.get("/model/hyperparameters")
18
+ @construct_response
19
+ def get_model_hyperparameters(request: Request):
20
+ variant = MODEL_PATH.parent.name
21
+ model_name = MODEL_PATH.stem
22
+ hyperparams_path = REPORTS_DIR / variant / model_name / "cv_parameters.json"
23
+ logger.info(
24
+ f"Looking for hyperparameters file at {hyperparams_path} "
25
+ f"(model={model_name}, variant={variant})"
26
+ )
27
+
28
+ if not hyperparams_path.exists():
29
+ logger.warning("Hyperparameters file not found")
30
+ return {
31
+ "message": HTTPStatus.NOT_FOUND.phrase,
32
+ "status-code": HTTPStatus.NOT_FOUND,
33
+ "data": {
34
+ "detail": "Hyperparameters file not found. Run the training pipeline.",
35
+ "model_name": model_name,
36
+ "variant": variant,
37
+ "expected_path": str(hyperparams_path),
38
+ },
39
+ }
40
+
41
+ with hyperparams_path.open("r", encoding="utf-8") as f:
42
+ hyperparams_data = json.load(f)
43
+
44
+ data: dict[str, Any] = {
45
+ "model_path": str(MODEL_PATH),
46
+ "hyperparameters": hyperparams_data,
47
+ }
48
+
49
+ return {
50
+ "message": HTTPStatus.OK.phrase,
51
+ "status-code": HTTPStatus.OK,
52
+ "data": data,
53
+ }
54
+
55
+
56
+ @router.get("/model/metrics")
57
+ @construct_response
58
+ def get_model_metrics(request: Request):
59
+ variant = MODEL_PATH.parent.name
60
+ model_name = MODEL_PATH.stem
61
+ metrics_path = TEST_METRICS_DIR / variant / f"{model_name}.json"
62
+ logger.info(
63
+ f"Looking for metrics file at {metrics_path} (model={model_name}, variant={variant})"
64
+ )
65
+
66
+ if not metrics_path.exists():
67
+ logger.warning("Metrics file not found")
68
+ return {
69
+ "message": HTTPStatus.NOT_FOUND.phrase,
70
+ "status-code": HTTPStatus.NOT_FOUND,
71
+ "data": {
72
+ "detail": (
73
+ "Metrics file not found. Run the evaluation pipeline for this model first."
74
+ ),
75
+ "model_name": model_name,
76
+ "variant": variant,
77
+ "expected_path": str(metrics_path),
78
+ },
79
+ }
80
+
81
+ with metrics_path.open("r", encoding="utf-8") as f:
82
+ metrics_data = json.load(f)
83
+
84
+ data: dict[str, Any] = {
85
+ "model_path": str(MODEL_PATH),
86
+ "model_name": model_name,
87
+ "variant": variant,
88
+ "metrics": metrics_data.get("metrics", metrics_data),
89
+ }
90
+
91
+ return {
92
+ "message": HTTPStatus.OK.phrase,
93
+ "status-code": HTTPStatus.OK,
94
+ "data": data,
95
+ }
predicting_outcomes_in_heart_failure/app/routers/prediction.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from http import HTTPStatus
2
+ from typing import Any
3
+
4
+ from fastapi import APIRouter, Request
5
+ from loguru import logger
6
+ import pandas as pd
7
+ from predicting_outcomes_in_heart_failure.app.schema import HeartSample
8
+ from predicting_outcomes_in_heart_failure.app.utils import (
9
+ construct_response,
10
+ get_model_from_state,
11
+ )
12
+ from predicting_outcomes_in_heart_failure.config import FIGURES_DIR, MODEL_PATH
13
+ from predicting_outcomes_in_heart_failure.modeling.explainability import (
14
+ explain_prediction,
15
+ save_shap_waterfall_plot,
16
+ )
17
+ from predicting_outcomes_in_heart_failure.modeling.predict import preprocessing
18
+
19
+ router = APIRouter()
20
+
21
+
22
+ @router.post("/predictions", tags=["Prediction"])
23
+ @construct_response
24
+ def predict(request: Request, payload: HeartSample):
25
+ model = get_model_from_state(request)
26
+ if model is None:
27
+ return {
28
+ "message": HTTPStatus.SERVICE_UNAVAILABLE.phrase,
29
+ "status-code": HTTPStatus.SERVICE_UNAVAILABLE,
30
+ "data": {"detail": "Model is not loaded."},
31
+ }
32
+
33
+ X_raw = payload.to_dataframe()
34
+ X = preprocessing(X_raw)
35
+ y_pred = int(model.predict(X)[0])
36
+
37
+ data: dict[str, Any] = {
38
+ "input": payload.model_dump(),
39
+ "prediction": y_pred,
40
+ }
41
+
42
+ logger.success("Prediction completed successfully for /predictions")
43
+ return {
44
+ "message": HTTPStatus.OK.phrase,
45
+ "status-code": HTTPStatus.OK,
46
+ "data": data,
47
+ }
48
+
49
+
50
+ @router.post("/batch-predictions", tags=["Prediction"])
51
+ @construct_response
52
+ def predict_batch(request: Request, payload: list[HeartSample]):
53
+ model = get_model_from_state(request)
54
+ if model is None:
55
+ return {
56
+ "message": HTTPStatus.SERVICE_UNAVAILABLE.phrase,
57
+ "status-code": HTTPStatus.SERVICE_UNAVAILABLE,
58
+ "data": {"detail": "Model is not loaded."},
59
+ }
60
+
61
+ X_raw_list = [sample.to_dataframe() for sample in payload]
62
+ X_raw = pd.concat(X_raw_list, ignore_index=True)
63
+ X = preprocessing(X_raw)
64
+
65
+ y_pred = [int(y) for y in model.predict(X)]
66
+
67
+ results: list[dict[str, Any]] = []
68
+ for idx, (sample, pred) in enumerate(zip(payload, y_pred, strict=True)):
69
+ results.append(
70
+ {
71
+ "index": idx,
72
+ "input": sample.model_dump(),
73
+ "prediction": pred,
74
+ }
75
+ )
76
+
77
+ data: dict[str, Any] = {
78
+ "results": results,
79
+ "batch_size": len(results),
80
+ }
81
+
82
+ return {
83
+ "message": HTTPStatus.OK.phrase,
84
+ "status-code": HTTPStatus.OK,
85
+ "data": data,
86
+ }
87
+
88
+
89
+ @router.post("/explanations", tags=["Explainability"])
90
+ @construct_response
91
+ def explain(request: Request, payload: HeartSample):
92
+ model = get_model_from_state(request)
93
+ if model is None:
94
+ return {
95
+ "message": HTTPStatus.SERVICE_UNAVAILABLE.phrase,
96
+ "status-code": HTTPStatus.SERVICE_UNAVAILABLE,
97
+ "data": {"detail": "Model is not loaded."},
98
+ }
99
+
100
+ X_raw = payload.to_dataframe()
101
+ X = preprocessing(X_raw)
102
+
103
+ data: dict[str, Any] = {"input": payload.model_dump()}
104
+ model_type = MODEL_PATH.stem
105
+
106
+ try:
107
+ logger.info("Computing explanation for default model prediction...")
108
+ explanations = explain_prediction(model, X, model_type=model_type, top_k=5)
109
+ if explanations:
110
+ data["explanations"] = explanations
111
+ logger.success("Explanation computed successfully for default model.")
112
+ else:
113
+ logger.warning("No explanation available for default model.")
114
+ except Exception as e:
115
+ logger.exception(f"Failed to compute explanation: {e}")
116
+
117
+ try:
118
+ plot_path = FIGURES_DIR / f"shap_waterfall_default_{model_type}.png"
119
+ saved_path = save_shap_waterfall_plot(
120
+ model=model,
121
+ X=X,
122
+ model_type=model_type,
123
+ output_path=plot_path,
124
+ )
125
+ if saved_path is not None:
126
+ data["explanation_plot_url"] = f"/figures/{saved_path.name}"
127
+ except Exception as e:
128
+ logger.exception(f"Failed to generate explanation plot: {e}")
129
+
130
+ logger.success("Explanation completed successfully for /explanations")
131
+ return {
132
+ "message": HTTPStatus.OK.phrase,
133
+ "status-code": HTTPStatus.OK,
134
+ "data": data,
135
+ }
predicting_outcomes_in_heart_failure/app/schema.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pydantic import BaseModel, field_validator
8
+
9
+
10
+ class HeartSample(BaseModel):
11
+ Age: int
12
+ ChestPainType: Literal["TA", "ATA", "NAP", "ASY"]
13
+ RestingBP: int
14
+ Cholesterol: int
15
+ FastingBS: int
16
+ RestingECG: Literal["Normal", "ST", "LVH"]
17
+ MaxHR: int
18
+ ExerciseAngina: Literal["Y", "N"]
19
+ Oldpeak: float
20
+ ST_Slope: Literal["Up", "Flat", "Down"]
21
+
22
+ @field_validator("Oldpeak")
23
+ @classmethod
24
+ def round_oldpeak(cls, v: float) -> float:
25
+ return float(np.round(v, 2))
26
+
27
+ def to_dataframe(self) -> pd.DataFrame:
28
+ return pd.DataFrame([self.model_dump()])
predicting_outcomes_in_heart_failure/app/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import wraps
3
+
4
+ from fastapi import Request
5
+ import gradio as gr
6
+ from loguru import logger
7
+
8
+
9
+ def construct_response(f):
10
+ """Construct a JSON response for an endpoint's results."""
11
+
12
+ @wraps(f)
13
+ def wrap(request: Request, *args, **kwargs):
14
+ result = f(request, *args, **kwargs)
15
+ response = {
16
+ "message": result["message"],
17
+ "method": request.method,
18
+ "status-code": result["status-code"],
19
+ "timestamp": datetime.now().isoformat(),
20
+ "url": request.url._url,
21
+ }
22
+ if "data" in result:
23
+ response["data"] = result["data"]
24
+ return response
25
+
26
+ return wrap
27
+
28
+
29
+ def get_model_from_state(request: Request):
30
+ """Retrieve the model from the app state."""
31
+ model = getattr(request.app.state, "model", None)
32
+ if model is None:
33
+ logger.error("Model not loaded in app.state.model")
34
+ return model
35
+
36
+
37
+ def load_page(io, fn):
38
+ model_card_content = gr.Markdown("Loading...")
39
+
40
+ io.load(fn=fn, inputs=None, outputs=model_card_content)
41
+ return io
predicting_outcomes_in_heart_failure/app/wrapper.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ from loguru import logger
3
+ import pandas as pd
4
+
5
+ from predicting_outcomes_in_heart_failure.config import API_URL
6
+
7
+
8
+ async def _fetch_api(endpoint: str):
9
+ async with httpx.AsyncClient() as client:
10
+ try:
11
+ response = await client.get(f"{API_URL}/{endpoint}")
12
+ response.raise_for_status()
13
+ return response.json()
14
+ except Exception as e:
15
+ logger.error(f"Error fetching {endpoint}: {e}")
16
+ return {"error": str(e)}
17
+
18
+
19
+ class Wrapper:
20
+ async def prediction(
21
+ age,
22
+ chest_pain_type,
23
+ resting_bp,
24
+ cholesterol,
25
+ fasting_bs,
26
+ resting_ecg,
27
+ max_hr,
28
+ exercise_angina,
29
+ oldpeak,
30
+ st_slope,
31
+ ):
32
+ async with httpx.AsyncClient() as client:
33
+ try:
34
+ payload = {
35
+ "Age": age,
36
+ "ChestPainType": chest_pain_type,
37
+ "RestingBP": resting_bp,
38
+ "Cholesterol": cholesterol,
39
+ "FastingBS": fasting_bs,
40
+ "RestingECG": resting_ecg,
41
+ "MaxHR": max_hr,
42
+ "ExerciseAngina": exercise_angina,
43
+ "Oldpeak": round(oldpeak, 2),
44
+ "ST_Slope": st_slope,
45
+ }
46
+ response = await client.post(f"{API_URL}/predictions", json=payload)
47
+ response.raise_for_status()
48
+ result = response.json()
49
+
50
+ prediction_value = result["data"]["prediction"]
51
+ result = "🆘" if prediction_value == 1 else "✅"
52
+
53
+ return f"# Patient's status: {result}"
54
+ except Exception as e:
55
+ logger.error(f"Error making prediction: {e}")
56
+ return f"Error: {str(e)}"
57
+
58
+ async def batch_prediction(file):
59
+ async with httpx.AsyncClient(timeout=30.0) as client:
60
+ try:
61
+ df = pd.read_csv(file)
62
+
63
+ payload = []
64
+ for _, row in df.iterrows():
65
+ sample = {
66
+ "Age": int(row["Age"]),
67
+ "ChestPainType": row["ChestPainType"],
68
+ "RestingBP": int(row["RestingBP"]),
69
+ "Cholesterol": int(row["Cholesterol"]),
70
+ "FastingBS": int(row["FastingBS"]),
71
+ "RestingECG": row["RestingECG"],
72
+ "MaxHR": int(row["MaxHR"]),
73
+ "ExerciseAngina": row["ExerciseAngina"],
74
+ "Oldpeak": round(float(row["Oldpeak"]), 2),
75
+ "ST_Slope": row["ST_Slope"],
76
+ }
77
+ payload.append(sample)
78
+
79
+ response = await client.post(f"{API_URL}/batch-predictions", json=payload)
80
+ response.raise_for_status()
81
+ result = response.json()
82
+
83
+ results = result["data"]["results"]
84
+ df_results = pd.DataFrame(
85
+ [
86
+ {
87
+ "Patients's index": r["index"],
88
+ "Patient's status": "🆘" if r["prediction"] == 1 else "✅",
89
+ }
90
+ for r in results
91
+ ]
92
+ )
93
+
94
+ return df_results
95
+ except Exception as e:
96
+ logger.error(f"Error making batch prediction: {e}")
97
+ return pd.DataFrame({"error": [str(e)]})
98
+
99
+ async def get_model_card():
100
+ data = await _fetch_api("cards/model_card")
101
+
102
+ card_lines = data.get("data").get("card_lines")
103
+ return "\n".join(card_lines)
104
+
105
+ async def get_dataset_card():
106
+ data = await _fetch_api("cards/dataset_card")
107
+
108
+ card_lines = data.get("data").get("card_lines")
109
+ return "\n".join(card_lines)
110
+
111
+ async def get_hyperparameters():
112
+ data = await _fetch_api("model/hyperparameters")
113
+ if "error" in data:
114
+ return f"## Error\n{data['error']}"
115
+
116
+ data = data.get("data", {}).get("hyperparameters", {}).get("cv", {})
117
+
118
+ md = ""
119
+ for key, value in data.items():
120
+ md += f"- **{key}**: {value}\n"
121
+ return md
122
+
123
+ async def get_metrics():
124
+ data = await _fetch_api("model/metrics")
125
+ if "error" in data:
126
+ return f"## Error\n{data['error']}"
127
+
128
+ metrics = data.get("data", {}).get("metrics", {})
129
+ if not metrics:
130
+ return "## No metrics found"
131
+
132
+ md = ""
133
+ for key, value in metrics.items():
134
+ md += f"- **{key}**: {value:.4f}\n"
135
+ return md
predicting_outcomes_in_heart_failure/config.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from dotenv import load_dotenv
4
+ from loguru import logger
5
+
6
+ load_dotenv()
7
+
8
+ # -------------------
9
+ # Experiment settings
10
+ # -------------------
11
+ VALID_VARIANTS = ["all", "female", "male", "nosex"]
12
+ VALID_MODELS = ["logreg", "random_forest", "decision_tree"]
13
+ EXPERIMENT_NAME = "Heart_Failure_Prediction"
14
+ DATASET_NAME = "fedesoriano/heart-failure-prediction"
15
+ TARGET_COL = "HeartDisease"
16
+ RANDOM_STATE = 42
17
+ TEST_SIZE = 0.30
18
+ N_SPLITS = 5
19
+ SCORING = {
20
+ "accuracy": "accuracy",
21
+ "f1": "f1",
22
+ "recall": "recall",
23
+ "roc_auc": "roc_auc",
24
+ }
25
+
26
+ NUM_COLS_DEFAULT = ["Age", "RestingBP", "Cholesterol", "MaxHR", "Oldpeak"]
27
+ CAT_COLS_DEFAULT = [
28
+ "Sex",
29
+ "ChestPainType",
30
+ "FastingBS",
31
+ "RestingECG",
32
+ "ExerciseAngina",
33
+ "ST_Slope",
34
+ ]
35
+ MULTI_CAT = ["ChestPainType", "RestingECG", "ST_Slope"]
36
+
37
+ INPUT_COLUMNS = [
38
+ "Age",
39
+ "RestingBP",
40
+ "Cholesterol",
41
+ "FastingBS",
42
+ "MaxHR",
43
+ "ExerciseAngina",
44
+ "Oldpeak",
45
+ "ChestPainType_ASY",
46
+ "ChestPainType_ATA",
47
+ "ChestPainType_NAP",
48
+ "ChestPainType_TA",
49
+ "RestingECG_LVH",
50
+ "RestingECG_Normal",
51
+ "RestingECG_ST",
52
+ "ST_Slope_Down",
53
+ "ST_Slope_Flat",
54
+ "ST_Slope_Up",
55
+ ]
56
+ # ----------------------------
57
+ # Model hyperparameter configurations
58
+ # ----------------------------
59
+ CONFIG_RF = {
60
+ "n_estimators": [200, 400, 800],
61
+ "max_depth": [None, 6, 12],
62
+ "class_weight": [None, "balanced"],
63
+ }
64
+ CONFIG_DT = {
65
+ "criterion": ["gini", "entropy", "log_loss"],
66
+ "max_depth": [None, 3, 5, 7, 9, 12],
67
+ "min_samples_split": [2, 5, 10, 20],
68
+ "min_samples_leaf": [1, 2, 4, 8],
69
+ "max_features": [None, "sqrt", "log2"],
70
+ "class_weight": [None, "balanced"],
71
+ "ccp_alpha": [0.0, 0.001, 0.01],
72
+ }
73
+ CONFIG_LR = {"C": [0.01, 0.1, 1, 10], "penalty": ["l2"], "class_weight": [None, "balanced"]}
74
+
75
+ # ----------------------------
76
+ # Repository info
77
+ # ----------------------------
78
+ REPO_OWNER = "se4ai2526-uniba"
79
+ REPO_NAME = "CardioTrack"
80
+
81
+ # ----------------------------
82
+ # Great Expectations
83
+ # ----------------------------
84
+ SOURCE_NAME = "heart_data_source"
85
+ ASSET_NAME = "heart_failure"
86
+ SUITE_NAME = "heart_failure_data_quality"
87
+
88
+ # ----------------------------
89
+ # Paths
90
+ # ----------------------------
91
+ PROJ_ROOT = Path(__file__).resolve().parents[1]
92
+ logger.info(f"PROJ_ROOT path is: {PROJ_ROOT}")
93
+
94
+ DATA_DIR = PROJ_ROOT / "data"
95
+ INTERIM_DATA_DIR = DATA_DIR / "interim"
96
+ PROCESSED_DATA_DIR = DATA_DIR / "processed"
97
+ RAW_DATA_DIR = DATA_DIR / "raw"
98
+ EXTERNAL_DATA_DIR = DATA_DIR / "external"
99
+
100
+ RAW_PATH = RAW_DATA_DIR / "heart.csv"
101
+ PREPROCESSED_CSV = INTERIM_DATA_DIR / "preprocessed.csv"
102
+ TRAIN_CSV = PROCESSED_DATA_DIR / "train.csv"
103
+ TEST_CSV = PROCESSED_DATA_DIR / "test.csv"
104
+
105
+ MODELS_DIR = PROJ_ROOT / "models"
106
+ REPORTS_DIR = PROJ_ROOT / "reports"
107
+ FIGURES_DIR = REPORTS_DIR / "figures"
108
+
109
+ METRICS_DIR = PROJ_ROOT / "metrics"
110
+ TEST_METRICS_DIR = METRICS_DIR / "test"
111
+
112
+ NOSEX_CSV = INTERIM_DATA_DIR / "preprocessed_no_sex_column.csv"
113
+ MALE_CSV = INTERIM_DATA_DIR / "preprocessed_male_only.csv"
114
+ FEMALE_CSV = INTERIM_DATA_DIR / "preprocessed_female_only.csv"
115
+
116
+ PREPROCESS_ARTIFACTS_DIR = INTERIM_DATA_DIR / "preprocess_artifacts"
117
+ SCALER_PATH = PREPROCESS_ARTIFACTS_DIR / "scaler.joblib"
118
+
119
+ MODEL_PATH = Path("models/nosex/random_forest.joblib")
120
+
121
+ CARD_PATHS = {
122
+ "dataset_card": DATA_DIR / "README.md",
123
+ "model_card": MODELS_DIR / "README.md",
124
+ }
125
+
126
+ # ----------------------------
127
+ # API
128
+ # ----------------------------
129
+ API_URL = "http://localhost:8000"
predicting_outcomes_in_heart_failure/data/dataset.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import kagglehub
5
+ from loguru import logger
6
+ from predicting_outcomes_in_heart_failure.config import DATASET_NAME, RAW_DATA_DIR
7
+ import typer
8
+
9
+ app = typer.Typer()
10
+
11
+
12
+ @app.command()
13
+ def main():
14
+ logger.info("Downloading dataset from Kaggle...")
15
+ os.makedirs(RAW_DATA_DIR, exist_ok=True)
16
+ path = kagglehub.dataset_download(DATASET_NAME)
17
+ shutil.copytree(path, RAW_DATA_DIR, dirs_exist_ok=True)
18
+ logger.success("Dataset downloaded successfully to {RAW_DATA_DIR}.")
19
+
20
+
21
+ if __name__ == "__main__":
22
+ app()
predicting_outcomes_in_heart_failure/data/preprocess.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from loguru import logger
3
+ import pandas as pd
4
+ from predicting_outcomes_in_heart_failure.config import (
5
+ FEMALE_CSV,
6
+ INTERIM_DATA_DIR,
7
+ MALE_CSV,
8
+ NOSEX_CSV,
9
+ NUM_COLS_DEFAULT,
10
+ PREPROCESS_ARTIFACTS_DIR,
11
+ PREPROCESSED_CSV,
12
+ RAW_PATH,
13
+ SCALER_PATH,
14
+ TARGET_COL,
15
+ )
16
+ from sklearn.preprocessing import StandardScaler
17
+
18
+
19
+ def save_scaler_artifact(scaler: StandardScaler):
20
+ """Save only the fitted scaler used during preprocessing for inference."""
21
+ PREPROCESS_ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
22
+ joblib.dump(scaler, SCALER_PATH)
23
+ logger.success(f"Saved StandardScaler to {SCALER_PATH}")
24
+
25
+
26
+ def generate_gender_splits(df: pd.DataFrame):
27
+ """Create and save gender-based CSV splits (female, male, nosex)."""
28
+ if "Sex" in df.columns:
29
+ df_female = df[df["Sex"] == 0].copy()
30
+ df_female.to_csv(FEMALE_CSV, index=False)
31
+ logger.success(f"Saved female-only dataset: {FEMALE_CSV} (rows={len(df_female)})")
32
+
33
+ if "Sex" in df.columns:
34
+ df_male = df[df["Sex"] == 1].copy()
35
+ df_male.to_csv(MALE_CSV, index=False)
36
+ logger.success(f"Saved male-only dataset: {MALE_CSV} (rows={len(df_male)})")
37
+
38
+ df_nosex = df.drop(columns=["Sex"], errors="ignore").copy()
39
+ df_nosex.to_csv(NOSEX_CSV, index=False)
40
+ logger.success(f"Saved dataset without 'Sex': {NOSEX_CSV} (rows={len(df_nosex)})")
41
+
42
+
43
+ def preprocessing():
44
+ """Run the full preprocessing pipeline on the raw heart dataset."""
45
+ logger.info("Starting preprocessing pipeline...")
46
+
47
+ if not RAW_PATH.exists():
48
+ logger.error(f"Missing {RAW_PATH}. Put heart.csv under data/raw/ or adjust RAW_PATH.")
49
+ raise FileNotFoundError(f"Missing {RAW_PATH}.")
50
+
51
+ df = pd.read_csv(RAW_PATH)
52
+ logger.info(f"Loaded dataset: {RAW_PATH} (rows={len(df)}, cols={df.shape[1]})")
53
+
54
+ if len(df) < 2:
55
+ raise ValueError("Preprocessing requires at least 2 rows, got only 1.")
56
+
57
+ # Ensure target is integer
58
+ df[TARGET_COL] = df[TARGET_COL].astype(int)
59
+
60
+ # Remove invalid RestingBP rows
61
+ if "RestingBP" in df.columns:
62
+ before = len(df)
63
+ df = df[df["RestingBP"] != 0].reset_index(drop=True)
64
+ removed = before - len(df)
65
+ if removed > 0:
66
+ logger.warning(f"Removed {removed} rows with RestingBP == 0")
67
+
68
+ # Impute missing/zero Cholesterol
69
+ if "Cholesterol" in df.columns:
70
+ zero_mask = df["Cholesterol"] == 0
71
+ if zero_mask.any():
72
+ median_chol = df.loc[~zero_mask, "Cholesterol"].median()
73
+ df.loc[zero_mask, "Cholesterol"] = median_chol
74
+ logger.info(f"Imputed {zero_mask.sum()} Cholesterol==0 with median={median_chol}")
75
+
76
+ # Encode binary categorical features
77
+ if "Sex" in df.columns:
78
+ df["Sex"] = df["Sex"].map({"M": 1, "F": 0}).astype(int)
79
+ logger.debug("Encoded 'Sex' as binary.")
80
+
81
+ if "ExerciseAngina" in df.columns:
82
+ df["ExerciseAngina"] = df["ExerciseAngina"].map({"Y": 1, "N": 0}).astype(int)
83
+ logger.debug("Encoded 'ExerciseAngina' as binary.")
84
+
85
+ # One-hot encode multi-category features
86
+ multi_cat = [c for c in ["ChestPainType", "RestingECG", "ST_Slope"] if c in df.columns]
87
+ df = pd.get_dummies(df, columns=multi_cat, drop_first=False)
88
+ logger.debug(f"One-hot encoded columns: {multi_cat}")
89
+
90
+ # Scale numerical columns
91
+ num_cols = [c for c in NUM_COLS_DEFAULT if c in df.columns and c != TARGET_COL]
92
+ scaler = StandardScaler()
93
+ df[num_cols] = scaler.fit_transform(df[num_cols])
94
+ logger.info(f"Scaled numerical features: {num_cols}")
95
+
96
+ # Save processed dataset
97
+ df.to_csv(PREPROCESSED_CSV, index=False)
98
+ logger.success(
99
+ "Saved preprocessed dataset: %s (rows=%d, cols=%d)", PREPROCESSED_CSV, len(df), df.shape[1]
100
+ )
101
+
102
+ # Log class distribution
103
+ count_0 = (df[TARGET_COL] == 0).sum()
104
+ count_1 = (df[TARGET_COL] == 1).sum()
105
+ logger.info(f"Target balance — 0: {count_0} | 1: {count_1}")
106
+
107
+ save_scaler_artifact(scaler)
108
+
109
+ logger.success("Preprocessing completed successfully.")
110
+ return df
111
+
112
+
113
+ if __name__ == "__main__":
114
+ INTERIM_DATA_DIR.mkdir(parents=True, exist_ok=True)
115
+ df_processed = preprocessing()
116
+ generate_gender_splits(df_processed)
predicting_outcomes_in_heart_failure/data/split_data.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ from loguru import logger
5
+ import pandas as pd
6
+ from predicting_outcomes_in_heart_failure.config import (
7
+ FEMALE_CSV,
8
+ MALE_CSV,
9
+ NOSEX_CSV,
10
+ PREPROCESSED_CSV,
11
+ PROCESSED_DATA_DIR,
12
+ RANDOM_STATE,
13
+ TARGET_COL,
14
+ TEST_SIZE,
15
+ )
16
+ from sklearn.model_selection import train_test_split
17
+
18
+ VARIANTS = {
19
+ "all": PREPROCESSED_CSV,
20
+ "female": FEMALE_CSV,
21
+ "male": MALE_CSV,
22
+ "nosex": NOSEX_CSV,
23
+ }
24
+
25
+
26
+ def _safe_train_test_split(X, y, test_size, random_state):
27
+ """Perform a stratified train/test split with fallback if not possible."""
28
+ stratify_y = y if y.nunique() > 1 else None
29
+ try:
30
+ X_tr, X_te, y_tr, y_te = train_test_split(
31
+ X,
32
+ y,
33
+ test_size=test_size,
34
+ stratify=stratify_y,
35
+ random_state=random_state,
36
+ shuffle=True,
37
+ )
38
+ if stratify_y is None:
39
+ logger.warning("Target has only one class — performing non-stratified split.")
40
+ else:
41
+ logger.debug("Stratified split executed successfully.")
42
+ return X_tr, X_te, y_tr, y_te
43
+ except ValueError as e:
44
+ logger.warning(f"Stratified split failed ({e}). Falling back to non-stratified split.")
45
+ return train_test_split(
46
+ X,
47
+ y,
48
+ test_size=test_size,
49
+ stratify=None,
50
+ random_state=random_state,
51
+ shuffle=True,
52
+ )
53
+
54
+
55
+ def split_one(csv_path: Path, variant: str):
56
+ """Split a specific variant (all/female/male/nosex) into train/test sets."""
57
+ if not csv_path.exists():
58
+ logger.warning(f"[{variant}] Missing CSV file: {csv_path} — skipping.")
59
+ return
60
+
61
+ df = pd.read_csv(csv_path)
62
+ logger.info(f"[{variant}] Loaded {csv_path} (rows={len(df)}, cols={df.shape[1]})")
63
+
64
+ if TARGET_COL not in df.columns:
65
+ raise ValueError(f"[{variant}] Target column '{TARGET_COL}' not found in {csv_path}")
66
+
67
+ X = df.drop(columns=[TARGET_COL])
68
+ y = df[TARGET_COL].astype(int)
69
+
70
+ X_train, X_test, y_train, y_test = _safe_train_test_split(X, y, TEST_SIZE, RANDOM_STATE)
71
+
72
+ train_df = X_train.copy()
73
+ train_df[TARGET_COL] = y_train.values
74
+ test_df = X_test.copy()
75
+ test_df[TARGET_COL] = y_test.values
76
+
77
+ out_dir = PROCESSED_DATA_DIR / variant
78
+ out_dir.mkdir(parents=True, exist_ok=True)
79
+ train_p = out_dir / "train.csv"
80
+ test_p = out_dir / "test.csv"
81
+
82
+ train_df.to_csv(train_p, index=False)
83
+ test_df.to_csv(test_p, index=False)
84
+
85
+ logger.success(f"[{variant}] Saved TRAIN -> {train_p} (rows={len(train_df)})")
86
+ logger.success(f"[{variant}] Saved TEST -> {test_p} (rows={len(test_df)})")
87
+
88
+ train_counts = train_df[TARGET_COL].value_counts().to_dict()
89
+ test_counts = test_df[TARGET_COL].value_counts().to_dict()
90
+ logger.info(f"[{variant}] Class distribution — TRAIN: {train_counts} | TEST: {test_counts}")
91
+
92
+
93
+ def main():
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument(
96
+ "--variant",
97
+ type=str,
98
+ choices=list(VARIANTS.keys()),
99
+ required=True,
100
+ help="Data variant to split: all, female, male, or nosex.",
101
+ )
102
+ args = parser.parse_args()
103
+
104
+ variant = args.variant
105
+ csv_path = VARIANTS[variant]
106
+
107
+ logger.info(f"Starting splitting for variant='{variant}'")
108
+ PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
109
+ split_one(csv_path, variant)
110
+ logger.success(f"Splitting completed for variant='{variant}'")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
predicting_outcomes_in_heart_failure/modeling/__init__.py ADDED
File without changes
predicting_outcomes_in_heart_failure/modeling/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (221 Bytes). View file
 
predicting_outcomes_in_heart_failure/modeling/__pycache__/explainability.cpython-311.pyc ADDED
Binary file (9.73 kB). View file
 
predicting_outcomes_in_heart_failure/modeling/__pycache__/predict.cpython-311.pyc ADDED
Binary file (7.49 kB). View file
 
predicting_outcomes_in_heart_failure/modeling/evaluate.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import dagshub
6
+ import joblib
7
+ from loguru import logger
8
+ import mlflow
9
+ from mlflow.models.signature import infer_signature
10
+ from sklearn.metrics import accuracy_score, f1_score, recall_score, roc_auc_score
11
+
12
+ from predicting_outcomes_in_heart_failure.config import (
13
+ DATASET_NAME,
14
+ EXPERIMENT_NAME,
15
+ MODELS_DIR,
16
+ PROCESSED_DATA_DIR,
17
+ REPO_NAME,
18
+ REPO_OWNER,
19
+ TARGET_COL,
20
+ TEST_METRICS_DIR,
21
+ VALID_MODELS,
22
+ VALID_VARIANTS,
23
+ )
24
+ from predicting_outcomes_in_heart_failure.modeling.train import load_split
25
+
26
+
27
+ def compute_metrics(model, X_test, y_test) -> dict:
28
+ """Compute evaluation metrics (F1, recall, accuracy, ROC-AUC)."""
29
+ y_pred = model.predict(X_test)
30
+ results = {
31
+ "test_f1": f1_score(y_test, y_pred, zero_division=0),
32
+ "test_recall": recall_score(y_test, y_pred, zero_division=0),
33
+ "test_accuracy": accuracy_score(y_test, y_pred),
34
+ }
35
+ if hasattr(model, "predict_proba"):
36
+ try:
37
+ y_prob = model.predict_proba(X_test)[:, 1]
38
+ results["test_roc_auc"] = roc_auc_score(y_test, y_prob)
39
+ except Exception as e:
40
+ logger.warning(f"ROC AUC not computed: {e}")
41
+ return results, y_pred
42
+
43
+
44
+ def evaluate_variant(variant: str, model_name: str | None = None):
45
+ """Evaluate trained models for a given variant, optionally by model."""
46
+ logger.info(f"=== Evaluation started (variant={variant}, model={model_name or 'ALL'}) ===")
47
+
48
+ test_path = PROCESSED_DATA_DIR / variant / "test.csv"
49
+ test_df = load_split(test_path)
50
+
51
+ X_test = test_df.drop(columns=[TARGET_COL])
52
+ y_test = test_df[TARGET_COL].astype(int)
53
+
54
+ models_dir_variant = MODELS_DIR / variant
55
+ if not models_dir_variant.exists():
56
+ logger.warning(
57
+ f"[{variant}] Models directory does not exist: {models_dir_variant} — skipping."
58
+ )
59
+ return
60
+
61
+ experiment_name = f"{EXPERIMENT_NAME}_{variant}"
62
+ experiment = mlflow.get_experiment_by_name(experiment_name)
63
+ if experiment is None:
64
+ logger.error(f"Experiment '{experiment_name}' not found.")
65
+ return
66
+
67
+ model_files = []
68
+ if model_name is not None:
69
+ model_files = [f"{model_name}.joblib"]
70
+ else:
71
+ model_files = [f for f in os.listdir(models_dir_variant) if f.endswith(".joblib")]
72
+
73
+ for file in model_files:
74
+ if not file.endswith(".joblib"):
75
+ continue
76
+
77
+ current_model_name = file.split(".joblib")[0]
78
+ run_name = f"{current_model_name}_{variant}"
79
+ logger.info(
80
+ f"[{variant} | {current_model_name}] Looking for training run '{run_name}' in MLflow."
81
+ )
82
+
83
+ runs = mlflow.search_runs(
84
+ experiment_ids=[experiment.experiment_id],
85
+ filter_string=f"tags.mlflow.runName = '{run_name}'",
86
+ order_by=["start_time DESC"],
87
+ max_results=1,
88
+ )
89
+
90
+ if runs.empty:
91
+ logger.warning(
92
+ f"[{variant} | {current_model_name}]No matching MLflow run found — skipping."
93
+ )
94
+ continue
95
+
96
+ tracked_id = runs.loc[0, "run_id"]
97
+
98
+ with mlflow.start_run(run_id=tracked_id):
99
+ rawdata = mlflow.data.from_pandas(test_df, name=f"{DATASET_NAME}_{variant}_test")
100
+ mlflow.log_input(rawdata, context="testing")
101
+
102
+ model_path = models_dir_variant / file
103
+ model = joblib.load(model_path)
104
+
105
+ metrics, _ = compute_metrics(model, X_test, y_test)
106
+ mlflow.log_metrics(metrics)
107
+
108
+ logger.info(f"[{variant} | {current_model_name}] Test set metrics:")
109
+ for k in ["test_f1", "test_recall", "test_accuracy", "test_roc_auc"]:
110
+ if k in metrics:
111
+ logger.info(f" - {k}: {metrics[k]:.4f}")
112
+
113
+ metrics_dir = TEST_METRICS_DIR / variant
114
+ metrics_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ metrics_path = metrics_dir / f"{current_model_name}.json"
117
+
118
+ to_save = {
119
+ "variant": variant,
120
+ "model_name": current_model_name,
121
+ "metrics": metrics,
122
+ }
123
+
124
+ with open(metrics_path, "w", encoding="utf-8") as f:
125
+ json.dump(to_save, f, indent=4)
126
+
127
+ logger.info(
128
+ f"[{variant} | {current_model_name}] Saved test metrics locally → {metrics_path}"
129
+ )
130
+
131
+ if (
132
+ metrics.get("test_f1", 0.0) >= 0.80
133
+ and metrics.get("test_recall", 0.0) >= 0.80
134
+ and metrics.get("test_accuracy", 0.0) >= 0.80
135
+ and metrics.get("test_roc_auc", 0.0) >= 0.85
136
+ ):
137
+ signature = infer_signature(X_test, model.predict(X_test))
138
+ registered_name = f"{current_model_name}_{variant}"
139
+ mlflow.sklearn.log_model(
140
+ sk_model=model,
141
+ artifact_path="Model_Info",
142
+ signature=signature,
143
+ input_example=X_test,
144
+ registered_model_name=registered_name,
145
+ )
146
+ logger.success(
147
+ f"[{variant} | {current_model_name}] "
148
+ f"Model promoted and registered as '{registered_name}'."
149
+ )
150
+
151
+ logger.success(
152
+ f"=== Evaluation completed (variant={variant}, model={model_name or 'ALL'}) ==="
153
+ )
154
+
155
+
156
+ def main():
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument(
159
+ "--variant",
160
+ type=str,
161
+ choices=VALID_VARIANTS,
162
+ required=True,
163
+ help="Data variant to use: all, female, male, or nosex.",
164
+ )
165
+ parser.add_argument(
166
+ "--model",
167
+ type=str,
168
+ choices=VALID_MODELS,
169
+ required=False,
170
+ help=(
171
+ "Specific model to evaluate (logreg, random_forest, decision_tree)."
172
+ " If omitted, evaluate all models."
173
+ ),
174
+ )
175
+ args = parser.parse_args()
176
+
177
+ dagshub.init(repo_owner=REPO_OWNER, repo_name=REPO_NAME, mlflow=True)
178
+ evaluate_variant(args.variant, args.model)
179
+
180
+
181
+ if __name__ == "__main__":
182
+ main()
predicting_outcomes_in_heart_failure/modeling/explainability.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import matplotlib
7
+
8
+ matplotlib.use("Agg")
9
+
10
+ from loguru import logger
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ import shap
15
+
16
+
17
+ def explain_prediction(
18
+ model: Any,
19
+ X: pd.DataFrame,
20
+ model_type: str,
21
+ top_k: int = 5,
22
+ ):
23
+ """
24
+ Build a explanation for a single sample.
25
+ """
26
+
27
+ if X.empty:
28
+ logger.warning("Received empty DataFrame for explanation; returning empty list.")
29
+ return []
30
+
31
+ model_type = model_type.lower()
32
+ x = X.iloc[[0]]
33
+ feature_names = x.columns.tolist()
34
+
35
+ # ---------------------------------------------------------------------
36
+ # 1) Logistic Regression → use coefficients
37
+ # ---------------------------------------------------------------------
38
+ if model_type in ("logreg", "logistic_regression"):
39
+ logger.info("Using coefficient-based explanation for Logistic Regression.")
40
+
41
+ if not hasattr(model, "coef_"):
42
+ logger.error(
43
+ "Model has no coef_ attribute;cannot build coefficient-based explanation."
44
+ )
45
+ return []
46
+
47
+ coef = np.asarray(model.coef_[0]).reshape(-1)
48
+ if coef.shape[0] != len(feature_names):
49
+ logger.warning(
50
+ f"Coefficient vector length ({coef.shape[0]}) does not match "
51
+ f"number of features ({len(feature_names)}). "
52
+ "Truncating to minimum length."
53
+ )
54
+
55
+ n = min(len(feature_names), coef.shape[0])
56
+ explanations = [
57
+ {
58
+ "feature": feature_names[i],
59
+ "value": float(coef[i]),
60
+ "abs_value": float(abs(coef[i])),
61
+ }
62
+ for i in range(n)
63
+ ]
64
+
65
+ explanations = sorted(explanations, key=lambda d: d["abs_value"], reverse=True)[:top_k]
66
+ logger.info(
67
+ f"Built coefficient-based explanation. Returning top {len(explanations)} features."
68
+ )
69
+ return explanations
70
+
71
+ # ---------------------------------------------------------------------
72
+ # 2) Tree-based models → SHAP TreeExplainer
73
+ # ---------------------------------------------------------------------
74
+ if model_type in ("random_forest", "decision_tree"):
75
+ logger.info("Using SHAP TreeExplainer for tree-based model.")
76
+
77
+ if X.empty:
78
+ logger.warning("Received empty DataFrame for SHAP explanation; returning empty list.")
79
+ return []
80
+
81
+ x = X.iloc[[0]]
82
+ feature_names = x.columns.tolist()
83
+
84
+ try:
85
+ explainer = shap.TreeExplainer(model)
86
+ shap_exp = explainer(x)
87
+ values = np.asarray(shap_exp.values)
88
+ logger.debug(f"Raw SHAP values shape: {values.shape!r}")
89
+ except Exception as e:
90
+ logger.error(f"SHAP TreeExplainer failed: {e}")
91
+ logger.warning("SHAP explanation not available for this model.")
92
+ return []
93
+
94
+ if values.ndim == 2:
95
+ shap_vec = values[0]
96
+
97
+ elif values.ndim == 3:
98
+ n_samples, dim2, dim3 = values.shape
99
+
100
+ if dim2 == x.shape[1]:
101
+ n_outputs = dim3
102
+ class_index = 1 if n_outputs > 1 else 0
103
+ shap_vec = values[0, :, class_index]
104
+
105
+ elif dim3 == x.shape[1]:
106
+ n_outputs = dim2
107
+ class_index = 1 if n_outputs > 1 else 0
108
+ shap_vec = values[0, class_index, :]
109
+
110
+ else:
111
+ logger.error(f"Unexpected SHAP shape {values.shape} for {x.shape[1]} features.")
112
+ return []
113
+
114
+ else:
115
+ logger.error(f"Unexpected SHAP values dimension: {values.ndim}")
116
+ return []
117
+
118
+ shap_vec = np.asarray(shap_vec).reshape(-1)
119
+
120
+ if shap_vec.shape[0] != len(feature_names):
121
+ logger.warning(
122
+ f"SHAP vector length ({shap_vec.shape[0]}) "
123
+ f"!= number of features ({len(feature_names)}). "
124
+ "Truncating to minimum length."
125
+ )
126
+
127
+ n = min(len(feature_names), shap_vec.shape[0])
128
+ explanations = [
129
+ {
130
+ "feature": feature_names[i],
131
+ "value": float(shap_vec[i]),
132
+ "abs_value": float(abs(shap_vec[i])),
133
+ }
134
+ for i in range(n)
135
+ ]
136
+
137
+ explanations = sorted(explanations, key=lambda d: d["abs_value"], reverse=True)[:top_k]
138
+
139
+ logger.info(f"Built SHAP-based explanation. Returning top {len(explanations)} features.")
140
+ return explanations
141
+
142
+
143
+ def save_shap_waterfall_plot(
144
+ model: Any,
145
+ X: pd.DataFrame,
146
+ model_type: str,
147
+ output_path: Path,
148
+ ) -> Path | None:
149
+ """
150
+ Save a SHAP waterfall plot for a single sample to the given output path.
151
+ """
152
+ model_type = model_type.lower()
153
+
154
+ if model_type not in ("random_forest", "decision_tree"):
155
+ logger.warning(
156
+ f"Waterfall plot is only supported for tree-based models. "
157
+ f"Got model_type='{model_type}'. Skipping plot generation."
158
+ )
159
+ return None
160
+
161
+ if X.empty:
162
+ logger.warning("Received empty DataFrame for SHAP plot; skipping.")
163
+ return None
164
+
165
+ x = X.iloc[[0]]
166
+ logger.info(f"Generating SHAP waterfall plot for model_type='{model_type}'.")
167
+
168
+ try:
169
+ explainer = shap.TreeExplainer(model)
170
+ shap_exp = explainer(x)
171
+ except Exception as e:
172
+ logger.error(f"Failed to build SHAP explainer for plot: {e}")
173
+ return None
174
+
175
+ try:
176
+ output_path.parent.mkdir(parents=True, exist_ok=True)
177
+
178
+ shap_to_plot = shap_exp
179
+ if np.asarray(shap_exp.values).ndim == 3:
180
+ vals = np.asarray(shap_exp.values)
181
+ if vals.shape[1] == x.shape[1]:
182
+ shap_to_plot = shap_exp[..., 1]
183
+ elif vals.shape[2] == x.shape[1]:
184
+ shap_to_plot = shap_exp[:, 1, :]
185
+ else:
186
+ logger.warning(
187
+ f"Unexpected shape for SHAP values in plot: {vals.shape}. "
188
+ "Falling back to shap_exp[0]."
189
+ )
190
+ shap_to_plot = shap_exp
191
+
192
+ plt.figure()
193
+ shap.plots.waterfall(shap_to_plot[0], show=False)
194
+ plt.tight_layout()
195
+ plt.savefig(output_path, bbox_inches="tight")
196
+ plt.close()
197
+
198
+ logger.success(f"SHAP waterfall plot saved to {output_path}")
199
+ return output_path
200
+ except Exception as e:
201
+ logger.error(f"Failed to save SHAP waterfall plot: {e}")
202
+ return None
predicting_outcomes_in_heart_failure/modeling/predict.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+
5
+ import joblib
6
+ from loguru import logger
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from predicting_outcomes_in_heart_failure.app.schema import HeartSample
11
+ from predicting_outcomes_in_heart_failure.config import (
12
+ FIGURES_DIR,
13
+ INPUT_COLUMNS,
14
+ MODEL_PATH,
15
+ MULTI_CAT,
16
+ NUM_COLS_DEFAULT,
17
+ SCALER_PATH,
18
+ )
19
+ from predicting_outcomes_in_heart_failure.modeling.explainability import (
20
+ explain_prediction,
21
+ save_shap_waterfall_plot,
22
+ )
23
+
24
+
25
+ def preprocessing(sample_df: pd.DataFrame) -> pd.DataFrame:
26
+ """
27
+ Apply the exact same preprocessing used during training:
28
+ """
29
+ logger.info("Applying preprocessing pipeline for inference...")
30
+
31
+ if not (SCALER_PATH.exists() and MODEL_PATH.exists()):
32
+ raise FileNotFoundError("Preprocessing artifacts missing.")
33
+
34
+ scaler = joblib.load(SCALER_PATH)
35
+ input_columns = INPUT_COLUMNS
36
+ multi_cat = MULTI_CAT
37
+ num_cols = NUM_COLS_DEFAULT
38
+
39
+ logger.debug(f"Loaded scaler from {SCALER_PATH}")
40
+ logger.debug(f"Using {len(input_columns)} input columns")
41
+
42
+ if "Sex" in sample_df.columns and "Sex" not in input_columns:
43
+ logger.debug("Dropping column 'Sex' since it's not used by the current model variant.")
44
+ sample_df = sample_df.drop(columns=["Sex"])
45
+
46
+ if "Sex" in sample_df.columns and "Sex" in input_columns:
47
+ sample_df["Sex"] = sample_df["Sex"].map({"M": 1, "F": 0}).astype(int)
48
+ logger.debug("Mapped 'Sex' to binary values (M=1, F=0).")
49
+
50
+ if "ExerciseAngina" in sample_df.columns and "ExerciseAngina" in input_columns:
51
+ sample_df["ExerciseAngina"] = sample_df["ExerciseAngina"].map({"Y": 1, "N": 0}).astype(int)
52
+ logger.debug("Mapped 'ExerciseAngina' to binary values (Y=1, N=0).")
53
+
54
+ present_multi = [c for c in multi_cat if c in sample_df.columns]
55
+ if present_multi:
56
+ logger.debug(f"Performing one-hot encoding on: {present_multi}")
57
+ sample_df = pd.get_dummies(sample_df, columns=present_multi, drop_first=False)
58
+
59
+ for col in input_columns:
60
+ if col not in sample_df.columns:
61
+ sample_df[col] = 0
62
+ sample_df = sample_df.reindex(columns=input_columns, fill_value=0)
63
+ logger.debug("Aligned input columns with training feature order.")
64
+
65
+ cols_to_scale = [c for c in num_cols if c in sample_df.columns]
66
+ sample_df[cols_to_scale] = scaler.transform(sample_df[cols_to_scale])
67
+ logger.debug(f"Scaled numerical columns: {cols_to_scale}")
68
+
69
+ logger.success("Preprocessing completed successfully.")
70
+ return sample_df
71
+
72
+
73
+ def main():
74
+ logger.info("Starting static inference...")
75
+
76
+ sample = HeartSample(
77
+ Age=54,
78
+ ChestPainType="ASY",
79
+ RestingBP=140,
80
+ Cholesterol=239,
81
+ FastingBS=0,
82
+ RestingECG="Normal",
83
+ MaxHR=160,
84
+ ExerciseAngina="N",
85
+ Oldpeak=0.0,
86
+ ST_Slope="Up",
87
+ )
88
+ logger.info("Sample created successfully.")
89
+
90
+ X_raw = sample.to_dataframe()
91
+ logger.debug(f"Raw input features:\n{X_raw}")
92
+ X = preprocessing(X_raw)
93
+
94
+ if not MODEL_PATH.exists():
95
+ raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
96
+ model = joblib.load(MODEL_PATH)
97
+ logger.success(f"Loaded model from {MODEL_PATH}")
98
+
99
+ # Perform prediction
100
+ t0 = time.perf_counter()
101
+ y_pred = model.predict(X)[0]
102
+ inference_time = time.perf_counter() - t0
103
+ y_pred = int(y_pred) if np.issubdtype(type(y_pred), np.integer) else y_pred
104
+ result = {
105
+ "prediction": y_pred,
106
+ "inference_time_seconds": inference_time,
107
+ }
108
+
109
+ # Explainability
110
+ model = joblib.load(MODEL_PATH)
111
+ model_type = MODEL_PATH.stem
112
+ try:
113
+ logger.info("Computing explanation for the prediction...")
114
+ explanations = explain_prediction(model, X, model_type=model_type, top_k=5)
115
+ result["explanations"] = explanations
116
+ logger.success("Explanation computed successfully.")
117
+ except Exception as e:
118
+ logger.error(f"Failed to compute explanation: {e}")
119
+
120
+ try:
121
+ shap_path = FIGURES_DIR / f"shap_waterfall_{model_type}.png"
122
+ saved = save_shap_waterfall_plot(model, X, model_type=model_type, output_path=shap_path)
123
+ if saved is not None:
124
+ result["explanation_plot"] = str(saved)
125
+ except Exception as e:
126
+ logger.error(f"Failed to generate SHAP waterfall plot: {e}")
127
+
128
+ logger.info("Inference completed.")
129
+ logger.success(f"Prediction result: {result}")
130
+
131
+ return result
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
predicting_outcomes_in_heart_failure/modeling/train.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+
7
+ import dagshub
8
+ from imblearn.over_sampling import RandomOverSampler
9
+ import joblib
10
+ from loguru import logger
11
+ import mlflow
12
+ import pandas as pd
13
+ from sklearn.model_selection import GridSearchCV, StratifiedKFold
14
+
15
+ from predicting_outcomes_in_heart_failure.config import (
16
+ CONFIG_DT,
17
+ CONFIG_LR,
18
+ CONFIG_RF,
19
+ DATASET_NAME,
20
+ EXPERIMENT_NAME,
21
+ MODELS_DIR,
22
+ N_SPLITS,
23
+ PROCESSED_DATA_DIR,
24
+ RANDOM_STATE,
25
+ REPO_NAME,
26
+ REPO_OWNER,
27
+ REPORTS_DIR,
28
+ SCORING,
29
+ TARGET_COL,
30
+ VALID_MODELS,
31
+ VALID_VARIANTS,
32
+ )
33
+
34
+ REFIT = "f1"
35
+
36
+
37
+ def load_split(path: Path) -> pd.DataFrame:
38
+ if not path.exists():
39
+ logger.error(f"Missing split file: {path}. Run split_data.py first.")
40
+ raise FileNotFoundError(path)
41
+ df = pd.read_csv(path)
42
+ logger.info(f"Loaded {path} (rows={len(df)}, cols={df.shape[1]})")
43
+ return df
44
+
45
+
46
+ def apply_random_oversampling(
47
+ X: pd.DataFrame,
48
+ y: pd.Series,
49
+ model_name: str,
50
+ variant: str,
51
+ ):
52
+ """Apply RandomOverSampler to balance classes in the training set."""
53
+ logger.info(f"[{variant} | {model_name}] Applying RandomOverSampler on training data...")
54
+
55
+ # Log original class distribution
56
+ orig_counts = y.value_counts().to_dict()
57
+ logger.info(f"[{variant} | {model_name}] Original class distribution: {orig_counts}")
58
+
59
+ ros = RandomOverSampler(random_state=RANDOM_STATE)
60
+ X_res, y_res = ros.fit_resample(X, y)
61
+
62
+ # Log resampled class distribution
63
+ res_counts = y_res.value_counts().to_dict()
64
+ logger.info(f"[{variant} | {model_name}] Resampled class distribution: {res_counts}")
65
+
66
+ logger.success(f"[{variant} | {model_name}] RandomOverSampler applied successfully.")
67
+ return X_res, y_res
68
+
69
+
70
+ def get_model_and_grid(model_name: str):
71
+ """Return estimator and parameter grid for the selected model."""
72
+ if model_name == "decision_tree":
73
+ from sklearn.tree import DecisionTreeClassifier
74
+
75
+ estimator = DecisionTreeClassifier(random_state=RANDOM_STATE)
76
+ param_grid = CONFIG_DT
77
+ return estimator, param_grid
78
+
79
+ elif model_name == "logreg":
80
+ from sklearn.linear_model import LogisticRegression
81
+
82
+ estimator = LogisticRegression(max_iter=500, random_state=RANDOM_STATE)
83
+ param_grid = CONFIG_LR
84
+ return estimator, param_grid
85
+
86
+ elif model_name == "random_forest":
87
+ from sklearn.ensemble import RandomForestClassifier
88
+
89
+ estimator = RandomForestClassifier(random_state=RANDOM_STATE)
90
+ param_grid = CONFIG_RF
91
+ return estimator, param_grid
92
+
93
+ else:
94
+ raise ValueError(f"Unknown model_name: {model_name}")
95
+
96
+
97
+ def run_grid_search(
98
+ estimator,
99
+ param_grid,
100
+ X_train,
101
+ y_train,
102
+ model_name: str,
103
+ variant: str,
104
+ reports_dir: Path,
105
+ ):
106
+ """Run GridSearchCV for the specified model and log CV results."""
107
+ cv = StratifiedKFold(
108
+ n_splits=N_SPLITS,
109
+ shuffle=True,
110
+ random_state=RANDOM_STATE,
111
+ )
112
+ grid = GridSearchCV(
113
+ estimator=estimator,
114
+ param_grid=param_grid,
115
+ scoring=SCORING,
116
+ refit=REFIT,
117
+ cv=cv,
118
+ n_jobs=-1,
119
+ verbose=1,
120
+ return_train_score=True,
121
+ )
122
+
123
+ logger.info(f"[{variant} | {model_name}] Starting GridSearchCV …")
124
+ grid.fit(X_train, y_train)
125
+
126
+ logger.success(f"[{variant} | {model_name}] GridSearchCV completed.")
127
+ logger.info(f"[{variant} | {model_name}] Best params ({REFIT}): {grid.best_params_}")
128
+ logger.info(f"[{variant} | {model_name}] Best CV {REFIT}: {grid.best_score_:.4f}")
129
+
130
+ cv_results_path = reports_dir / "cv_results.csv"
131
+ df = pd.DataFrame(grid.cv_results_)
132
+ df.to_csv(cv_results_path, index=False)
133
+
134
+ mlflow.log_artifact(str(cv_results_path))
135
+ return grid.best_estimator_, grid, grid.best_params_
136
+
137
+
138
+ def save_artifacts(
139
+ model,
140
+ grid,
141
+ X_train,
142
+ model_name: str,
143
+ variant: str,
144
+ model_dir: Path,
145
+ reports_dir: Path,
146
+ ) -> None:
147
+ """Save model, parameters, and metadata to disk and MLflow."""
148
+ model_dir.mkdir(parents=True, exist_ok=True)
149
+ reports_dir.mkdir(parents=True, exist_ok=True)
150
+
151
+ model_path = model_dir / f"{model_name}.joblib"
152
+ joblib.dump(model, model_path)
153
+ logger.success(f"[{variant} | {model_name}] Saved model → {model_path}")
154
+
155
+ out = {
156
+ "model_name": model_name,
157
+ "data_variant": variant,
158
+ "cv": {
159
+ "refit": REFIT,
160
+ "best_score": getattr(grid, "best_score_", None),
161
+ "best_params": getattr(grid, "best_params_", None),
162
+ "scoring": list(SCORING.keys()),
163
+ "n_splits": N_SPLITS,
164
+ "random_state": RANDOM_STATE,
165
+ },
166
+ "features": list(X_train.columns),
167
+ }
168
+
169
+ cv_params_path = reports_dir / "cv_parameters.json"
170
+ with open(cv_params_path, "w", encoding="utf-8") as f:
171
+ json.dump(out, f, indent=4)
172
+
173
+ mlflow.log_artifact(str(cv_params_path))
174
+ logger.success(f"[{variant} | {model_name}] Saved artifacts.")
175
+
176
+
177
+ def train(model_name: str, variant: str):
178
+ """Train a model for a specific dataset variant and log results to MLflow."""
179
+ experiment_name = f"{EXPERIMENT_NAME}_{variant}"
180
+ if not mlflow.get_experiment_by_name(experiment_name):
181
+ mlflow.create_experiment(experiment_name)
182
+ mlflow.set_experiment(experiment_name)
183
+
184
+ train_path = PROCESSED_DATA_DIR / variant / "train.csv"
185
+ run_name = f"{model_name}_{variant}"
186
+
187
+ logger.info(f"=== Training started (model={model_name}, variant={variant}) ===")
188
+
189
+ with mlflow.start_run(run_name=run_name):
190
+ train_df = load_split(train_path)
191
+
192
+ rawdata = mlflow.data.from_pandas(train_df, name=f"{DATASET_NAME}_{variant}")
193
+ mlflow.log_input(rawdata, context="training")
194
+
195
+ X_train = train_df.drop(columns=[TARGET_COL])
196
+ y_train = train_df[TARGET_COL].astype(int)
197
+
198
+ X_train, y_train = apply_random_oversampling(
199
+ X_train,
200
+ y_train,
201
+ model_name=model_name,
202
+ variant=variant,
203
+ )
204
+
205
+ estimator, param_grid = get_model_and_grid(model_name)
206
+ mlflow.set_tag("estimator_name", estimator.__class__.__name__)
207
+ mlflow.set_tag("data_variant", variant)
208
+ mlflow.log_param("data_variant", variant)
209
+
210
+ model_dir = MODELS_DIR / variant
211
+ reports_dir = REPORTS_DIR / variant / model_name
212
+ reports_dir.mkdir(parents=True, exist_ok=True)
213
+
214
+ best_model, grid, params = run_grid_search(
215
+ estimator,
216
+ param_grid,
217
+ X_train,
218
+ y_train,
219
+ model_name=model_name,
220
+ variant=variant,
221
+ reports_dir=reports_dir,
222
+ )
223
+ mlflow.log_params(params)
224
+
225
+ save_artifacts(
226
+ best_model,
227
+ grid,
228
+ X_train,
229
+ model_name=model_name,
230
+ variant=variant,
231
+ model_dir=model_dir,
232
+ reports_dir=reports_dir,
233
+ )
234
+
235
+ logger.success(f"=== Training completed (model={model_name}, variant={variant}) ===")
236
+
237
+
238
+ def main():
239
+ parser = argparse.ArgumentParser()
240
+ parser.add_argument(
241
+ "--variant",
242
+ type=str,
243
+ choices=VALID_VARIANTS,
244
+ required=True,
245
+ help="Data variant to use: all, female, male, or nosex.",
246
+ )
247
+ parser.add_argument(
248
+ "--model",
249
+ type=str,
250
+ choices=VALID_MODELS,
251
+ required=True,
252
+ help="Model to train: logreg, random_forest, or decision_tree.",
253
+ )
254
+ args = parser.parse_args()
255
+
256
+ dagshub.init(repo_owner=REPO_OWNER, repo_name=REPO_NAME, mlflow=True)
257
+ train(args.model, args.variant)
258
+
259
+
260
+ if __name__ == "__main__":
261
+ main()
reports/figures/.gitkeep ADDED
File without changes
reports/nosex/random_forest/cv_parameters.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "random_forest",
3
+ "data_variant": "nosex",
4
+ "cv": {
5
+ "refit": "f1",
6
+ "best_score": 0.8616030330699311,
7
+ "best_params": {
8
+ "class_weight": "balanced",
9
+ "max_depth": 12,
10
+ "n_estimators": 800
11
+ },
12
+ "scoring": [
13
+ "accuracy",
14
+ "f1",
15
+ "recall",
16
+ "roc_auc"
17
+ ],
18
+ "n_splits": 5,
19
+ "random_state": 42
20
+ },
21
+ "features": [
22
+ "Age",
23
+ "RestingBP",
24
+ "Cholesterol",
25
+ "FastingBS",
26
+ "MaxHR",
27
+ "ExerciseAngina",
28
+ "Oldpeak",
29
+ "ChestPainType_ASY",
30
+ "ChestPainType_ATA",
31
+ "ChestPainType_NAP",
32
+ "ChestPainType_TA",
33
+ "RestingECG_LVH",
34
+ "RestingECG_Normal",
35
+ "RestingECG_ST",
36
+ "ST_Slope_Down",
37
+ "ST_Slope_Flat",
38
+ "ST_Slope_Up"
39
+ ]
40
+ }