Commit ·
000de75
0
Parent(s):
Inital Upload
Browse files- .gitignore +27 -0
- MODEL_CARD.md +263 -0
- README.md +263 -0
- lookups/Training_20221125_labs_lookup.csv +23 -0
- training/check_index_date_dist.py +149 -0
- training/combine_features.py +179 -0
- training/config.yaml +45 -0
- training/create_sh_lookup_table.py +48 -0
- training/cross_val_final_models.py +673 -0
- training/cross_val_first_models.py +460 -0
- training/encode_and_impute.py +286 -0
- training/encoding.py +281 -0
- training/imputation.py +255 -0
- training/model_h.py +2061 -0
- training/perform_forward_validation.py +296 -0
- training/perform_hyper_param_tuning.py +215 -0
- training/process_comorbidities.py +109 -0
- training/process_demographics.py +75 -0
- training/process_exacerbation_history.py +297 -0
- training/process_labs.py +228 -0
- training/process_pros.py +1031 -0
- training/process_spirometry.py +116 -0
- training/pros_multiple_time_windows.py +618 -0
- training/setup_labels_forward_val.py +643 -0
- training/setup_labels_hosp_comm.py +935 -0
- training/setup_labels_only_hosp.py +338 -0
- training/split_train_test_val.py +108 -0
- training/splitting.py +292 -0
.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Folders for model cohort data, training data plots and logs
|
| 2 |
+
data/
|
| 3 |
+
training/logging/
|
| 4 |
+
plots/
|
| 5 |
+
tmp/
|
| 6 |
+
training/explore.ipynb
|
| 7 |
+
|
| 8 |
+
# Byte-compiled / optimized / DLL files
|
| 9 |
+
training/__pycache__/
|
| 10 |
+
|
| 11 |
+
# Environments
|
| 12 |
+
.venv
|
| 13 |
+
.venvdowhy
|
| 14 |
+
|
| 15 |
+
# VS Code
|
| 16 |
+
.vscode
|
| 17 |
+
|
| 18 |
+
#MLFlow
|
| 19 |
+
mlruns/
|
| 20 |
+
mlruns.db
|
| 21 |
+
|
| 22 |
+
#Catboost
|
| 23 |
+
catboost_info/
|
| 24 |
+
|
| 25 |
+
#Dowhy
|
| 26 |
+
training/dowhy_2.py
|
| 27 |
+
training/dowhy_example.py
|
MODEL_CARD.md
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
tags:
|
| 5 |
+
- healthcare
|
| 6 |
+
- ehr
|
| 7 |
+
- copd
|
| 8 |
+
- clinical-risk
|
| 9 |
+
- tabular
|
| 10 |
+
- scikit-learn
|
| 11 |
+
- xgboost
|
| 12 |
+
- lightgbm
|
| 13 |
+
- catboost
|
| 14 |
+
- patient-reported-outcomes
|
| 15 |
+
pipeline_tag: tabular-classification
|
| 16 |
+
library_name: sklearn
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# COPD Open Models — Model H (90-Day Exacerbation Prediction)
|
| 20 |
+
|
| 21 |
+
## Model Details
|
| 22 |
+
|
| 23 |
+
Model H predicts the risk of a COPD exacerbation within **90 days** using features derived from NHS EHR datasets and patient-reported outcomes (PROs). It serves as both a production-grade prediction pipeline and a **template project** for building new COPD prediction models, featuring a reusable 2,000+ line core library (`model_h.py`) with end-to-end training, calibration, evaluation, and SHAP explainability.
|
| 24 |
+
|
| 25 |
+
### Key Characteristics
|
| 26 |
+
|
| 27 |
+
- **Comprehensive PRO integration** — the most detailed PRO feature engineering in the portfolio, covering four instruments: EQ-5D (monthly), MRC dyspnoea (weekly), CAT (daily), and Symptom Diary (daily), each with engagement metrics, score differences, and multi-window aggregations.
|
| 28 |
+
- **13 algorithms screened** in the first phase, narrowed to top 3 with Bayesian hyperparameter tuning.
|
| 29 |
+
- **Forward validation** on 9 months of prospective data (May 2023 – February 2024).
|
| 30 |
+
- **Reusable core library** (`model_h.py`) — 40+ functions for label setup, feature engineering, model evaluation, calibration, and SHAP explainability.
|
| 31 |
+
- Training code is fully decoupled from cloud infrastructure — runs locally with no Azure dependencies.
|
| 32 |
+
|
| 33 |
+
> **Note:** This repository contains no real patient-level data. All included data files are synthetic or example data for pipeline validation.
|
| 34 |
+
|
| 35 |
+
### Model Type
|
| 36 |
+
|
| 37 |
+
Traditional tabular ML classifiers (multiple candidate estimators; see "Training Procedure").
|
| 38 |
+
|
| 39 |
+
### Release Notes
|
| 40 |
+
|
| 41 |
+
- **Phase 1 (current):** Models C, E, H published as the initial "COPD Open Models" collection.
|
| 42 |
+
- **Phase 2 (planned):** Additional models may follow after codebase sanitisation.
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## Intended Use
|
| 47 |
+
|
| 48 |
+
This model and code are published as **reference implementations** for research, education, and benchmarking on COPD prediction tasks.
|
| 49 |
+
|
| 50 |
+
### Intended Users
|
| 51 |
+
|
| 52 |
+
- ML practitioners exploring tabular healthcare ML pipelines
|
| 53 |
+
- Researchers comparing feature engineering and evaluation approaches
|
| 54 |
+
- Developers building internal prototypes (non-clinical)
|
| 55 |
+
|
| 56 |
+
### Out-of-Scope Uses
|
| 57 |
+
|
| 58 |
+
- **Not** for clinical decision-making, triage, diagnosis, or treatment planning.
|
| 59 |
+
- **Not** a substitute for clinical judgement or validated clinical tools.
|
| 60 |
+
- Do **not** deploy in healthcare settings without an appropriate regulatory, clinical safety, and information governance framework.
|
| 61 |
+
|
| 62 |
+
### Regulatory Considerations (SaMD)
|
| 63 |
+
|
| 64 |
+
Regulatory status for software depends on the intended purpose expressed in documentation, labelling, and promotional materials. Downstream users integrating or deploying this model should determine whether their implementation qualifies as Software as a Medical Device (SaMD) and identify the legal "manufacturer" responsible for compliance and post-market obligations.
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Training Data
|
| 69 |
+
|
| 70 |
+
- **Source:** NHS EHR-derived datasets and Lenus COPD Service PRO data (training performed on controlled datasets; not distributed here).
|
| 71 |
+
- **Data available in this repo:** Synthetic/example datasets only.
|
| 72 |
+
- **Cohort:** COPD patients with RECEIVER and Scale-Up cohort membership.
|
| 73 |
+
- **Target:** Binary — `ExacWithin3Months` (hospital + community exacerbations) or `HospExacWithin3Months` (hospital only).
|
| 74 |
+
- **Configuration:** 90-day prediction window, 180-day lookback, 5-fold cross-validation.
|
| 75 |
+
|
| 76 |
+
### Features
|
| 77 |
+
|
| 78 |
+
| Category | Features |
|
| 79 |
+
|----------|----------|
|
| 80 |
+
| **Demographics** | Age (binned: <50/50-59/60-69/70-79/80+), Sex_F |
|
| 81 |
+
| **Comorbidities** | AsthmaOverlap (binary), Comorbidities count (binned: None/1-2/3+) |
|
| 82 |
+
| **Exacerbation History** | Hospital and community exac counts in lookback, days since last exac, recency-weighted counts |
|
| 83 |
+
| **Spirometry** | FEV1, FVC, FEV1/FVC ratio — max, min, and latest values |
|
| 84 |
+
| **Laboratory (20+ tests)** | MaxLifetime, MinLifetime, Max1Year, Min1Year, Latest values with recency weighting (decay_rate=0.001) — WBC, RBC, haemoglobin, haematocrit, platelets, sodium, potassium, creatinine, albumin, glucose, ALT, AST, GGT, bilirubin, ALP, cholesterol, triglycerides, TSH, and more |
|
| 85 |
+
| **EQ-5D (monthly)** | Q1–Q5, total score, latest values, engagement rates, score changes |
|
| 86 |
+
| **MRC Dyspnoea (weekly)** | MRC Score (1–5), latest value, engagement, variations |
|
| 87 |
+
| **CAT (daily)** | Q1–Q8, total score (0–40), latest values, engagement, score differences |
|
| 88 |
+
| **Symptom Diary (daily)** | Q5 rescue medication (binary), weekly aggregates, engagement rates |
|
| 89 |
+
|
| 90 |
+
### Data Preprocessing
|
| 91 |
+
|
| 92 |
+
1. **Target encoding** — K-fold encoding with smoothing for categorical features (Age, Comorbidities, FEV1 severity, smoking status, etc.).
|
| 93 |
+
2. **Imputation** — median/mean/mode imputation strategies, applied per-fold.
|
| 94 |
+
3. **Scaling** — MinMaxScaler to [0, 1], fit on training fold only.
|
| 95 |
+
4. **PRO LOGIC filtering** — 14-day minimum between exacerbation episodes, 2 consecutive negative Q5 responses required for borderline events (14–35 days apart).
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## Training Procedure
|
| 100 |
+
|
| 101 |
+
### Training Framework
|
| 102 |
+
|
| 103 |
+
- pandas, scikit-learn, imbalanced-learn, xgboost, lightgbm, catboost
|
| 104 |
+
- Hyperparameter tuning: scikit-optimize (BayesSearchCV)
|
| 105 |
+
- Explainability: SHAP (TreeExplainer)
|
| 106 |
+
- Experiment tracking: MLflow
|
| 107 |
+
|
| 108 |
+
### Algorithms Evaluated
|
| 109 |
+
|
| 110 |
+
**First Phase (13 model types):**
|
| 111 |
+
|
| 112 |
+
| Algorithm | Library |
|
| 113 |
+
|-----------|---------|
|
| 114 |
+
| DummyClassifier (baseline) | sklearn |
|
| 115 |
+
| Logistic Regression | sklearn |
|
| 116 |
+
| Logistic Regression (balanced) | sklearn |
|
| 117 |
+
| Random Forest | sklearn |
|
| 118 |
+
| Random Forest (balanced) | sklearn |
|
| 119 |
+
| Balanced Random Forest | imblearn |
|
| 120 |
+
| Balanced Bagging | imblearn |
|
| 121 |
+
| XGBoost (7 variants) | xgboost |
|
| 122 |
+
| LightGBM (2 variants) | lightgbm |
|
| 123 |
+
| CatBoost | catboost |
|
| 124 |
+
|
| 125 |
+
**Hyperparameter Tuning Search Spaces:**
|
| 126 |
+
|
| 127 |
+
| Algorithm | Parameters Tuned |
|
| 128 |
+
|-----------|-----------------|
|
| 129 |
+
| Logistic Regression | penalty, class_weight, max_iter (50–300), C (0.001–10) |
|
| 130 |
+
| Random Forest | max_depth (4–10), n_estimators (70–850), min_samples_split (2–10), class_weight |
|
| 131 |
+
| XGBoost | max_depth (4–10), n_estimators (70–850) |
|
| 132 |
+
|
| 133 |
+
**Final Phase:** Top 3 models (Balanced Random Forest, XGBoost, Random Forest) retrained with tuned hyperparameters.
|
| 134 |
+
|
| 135 |
+
### Evaluation Design
|
| 136 |
+
|
| 137 |
+
- **5-fold** cross-validation with per-fold preprocessing.
|
| 138 |
+
- Metrics evaluated at threshold 0.5 and at best-F1 threshold.
|
| 139 |
+
- Event-type breakdown: hospital vs. community exacerbations evaluated separately.
|
| 140 |
+
- **Forward validation:** 9 months of prospective data (May 2023 – February 2024), assessed with KS test and Wasserstein distance for distribution shift.
|
| 141 |
+
|
| 142 |
+
### Calibration
|
| 143 |
+
|
| 144 |
+
- **Sigmoid** (Platt scaling)
|
| 145 |
+
- **Isotonic** regression
|
| 146 |
+
- Applied via CalibratedClassifierCV with per-fold calibration.
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## Evaluation Results
|
| 151 |
+
|
| 152 |
+
> Replace this section with measured results from your training run.
|
| 153 |
+
|
| 154 |
+
| Metric | Value | Notes |
|
| 155 |
+
|--------|-------|-------|
|
| 156 |
+
| ROC-AUC | TBD | Cross-validation mean (± std) |
|
| 157 |
+
| AUC-PR | TBD | Primary metric for imbalanced outcome |
|
| 158 |
+
| F1 Score (@ 0.5) | TBD | Default threshold |
|
| 159 |
+
| Best F1 Score | TBD | At optimal threshold |
|
| 160 |
+
| Balanced Accuracy | TBD | Cross-validation mean |
|
| 161 |
+
| Brier Score | TBD | Probability calibration quality |
|
| 162 |
+
|
| 163 |
+
### Caveats on Metrics
|
| 164 |
+
|
| 165 |
+
- Performance depends heavily on cohort definition, PRO engagement rates, and label construction.
|
| 166 |
+
- Forward validation results may differ from cross-validation due to temporal shifts in data availability and coding practices.
|
| 167 |
+
- Reported metrics from controlled datasets may not transfer to other settings without recalibration and validation.
|
| 168 |
+
|
| 169 |
+
---
|
| 170 |
+
|
| 171 |
+
## Bias, Risks, and Limitations
|
| 172 |
+
|
| 173 |
+
- **Dataset shift:** EHR coding practices, PRO engagement, and population characteristics vary across sites and time periods.
|
| 174 |
+
- **PRO engagement bias:** Patients who engage more with digital health tools may differ systematically from non-engagers.
|
| 175 |
+
- **Label uncertainty:** Exacerbation events are constructed via PRO LOGIC — different definitions produce different results.
|
| 176 |
+
- **Fairness:** Outcomes and feature availability may vary by age, sex, deprivation, comorbidity burden, or service access.
|
| 177 |
+
- **Misuse risk:** Using predictions to drive clinical action without clinical safety processes can cause harm through false positives and negatives.
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## How to Use
|
| 182 |
+
|
| 183 |
+
### Pipeline Execution Order
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
# 1. Install dependencies
|
| 187 |
+
pip install pandas numpy scikit-learn imbalanced-learn xgboost lightgbm catboost scikit-optimize shap mlflow matplotlib seaborn pyyaml joblib scipy
|
| 188 |
+
|
| 189 |
+
# 2. Set up labels (choose one)
|
| 190 |
+
python training/setup_labels_hosp_comm.py # hospital + community exacerbations
|
| 191 |
+
python training/setup_labels_only_hosp.py # hospital only
|
| 192 |
+
python training/setup_labels_forward_val.py # forward validation set
|
| 193 |
+
|
| 194 |
+
# 3. Split data
|
| 195 |
+
python training/split_train_test_val.py
|
| 196 |
+
|
| 197 |
+
# 4. Feature engineering (run in sequence)
|
| 198 |
+
python training/process_demographics.py
|
| 199 |
+
python training/process_comorbidities.py
|
| 200 |
+
python training/process_exacerbation_history.py
|
| 201 |
+
python training/process_spirometry.py
|
| 202 |
+
python training/process_labs.py
|
| 203 |
+
python training/process_pros.py
|
| 204 |
+
|
| 205 |
+
# 5. Combine features
|
| 206 |
+
python training/combine_features.py
|
| 207 |
+
|
| 208 |
+
# 6. Encode and impute
|
| 209 |
+
python training/encode_and_impute.py
|
| 210 |
+
|
| 211 |
+
# 7. Screen algorithms
|
| 212 |
+
python training/cross_val_first_models.py
|
| 213 |
+
|
| 214 |
+
# 8. Hyperparameter tuning
|
| 215 |
+
python training/perform_hyper_param_tuning.py
|
| 216 |
+
|
| 217 |
+
# 9. Final cross-validation with best models
|
| 218 |
+
python training/cross_val_final_models.py
|
| 219 |
+
|
| 220 |
+
# 10. Forward validation (optional)
|
| 221 |
+
python training/perform_forward_validation.py
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
### Configuration
|
| 225 |
+
|
| 226 |
+
Edit `config.yaml` to adjust:
|
| 227 |
+
- `prediction_window` (default: 90 days)
|
| 228 |
+
- `lookback_period` (default: 180 days)
|
| 229 |
+
- `model_type` ('hosp_comm' or 'only_hosp')
|
| 230 |
+
- `num_folds` (default: 5)
|
| 231 |
+
- Input/output data paths
|
| 232 |
+
|
| 233 |
+
### Core Library
|
| 234 |
+
|
| 235 |
+
`model_h.py` provides 40+ reusable functions for:
|
| 236 |
+
- PRO LOGIC exacerbation validation
|
| 237 |
+
- Recency-weighted feature engineering
|
| 238 |
+
- Model evaluation (F1, PR-AUC, ROC-AUC, Brier, calibration curves)
|
| 239 |
+
- SHAP explainability (summary, local, interaction, decision plots)
|
| 240 |
+
- Calibration (sigmoid, isotonic, spline)
|
| 241 |
+
|
| 242 |
+
---
|
| 243 |
+
|
| 244 |
+
## Environmental Impact
|
| 245 |
+
|
| 246 |
+
Training computational requirements are minimal — all models are traditional tabular ML classifiers running on CPU. A full pipeline run (feature engineering through cross-validation) completes in minutes on a standard laptop.
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
## Citation
|
| 251 |
+
|
| 252 |
+
If you use this model or code, please cite:
|
| 253 |
+
|
| 254 |
+
- This repository: *(add citation format / Zenodo DOI if minted)*
|
| 255 |
+
- Associated publications: *(clinical trial results paper — forthcoming)*
|
| 256 |
+
|
| 257 |
+
## Authors and Contributors
|
| 258 |
+
|
| 259 |
+
- **Storm ID** (maintainers)
|
| 260 |
+
|
| 261 |
+
## License
|
| 262 |
+
|
| 263 |
+
This model and code are released under the **Apache 2.0** license.
|
README.md
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
tags:
|
| 5 |
+
- healthcare
|
| 6 |
+
- ehr
|
| 7 |
+
- copd
|
| 8 |
+
- clinical-risk
|
| 9 |
+
- tabular
|
| 10 |
+
- scikit-learn
|
| 11 |
+
- xgboost
|
| 12 |
+
- lightgbm
|
| 13 |
+
- catboost
|
| 14 |
+
- patient-reported-outcomes
|
| 15 |
+
pipeline_tag: tabular-classification
|
| 16 |
+
library_name: sklearn
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# COPD Open Models — Model H (90-Day Exacerbation Prediction)
|
| 20 |
+
|
| 21 |
+
## Model Details
|
| 22 |
+
|
| 23 |
+
Model H predicts the risk of a COPD exacerbation within **90 days** using features derived from NHS EHR datasets and patient-reported outcomes (PROs). It serves as both a production-grade prediction pipeline and a **template project** for building new COPD prediction models, featuring a reusable 2,000+ line core library (`model_h.py`) with end-to-end training, calibration, evaluation, and SHAP explainability.
|
| 24 |
+
|
| 25 |
+
### Key Characteristics
|
| 26 |
+
|
| 27 |
+
- **Comprehensive PRO integration** — the most detailed PRO feature engineering in the portfolio, covering four instruments: EQ-5D (monthly), MRC dyspnoea (weekly), CAT (daily), and Symptom Diary (daily), each with engagement metrics, score differences, and multi-window aggregations.
|
| 28 |
+
- **13 algorithms screened** in the first phase, narrowed to top 3 with Bayesian hyperparameter tuning.
|
| 29 |
+
- **Forward validation** on 9 months of prospective data (May 2023 – February 2024).
|
| 30 |
+
- **Reusable core library** (`model_h.py`) — 40+ functions for label setup, feature engineering, model evaluation, calibration, and SHAP explainability.
|
| 31 |
+
- Training code is fully decoupled from cloud infrastructure — runs locally with no Azure dependencies.
|
| 32 |
+
|
| 33 |
+
> **Note:** This repository contains no real patient-level data. All included data files are synthetic or example data for pipeline validation.
|
| 34 |
+
|
| 35 |
+
### Model Type
|
| 36 |
+
|
| 37 |
+
Traditional tabular ML classifiers (multiple candidate estimators; see "Training Procedure").
|
| 38 |
+
|
| 39 |
+
### Release Notes
|
| 40 |
+
|
| 41 |
+
- **Phase 1 (current):** Models C, E, H published as the initial "COPD Open Models" collection.
|
| 42 |
+
- **Phase 2 (planned):** Additional models may follow after codebase sanitisation.
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## Intended Use
|
| 47 |
+
|
| 48 |
+
This model and code are published as **reference implementations** for research, education, and benchmarking on COPD prediction tasks.
|
| 49 |
+
|
| 50 |
+
### Intended Users
|
| 51 |
+
|
| 52 |
+
- ML practitioners exploring tabular healthcare ML pipelines
|
| 53 |
+
- Researchers comparing feature engineering and evaluation approaches
|
| 54 |
+
- Developers building internal prototypes (non-clinical)
|
| 55 |
+
|
| 56 |
+
### Out-of-Scope Uses
|
| 57 |
+
|
| 58 |
+
- **Not** for clinical decision-making, triage, diagnosis, or treatment planning.
|
| 59 |
+
- **Not** a substitute for clinical judgement or validated clinical tools.
|
| 60 |
+
- Do **not** deploy in healthcare settings without an appropriate regulatory, clinical safety, and information governance framework.
|
| 61 |
+
|
| 62 |
+
### Regulatory Considerations (SaMD)
|
| 63 |
+
|
| 64 |
+
Regulatory status for software depends on the intended purpose expressed in documentation, labelling, and promotional materials. Downstream users integrating or deploying this model should determine whether their implementation qualifies as Software as a Medical Device (SaMD) and identify the legal "manufacturer" responsible for compliance and post-market obligations.
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Training Data
|
| 69 |
+
|
| 70 |
+
- **Source:** NHS EHR-derived datasets and Lenus COPD Service PRO data (training performed on controlled datasets; not distributed here).
|
| 71 |
+
- **Data available in this repo:** Synthetic/example datasets only.
|
| 72 |
+
- **Cohort:** COPD patients with RECEIVER and Scale-Up cohort membership.
|
| 73 |
+
- **Target:** Binary — `ExacWithin3Months` (hospital + community exacerbations) or `HospExacWithin3Months` (hospital only).
|
| 74 |
+
- **Configuration:** 90-day prediction window, 180-day lookback, 5-fold cross-validation.
|
| 75 |
+
|
| 76 |
+
### Features
|
| 77 |
+
|
| 78 |
+
| Category | Features |
|
| 79 |
+
|----------|----------|
|
| 80 |
+
| **Demographics** | Age (binned: <50/50-59/60-69/70-79/80+), Sex_F |
|
| 81 |
+
| **Comorbidities** | AsthmaOverlap (binary), Comorbidities count (binned: None/1-2/3+) |
|
| 82 |
+
| **Exacerbation History** | Hospital and community exac counts in lookback, days since last exac, recency-weighted counts |
|
| 83 |
+
| **Spirometry** | FEV1, FVC, FEV1/FVC ratio — max, min, and latest values |
|
| 84 |
+
| **Laboratory (20+ tests)** | MaxLifetime, MinLifetime, Max1Year, Min1Year, Latest values with recency weighting (decay_rate=0.001) — WBC, RBC, haemoglobin, haematocrit, platelets, sodium, potassium, creatinine, albumin, glucose, ALT, AST, GGT, bilirubin, ALP, cholesterol, triglycerides, TSH, and more |
|
| 85 |
+
| **EQ-5D (monthly)** | Q1–Q5, total score, latest values, engagement rates, score changes |
|
| 86 |
+
| **MRC Dyspnoea (weekly)** | MRC Score (1–5), latest value, engagement, variations |
|
| 87 |
+
| **CAT (daily)** | Q1–Q8, total score (0–40), latest values, engagement, score differences |
|
| 88 |
+
| **Symptom Diary (daily)** | Q5 rescue medication (binary), weekly aggregates, engagement rates |
|
| 89 |
+
|
| 90 |
+
### Data Preprocessing
|
| 91 |
+
|
| 92 |
+
1. **Target encoding** — K-fold encoding with smoothing for categorical features (Age, Comorbidities, FEV1 severity, smoking status, etc.).
|
| 93 |
+
2. **Imputation** — median/mean/mode imputation strategies, applied per-fold.
|
| 94 |
+
3. **Scaling** — MinMaxScaler to [0, 1], fit on training fold only.
|
| 95 |
+
4. **PRO LOGIC filtering** — 14-day minimum between exacerbation episodes, 2 consecutive negative Q5 responses required for borderline events (14–35 days apart).
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## Training Procedure
|
| 100 |
+
|
| 101 |
+
### Training Framework
|
| 102 |
+
|
| 103 |
+
- pandas, scikit-learn, imbalanced-learn, xgboost, lightgbm, catboost
|
| 104 |
+
- Hyperparameter tuning: scikit-optimize (BayesSearchCV)
|
| 105 |
+
- Explainability: SHAP (TreeExplainer)
|
| 106 |
+
- Experiment tracking: MLflow
|
| 107 |
+
|
| 108 |
+
### Algorithms Evaluated
|
| 109 |
+
|
| 110 |
+
**First Phase (13 model types):**
|
| 111 |
+
|
| 112 |
+
| Algorithm | Library |
|
| 113 |
+
|-----------|---------|
|
| 114 |
+
| DummyClassifier (baseline) | sklearn |
|
| 115 |
+
| Logistic Regression | sklearn |
|
| 116 |
+
| Logistic Regression (balanced) | sklearn |
|
| 117 |
+
| Random Forest | sklearn |
|
| 118 |
+
| Random Forest (balanced) | sklearn |
|
| 119 |
+
| Balanced Random Forest | imblearn |
|
| 120 |
+
| Balanced Bagging | imblearn |
|
| 121 |
+
| XGBoost (7 variants) | xgboost |
|
| 122 |
+
| LightGBM (2 variants) | lightgbm |
|
| 123 |
+
| CatBoost | catboost |
|
| 124 |
+
|
| 125 |
+
**Hyperparameter Tuning Search Spaces:**
|
| 126 |
+
|
| 127 |
+
| Algorithm | Parameters Tuned |
|
| 128 |
+
|-----------|-----------------|
|
| 129 |
+
| Logistic Regression | penalty, class_weight, max_iter (50–300), C (0.001–10) |
|
| 130 |
+
| Random Forest | max_depth (4–10), n_estimators (70–850), min_samples_split (2–10), class_weight |
|
| 131 |
+
| XGBoost | max_depth (4–10), n_estimators (70–850) |
|
| 132 |
+
|
| 133 |
+
**Final Phase:** Top 3 models (Balanced Random Forest, XGBoost, Random Forest) retrained with tuned hyperparameters.
|
| 134 |
+
|
| 135 |
+
### Evaluation Design
|
| 136 |
+
|
| 137 |
+
- **5-fold** cross-validation with per-fold preprocessing.
|
| 138 |
+
- Metrics evaluated at threshold 0.5 and at best-F1 threshold.
|
| 139 |
+
- Event-type breakdown: hospital vs. community exacerbations evaluated separately.
|
| 140 |
+
- **Forward validation:** 9 months of prospective data (May 2023 – February 2024), assessed with KS test and Wasserstein distance for distribution shift.
|
| 141 |
+
|
| 142 |
+
### Calibration
|
| 143 |
+
|
| 144 |
+
- **Sigmoid** (Platt scaling)
|
| 145 |
+
- **Isotonic** regression
|
| 146 |
+
- Applied via CalibratedClassifierCV with per-fold calibration.
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## Evaluation Results
|
| 151 |
+
|
| 152 |
+
> Replace this section with measured results from your training run.
|
| 153 |
+
|
| 154 |
+
| Metric | Value | Notes |
|
| 155 |
+
|--------|-------|-------|
|
| 156 |
+
| ROC-AUC | TBD | Cross-validation mean (± std) |
|
| 157 |
+
| AUC-PR | TBD | Primary metric for imbalanced outcome |
|
| 158 |
+
| F1 Score (@ 0.5) | TBD | Default threshold |
|
| 159 |
+
| Best F1 Score | TBD | At optimal threshold |
|
| 160 |
+
| Balanced Accuracy | TBD | Cross-validation mean |
|
| 161 |
+
| Brier Score | TBD | Probability calibration quality |
|
| 162 |
+
|
| 163 |
+
### Caveats on Metrics
|
| 164 |
+
|
| 165 |
+
- Performance depends heavily on cohort definition, PRO engagement rates, and label construction.
|
| 166 |
+
- Forward validation results may differ from cross-validation due to temporal shifts in data availability and coding practices.
|
| 167 |
+
- Reported metrics from controlled datasets may not transfer to other settings without recalibration and validation.
|
| 168 |
+
|
| 169 |
+
---
|
| 170 |
+
|
| 171 |
+
## Bias, Risks, and Limitations
|
| 172 |
+
|
| 173 |
+
- **Dataset shift:** EHR coding practices, PRO engagement, and population characteristics vary across sites and time periods.
|
| 174 |
+
- **PRO engagement bias:** Patients who engage more with digital health tools may differ systematically from non-engagers.
|
| 175 |
+
- **Label uncertainty:** Exacerbation events are constructed via PRO LOGIC — different definitions produce different results.
|
| 176 |
+
- **Fairness:** Outcomes and feature availability may vary by age, sex, deprivation, comorbidity burden, or service access.
|
| 177 |
+
- **Misuse risk:** Using predictions to drive clinical action without clinical safety processes can cause harm through false positives and negatives.
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## How to Use
|
| 182 |
+
|
| 183 |
+
### Pipeline Execution Order
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
# 1. Install dependencies
|
| 187 |
+
pip install pandas numpy scikit-learn imbalanced-learn xgboost lightgbm catboost scikit-optimize shap mlflow matplotlib seaborn pyyaml joblib scipy
|
| 188 |
+
|
| 189 |
+
# 2. Set up labels (choose one)
|
| 190 |
+
python training/setup_labels_hosp_comm.py # hospital + community exacerbations
|
| 191 |
+
python training/setup_labels_only_hosp.py # hospital only
|
| 192 |
+
python training/setup_labels_forward_val.py # forward validation set
|
| 193 |
+
|
| 194 |
+
# 3. Split data
|
| 195 |
+
python training/split_train_test_val.py
|
| 196 |
+
|
| 197 |
+
# 4. Feature engineering (run in sequence)
|
| 198 |
+
python training/process_demographics.py
|
| 199 |
+
python training/process_comorbidities.py
|
| 200 |
+
python training/process_exacerbation_history.py
|
| 201 |
+
python training/process_spirometry.py
|
| 202 |
+
python training/process_labs.py
|
| 203 |
+
python training/process_pros.py
|
| 204 |
+
|
| 205 |
+
# 5. Combine features
|
| 206 |
+
python training/combine_features.py
|
| 207 |
+
|
| 208 |
+
# 6. Encode and impute
|
| 209 |
+
python training/encode_and_impute.py
|
| 210 |
+
|
| 211 |
+
# 7. Screen algorithms
|
| 212 |
+
python training/cross_val_first_models.py
|
| 213 |
+
|
| 214 |
+
# 8. Hyperparameter tuning
|
| 215 |
+
python training/perform_hyper_param_tuning.py
|
| 216 |
+
|
| 217 |
+
# 9. Final cross-validation with best models
|
| 218 |
+
python training/cross_val_final_models.py
|
| 219 |
+
|
| 220 |
+
# 10. Forward validation (optional)
|
| 221 |
+
python training/perform_forward_validation.py
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
### Configuration
|
| 225 |
+
|
| 226 |
+
Edit `config.yaml` to adjust:
|
| 227 |
+
- `prediction_window` (default: 90 days)
|
| 228 |
+
- `lookback_period` (default: 180 days)
|
| 229 |
+
- `model_type` ('hosp_comm' or 'only_hosp')
|
| 230 |
+
- `num_folds` (default: 5)
|
| 231 |
+
- Input/output data paths
|
| 232 |
+
|
| 233 |
+
### Core Library
|
| 234 |
+
|
| 235 |
+
`model_h.py` provides 40+ reusable functions for:
|
| 236 |
+
- PRO LOGIC exacerbation validation
|
| 237 |
+
- Recency-weighted feature engineering
|
| 238 |
+
- Model evaluation (F1, PR-AUC, ROC-AUC, Brier, calibration curves)
|
| 239 |
+
- SHAP explainability (summary, local, interaction, decision plots)
|
| 240 |
+
- Calibration (sigmoid, isotonic, spline)
|
| 241 |
+
|
| 242 |
+
---
|
| 243 |
+
|
| 244 |
+
## Environmental Impact
|
| 245 |
+
|
| 246 |
+
Training computational requirements are minimal — all models are traditional tabular ML classifiers running on CPU. A full pipeline run (feature engineering through cross-validation) completes in minutes on a standard laptop.
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
## Citation
|
| 251 |
+
|
| 252 |
+
If you use this model or code, please cite:
|
| 253 |
+
|
| 254 |
+
- This repository: *(add citation format / Zenodo DOI if minted)*
|
| 255 |
+
- Associated publications: *(clinical trial results paper — forthcoming)*
|
| 256 |
+
|
| 257 |
+
## Authors and Contributors
|
| 258 |
+
|
| 259 |
+
- **Storm ID** (maintainers)
|
| 260 |
+
|
| 261 |
+
## License
|
| 262 |
+
|
| 263 |
+
This model and code are released under the **Apache 2.0** license.
|
lookups/Training_20221125_labs_lookup.csv
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ClinicalCodeDescription,RefUnit,DataQ10,DataQ50,DataQ90,DataMean,DataStd
|
| 2 |
+
ALT,u/l,8.0,17.0,45.0,30.243673995141023,113.66715870850466
|
| 3 |
+
AST,u/l,12.0,20.0,45.0,33.897970974873104,172.29514450554794
|
| 4 |
+
Albumin,g/l,23.0,33.0,40.0,32.143139689111685,6.696146405869818
|
| 5 |
+
Alkaline Phosphatase,u/l,61.0,93.0,172.0,115.04396818241,112.95974512695292
|
| 6 |
+
Basophils,10^9/l,0.0,0.01,0.1,0.04231942316043947,0.0731083879073512
|
| 7 |
+
C Reactive Protein,mg/l,3.0,24.0,153.0,54.79311137263134,72.88535292424446
|
| 8 |
+
Calcium,mmol/l,1.99,2.27,2.47,2.248002582795482,0.1994094643731489
|
| 9 |
+
Chloride,mmol/l,95.0,103.0,108.0,102.1250032055317,5.439721834421554
|
| 10 |
+
Cholesterol,mmol/l,3.3,4.5,6.3,4.702769260607771,1.2336072058201566
|
| 11 |
+
Eosinophils,10^9/l,0.0,0.14,0.4,0.1964911476952577,0.24370462126922973
|
| 12 |
+
Estimated GFR,ml/min,34.0,60.0,60.0,53.78249759753863,12.453638969866372
|
| 13 |
+
Glucose,mmol/l,4.6,6.1,10.8,7.236088370730932,3.9897001667960987
|
| 14 |
+
Haematocrit,l/l,0.298,0.381,0.455,0.3790491791553389,0.062262607997835374
|
| 15 |
+
Haemoglobin,g/l,94.0,123.0,149.0,122.11866599070132,21.51488967036221
|
| 16 |
+
Lymphocytes,10^9/l,0.7,1.5,2.8,1.7822859737658994,4.080028422225104
|
| 17 |
+
Mean Cell Volume,fl,84.7,93.1,102.2,93.31201943900349,7.3459980963707805
|
| 18 |
+
Monocytes,10^9/l,0.4,0.7,1.2,0.7565321844650618,0.5511385273529443
|
| 19 |
+
Neutrophils,10^9/l,3.1,5.8,11.5,6.777581745391032,4.192513421557571
|
| 20 |
+
Platelet Count,10^9/l,156.0,264.0,423.0,280.4851990035415,117.75695119538601
|
| 21 |
+
Sodium,mmol/l,132.0,138.0,142.0,137.67649793633612,4.375384346285572
|
| 22 |
+
Total Bilirubin,umol/l,4.0,8.0,19.0,11.707785816461053,18.33306195550475
|
| 23 |
+
White Blood Count,10^9/l,5.3,8.6,14.5,9.597243934380035,8.600370594737681
|
training/check_index_date_dist.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Using a master random seed, create random seeds to iterate through. These random seeds are
|
| 3 |
+
then used to assign a different random seed to every patient to avoid similar index dates
|
| 4 |
+
being generated amongst patients who were enrolled to the service at the same time.
|
| 5 |
+
Histograms were checked and the general random seed that provided the most uniform monthly
|
| 6 |
+
distribution of index dates was chosen for the final index date generation in
|
| 7 |
+
setup_labels_hosp_comm.py and setup_labels_only_hosp.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from datetime import timedelta
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import random
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
# Open patient details file to calculate index dates
|
| 18 |
+
patient_details = pd.read_csv("./data/pat_details_to_calc_index_dt.csv")
|
| 19 |
+
exac_data = pd.read_pickle("./data/hosp_comm_exacs.pkl")
|
| 20 |
+
|
| 21 |
+
pat_details_date_cols = ["EarliestIndexDate", "EarliestIndexAfterGap"]
|
| 22 |
+
for col in pat_details_date_cols:
|
| 23 |
+
patient_details[col] = pd.to_datetime(patient_details[col])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Using a master seed, generate random seeds to use for iterations
|
| 27 |
+
master_seed = 42
|
| 28 |
+
random.seed(master_seed)
|
| 29 |
+
general_seeds = random.sample(range(0, 2**32), 200)
|
| 30 |
+
|
| 31 |
+
# For each iteration, use a different general seed to create random seeds for each patient
|
| 32 |
+
for general_seed in general_seeds:
|
| 33 |
+
random.seed(general_seed)
|
| 34 |
+
|
| 35 |
+
# Create different random seeds for each patient
|
| 36 |
+
patient_details["RandomSeed"] = random.sample(
|
| 37 |
+
range(0, 2**32), patient_details.shape[0]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Create random index dates for each patient based on their random seed
|
| 41 |
+
full_index_date_df = pd.DataFrame()
|
| 42 |
+
for index, row in patient_details.iterrows():
|
| 43 |
+
excluded_index_dates_close_to_exac = False
|
| 44 |
+
iter_num = 0
|
| 45 |
+
while excluded_index_dates_close_to_exac is False:
|
| 46 |
+
np.random.seed(row["RandomSeed"] + iter_num)
|
| 47 |
+
rand_days_dict = {}
|
| 48 |
+
rand_date_dict = {}
|
| 49 |
+
|
| 50 |
+
# Generate random number of days
|
| 51 |
+
rand_days_dict[row["StudyId"]] = np.random.choice(
|
| 52 |
+
row["LengthInService"], size=row["NumRows"], replace=False
|
| 53 |
+
)
|
| 54 |
+
rand_date_dict[row["StudyId"]] = []
|
| 55 |
+
|
| 56 |
+
# Using the random days generated, calculate the date
|
| 57 |
+
for day in rand_days_dict[row["StudyId"]]:
|
| 58 |
+
if day <= row["NumDaysPossibleIndex"]:
|
| 59 |
+
rand_date_dict[row["StudyId"]].append(
|
| 60 |
+
row["EarliestIndexDate"] + timedelta(days=int(day))
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
rand_date_dict[row["StudyId"]].append(
|
| 64 |
+
row["EarliestIndexAfterGap"]
|
| 65 |
+
+ timedelta(days=int(day - row["NumDaysPossibleIndex"]))
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Get exacerbation info to exclude any exacerbations that occurred within 14 days
|
| 69 |
+
# before an index date
|
| 70 |
+
exac_event_per_patient = exac_data[
|
| 71 |
+
(exac_data["StudyId"] == row["StudyId"]) & (exac_data["IsExac"] == 1)
|
| 72 |
+
][["StudyId", "DateOfEvent", "IsExac"]]
|
| 73 |
+
|
| 74 |
+
# Create df from dictionaries containing random index dates
|
| 75 |
+
index_date_df = pd.DataFrame.from_dict(
|
| 76 |
+
rand_date_dict, orient="index"
|
| 77 |
+
).reset_index()
|
| 78 |
+
index_date_df = index_date_df.rename(columns={"index": "StudyId"})
|
| 79 |
+
|
| 80 |
+
# Convert the multiple columns containing index dates to one column
|
| 81 |
+
index_date_df = (
|
| 82 |
+
pd.melt(index_date_df, id_vars=["StudyId"], value_name="IndexDate")
|
| 83 |
+
.drop(["variable"], axis=1)
|
| 84 |
+
.sort_values(by=["StudyId", "IndexDate"])
|
| 85 |
+
)
|
| 86 |
+
index_date_df = index_date_df.dropna()
|
| 87 |
+
index_date_df = index_date_df.reset_index(drop=True)
|
| 88 |
+
|
| 89 |
+
# Calculate the time to event from exac date (DateOfEvent) to index date
|
| 90 |
+
exac_event_per_patient = exac_event_per_patient.merge(
|
| 91 |
+
index_date_df, on="StudyId", how="outer"
|
| 92 |
+
)
|
| 93 |
+
exac_event_per_patient["IndexDate"] = pd.to_datetime(
|
| 94 |
+
exac_event_per_patient["IndexDate"], utc=True
|
| 95 |
+
)
|
| 96 |
+
exac_event_per_patient["TimeToEvent"] = (
|
| 97 |
+
exac_event_per_patient["DateOfEvent"]
|
| 98 |
+
- exac_event_per_patient["IndexDate"]
|
| 99 |
+
).dt.days
|
| 100 |
+
|
| 101 |
+
# End while loop if there are no index dates within 14 days before index date. If
|
| 102 |
+
# there are, continue loop until there isn't
|
| 103 |
+
if (
|
| 104 |
+
not exac_event_per_patient["TimeToEvent"]
|
| 105 |
+
.between(-14, 0, inclusive="both")
|
| 106 |
+
.any()
|
| 107 |
+
):
|
| 108 |
+
excluded_index_dates_close_to_exac = True
|
| 109 |
+
full_index_date_df = pd.concat([full_index_date_df, index_date_df])
|
| 110 |
+
else:
|
| 111 |
+
iter_num = iter_num + 1
|
| 112 |
+
|
| 113 |
+
# Check distribution of generated index dates
|
| 114 |
+
full_index_date_df["IndexYear"] = full_index_date_df["IndexDate"].dt.year
|
| 115 |
+
full_index_date_df["IndexMonth"] = full_index_date_df["IndexDate"].dt.month
|
| 116 |
+
full_index_date_df["IndexMonth"].hist(bins=10, density=True)
|
| 117 |
+
|
| 118 |
+
os.makedirs("./plots/index_date", exist_ok=True)
|
| 119 |
+
plt.savefig(
|
| 120 |
+
"./plots/index_date/seed_" + str(general_seed) + "_hist.png",
|
| 121 |
+
bbox_inches="tight",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
dates_groupby = full_index_date_df.groupby(by=["IndexYear", "IndexMonth"]).count()[
|
| 125 |
+
["StudyId"]
|
| 126 |
+
]
|
| 127 |
+
month_groupby = (
|
| 128 |
+
full_index_date_df.groupby("IndexMonth").count()[["StudyId"]].reset_index()
|
| 129 |
+
)
|
| 130 |
+
year_groupby = (
|
| 131 |
+
full_index_date_df.groupby("IndexYear").count()[["StudyId"]].reset_index()
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
dates_groupby.plot.bar(y="StudyId")
|
| 135 |
+
plt.savefig(
|
| 136 |
+
"./plots/index_date/seed_" + str(general_seed) + "_month_year.png",
|
| 137 |
+
bbox_inches="tight",
|
| 138 |
+
)
|
| 139 |
+
month_groupby.plot.bar(x="IndexMonth", y="StudyId")
|
| 140 |
+
plt.savefig(
|
| 141 |
+
"./plots/index_date/seed_" + str(general_seed) + "_month.png",
|
| 142 |
+
bbox_inches="tight",
|
| 143 |
+
)
|
| 144 |
+
year_groupby.plot.bar(x="IndexYear", y="StudyId")
|
| 145 |
+
plt.savefig(
|
| 146 |
+
"./plots/index_date/seed_" + str(general_seed) + "_year.png",
|
| 147 |
+
bbox_inches="tight",
|
| 148 |
+
)
|
| 149 |
+
plt.close(fig="all")
|
training/combine_features.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Script that combines features, performs encoding of categorical features and imputation.
|
| 2 |
+
|
| 3 |
+
Demographics, exacerbation history, comorbidities, spirometry, labs, and pro datasets
|
| 4 |
+
combined. Splitting of dataset performed if the data_to_process specified in config.yaml is
|
| 5 |
+
not forward_val. Performs encoding of categorical features, and imputation of missing
|
| 6 |
+
values. Two versions of the data is saved: imputed and not imputed dataframes.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import yaml
|
| 14 |
+
import json
|
| 15 |
+
import joblib
|
| 16 |
+
import encoding
|
| 17 |
+
import imputation
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
with open("./training/config.yaml", "r") as config:
|
| 21 |
+
config = yaml.safe_load(config)
|
| 22 |
+
|
| 23 |
+
# Specify which model to generate features for
|
| 24 |
+
model_type = config["model_settings"]["model_type"]
|
| 25 |
+
|
| 26 |
+
# Setup log file
|
| 27 |
+
log = open("./training/logging/combine_features_" + model_type + ".log", "w")
|
| 28 |
+
sys.stdout = log
|
| 29 |
+
|
| 30 |
+
# Dataset to process - set through config file
|
| 31 |
+
data_to_process = config["model_settings"]["data_to_process"]
|
| 32 |
+
|
| 33 |
+
############################################################################
|
| 34 |
+
# Combine features
|
| 35 |
+
############################################################################
|
| 36 |
+
|
| 37 |
+
# Load cohort data
|
| 38 |
+
if data_to_process == "forward_val":
|
| 39 |
+
demographics = pd.read_pickle(
|
| 40 |
+
os.path.join(
|
| 41 |
+
config["outputs"]["processed_data_dir"],
|
| 42 |
+
"demographics_forward_val_{}.pkl".format(model_type),
|
| 43 |
+
)
|
| 44 |
+
)
|
| 45 |
+
exac_history = pd.read_pickle(
|
| 46 |
+
os.path.join(
|
| 47 |
+
config["outputs"]["processed_data_dir"],
|
| 48 |
+
"exac_history_forward_val_{}.pkl".format(model_type),
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
comorbidities = pd.read_pickle(
|
| 52 |
+
os.path.join(
|
| 53 |
+
config["outputs"]["processed_data_dir"],
|
| 54 |
+
"comorbidities_forward_val_{}.pkl".format(model_type),
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
spirometry = pd.read_pickle(
|
| 58 |
+
os.path.join(
|
| 59 |
+
config["outputs"]["processed_data_dir"],
|
| 60 |
+
"spirometry_forward_val_{}.pkl".format(model_type),
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
labs = pd.read_pickle(
|
| 64 |
+
os.path.join(
|
| 65 |
+
config["outputs"]["processed_data_dir"],
|
| 66 |
+
"labs_forward_val_{}.pkl".format(model_type),
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
pros = pd.read_pickle(
|
| 70 |
+
os.path.join(
|
| 71 |
+
config["outputs"]["processed_data_dir"],
|
| 72 |
+
"pros_forward_val_{}.pkl".format(model_type),
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
demographics = pd.read_pickle(
|
| 77 |
+
os.path.join(
|
| 78 |
+
config["outputs"]["processed_data_dir"],
|
| 79 |
+
"demographics_{}.pkl".format(model_type),
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
exac_history = pd.read_pickle(
|
| 83 |
+
os.path.join(
|
| 84 |
+
config["outputs"]["processed_data_dir"],
|
| 85 |
+
"exac_history_{}.pkl".format(model_type),
|
| 86 |
+
)
|
| 87 |
+
)
|
| 88 |
+
comorbidities = pd.read_pickle(
|
| 89 |
+
os.path.join(
|
| 90 |
+
config["outputs"]["processed_data_dir"],
|
| 91 |
+
"comorbidities_{}.pkl".format(model_type),
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
spirometry = pd.read_pickle(
|
| 95 |
+
os.path.join(
|
| 96 |
+
config["outputs"]["processed_data_dir"],
|
| 97 |
+
"spirometry_{}.pkl".format(model_type),
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
labs = pd.read_pickle(
|
| 101 |
+
os.path.join(
|
| 102 |
+
config["outputs"]["processed_data_dir"], "labs_{}.pkl".format(model_type)
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
pros = pd.read_pickle(
|
| 106 |
+
os.path.join(
|
| 107 |
+
config["outputs"]["processed_data_dir"], "pros_{}.pkl".format(model_type)
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
data_combined = demographics.merge(
|
| 112 |
+
exac_history, on=["StudyId", "IndexDate"], how="left"
|
| 113 |
+
)
|
| 114 |
+
data_combined = data_combined.merge(
|
| 115 |
+
comorbidities, on=["StudyId", "IndexDate"], how="left"
|
| 116 |
+
)
|
| 117 |
+
data_combined = data_combined.merge(spirometry, on=["StudyId", "IndexDate"], how="left")
|
| 118 |
+
data_combined = data_combined.merge(labs, on=["StudyId", "IndexDate"], how="left")
|
| 119 |
+
data_combined = data_combined.merge(pros, on=["StudyId", "IndexDate"], how="left")
|
| 120 |
+
|
| 121 |
+
# Print dataset info
|
| 122 |
+
print(
|
| 123 |
+
"Data date range",
|
| 124 |
+
data_combined["IndexDate"].min(),
|
| 125 |
+
data_combined["IndexDate"].max(),
|
| 126 |
+
)
|
| 127 |
+
print("Mean age", data_combined["Age"].mean())
|
| 128 |
+
print("Sex Female:", data_combined["Sex_F"].value_counts())
|
| 129 |
+
|
| 130 |
+
if data_to_process != "forward_val":
|
| 131 |
+
# Load training and test ids
|
| 132 |
+
train_ids = pd.read_pickle(
|
| 133 |
+
os.path.join(
|
| 134 |
+
config["outputs"]["cohort_info_dir"], "train_ids_{}.pkl".format(model_type)
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
test_ids = pd.read_pickle(
|
| 138 |
+
os.path.join(
|
| 139 |
+
config["outputs"]["cohort_info_dir"], "test_ids_{}.pkl".format(model_type)
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
fold_patients = np.load(
|
| 143 |
+
os.path.join(
|
| 144 |
+
config["outputs"]["cohort_info_dir"],
|
| 145 |
+
"fold_patients_{}.npy".format(model_type),
|
| 146 |
+
),
|
| 147 |
+
allow_pickle=True,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Split data into training and test sets
|
| 151 |
+
train_data = data_combined[data_combined["StudyId"].isin(train_ids)]
|
| 152 |
+
test_data = data_combined[data_combined["StudyId"].isin(test_ids)]
|
| 153 |
+
train_data = train_data.sort_values(by=["StudyId", "IndexDate"]).reset_index(
|
| 154 |
+
drop=True
|
| 155 |
+
)
|
| 156 |
+
test_data = test_data.sort_values(by=["StudyId", "IndexDate"]).reset_index(
|
| 157 |
+
drop=True
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Save data
|
| 161 |
+
train_data.to_pickle(
|
| 162 |
+
os.path.join(
|
| 163 |
+
config["outputs"]["processed_data_dir"],
|
| 164 |
+
"train_combined_{}.pkl".format(model_type),
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
test_data.to_pickle(
|
| 168 |
+
os.path.join(
|
| 169 |
+
config["outputs"]["processed_data_dir"],
|
| 170 |
+
"test_combined_{}.pkl".format(model_type),
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
data_combined.to_pickle(
|
| 175 |
+
os.path.join(
|
| 176 |
+
config["outputs"]["processed_data_dir"],
|
| 177 |
+
"forward_val_combined_{}.pkl".format(model_type),
|
| 178 |
+
)
|
| 179 |
+
)
|
training/config.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_settings:
|
| 2 |
+
prediction_window: 90
|
| 3 |
+
lookback_period: 180
|
| 4 |
+
seed: 0
|
| 5 |
+
index_date_generation_master_seed: 2188398760
|
| 6 |
+
pro_logic_min_days_after_exac: 14
|
| 7 |
+
pro_logic_max_days_after_exac: 35
|
| 8 |
+
neg_consecutive_q5_replies: 2
|
| 9 |
+
model_type: 'hosp_comm'
|
| 10 |
+
latest_date_before_bug_break: "2022-03-08"
|
| 11 |
+
after_bug_fixed_start_date: "2022-07-07"
|
| 12 |
+
training_data_end_date: "2023-05-23"
|
| 13 |
+
pro_q5_change_date: "2021-04-22"
|
| 14 |
+
forward_validation_earliest_date: "2023-05-23"
|
| 15 |
+
forward_validation_latest_date: "2024-02-20"
|
| 16 |
+
one_row_per_days_in_service: 150
|
| 17 |
+
num_folds: 5
|
| 18 |
+
data_to_process: 'test'
|
| 19 |
+
inputs:
|
| 20 |
+
raw_data_paths:
|
| 21 |
+
receiver_cohort: "<YOUR_DATA_PATH>/EXAMPLE_STUDY_DATA/Cohort3Rand.csv"
|
| 22 |
+
scale_up_cohort: "<YOUR_DATA_PATH>/SU_IDs/Scale_Up_lookup.csv"
|
| 23 |
+
patient_details: "<YOUR_DATA_PATH>/copd-dataset/CopdDatasetPatientDetails.txt"
|
| 24 |
+
patient_events: "<YOUR_DATA_PATH>/copd-dataset/PatientEvents.txt"
|
| 25 |
+
comorbidities: "<YOUR_DATA_PATH>/copd-dataset/CopdDatasetCoMorbidityDetails.txt"
|
| 26 |
+
copd_status: "<YOUR_DATA_PATH>/copd-dataset/CopdDatasetCopdStatusDetails.txt"
|
| 27 |
+
inhalers: "<YOUR_DATA_PATH>/copd-dataset/CopdDatasetUsualTherapies.txt"
|
| 28 |
+
pro_symptom_diary: "<YOUR_DATA_PATH>/copd-dataset/CopdDatasetProSymptomDiary.txt"
|
| 29 |
+
pro_eq5d: "<YOUR_DATA_PATH>/copd-dataset/CopdDatasetProEQ5D.txt"
|
| 30 |
+
pro_mrc: "<YOUR_DATA_PATH>/copd-dataset/CopdDatasetProMrc.txt"
|
| 31 |
+
pro_cat: "<YOUR_DATA_PATH>/copd-dataset/CopdDatasetProCat.txt"
|
| 32 |
+
receiver_community_verified_events: "<YOUR_DATA_PATH>/LenusEvents/breakdown_of_com_exac.xlsx"
|
| 33 |
+
scale_up_community_verified_events: "<YOUR_DATA_PATH>/LenusEvents/Scale_Up_comm_exac_count.xlsx"
|
| 34 |
+
admissions: "<YOUR_DATA_PATH>/03_Training/SMR01.csv"
|
| 35 |
+
prescribing: "<YOUR_DATA_PATH>/03_Training/Pharmacy.csv"
|
| 36 |
+
labs: "<YOUR_DATA_PATH>/02_Training/SCI_Store.csv"
|
| 37 |
+
labs_lookup_table: "./lookups/Training_20221125_labs_lookup.csv"
|
| 38 |
+
sh_demographics: "<YOUR_DATA_PATH>/EXAMPLE_STUDY_DATA/Demographics_Cohort4.csv"
|
| 39 |
+
outputs:
|
| 40 |
+
output_data_dir: './data'
|
| 41 |
+
cohort_info_dir: "./data/cohort_info/"
|
| 42 |
+
logging_dir: './training/logging'
|
| 43 |
+
artifact_dir: './tmp'
|
| 44 |
+
processed_data_dir: './data/processed_data'
|
| 45 |
+
model_input_data_dir: './data/model_input_data'
|
training/create_sh_lookup_table.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import yaml
|
| 4 |
+
|
| 5 |
+
with open("./training/config.yaml", "r") as config:
|
| 6 |
+
config = yaml.safe_load(config)
|
| 7 |
+
|
| 8 |
+
# Read lookups for RECEIVER
|
| 9 |
+
receiver = pd.read_csv(config['inputs']['raw_data_paths']['receiver_cohort'])
|
| 10 |
+
receiver = receiver.rename(columns={"RNo": "StudyId"})
|
| 11 |
+
|
| 12 |
+
# Read lookups for Scale Up
|
| 13 |
+
scaleup = pd.read_csv(config['inputs']['raw_data_paths']['scale_up_cohort'])
|
| 14 |
+
scaleup = scaleup.rename(columns={"Study_Number": "StudyId"})
|
| 15 |
+
|
| 16 |
+
# Concatenate tables and drop missing SH IDs (some study patients not in data extract)
|
| 17 |
+
all_patients = pd.concat([receiver, scaleup]).dropna()
|
| 18 |
+
|
| 19 |
+
# Save final mapping between StudyId and SafeHavenID
|
| 20 |
+
all_patients.to_pickle(os.path.join(config['outputs']['output_data_dir'], "sh_to_studyid_mapping.pkl"))
|
| 21 |
+
|
| 22 |
+
# Check for matching age and sex between SafeHaven and Lenus data (mapping sanity check)
|
| 23 |
+
lenus_demographics = pd.read_csv(
|
| 24 |
+
config['inputs']['raw_data_paths']['patient_details'],
|
| 25 |
+
usecols=["StudyId", "DateOfBirth", "Sex"],
|
| 26 |
+
sep="|",
|
| 27 |
+
)
|
| 28 |
+
sh_demographics = pd.read_csv(
|
| 29 |
+
config['inputs']['raw_data_paths']['sh_demographics'],
|
| 30 |
+
usecols=["SafeHavenID", "SEX", "OBF_DOB"],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
sh_demographics["OBF_DOB"] = pd.to_datetime(
|
| 34 |
+
sh_demographics["OBF_DOB"], utc=True
|
| 35 |
+
).dt.normalize()
|
| 36 |
+
|
| 37 |
+
mapping = all_patients.merge(sh_demographics, on="SafeHavenID", how="inner")
|
| 38 |
+
mapping = mapping.merge(lenus_demographics, on="StudyId", how="inner")
|
| 39 |
+
|
| 40 |
+
# Check patient sex matches
|
| 41 |
+
print(mapping[mapping.SEX != mapping.Sex])
|
| 42 |
+
# There is one mismatch
|
| 43 |
+
print(all_patients[all_patients.duplicated(subset="SafeHavenID", keep=False)])
|
| 44 |
+
|
| 45 |
+
# Check patient DOB matches
|
| 46 |
+
print(mapping[mapping.OBF_DOB != mapping.DateOfBirth])
|
| 47 |
+
|
| 48 |
+
print(mapping[mapping["StudyId"] == "SU126"])
|
training/cross_val_final_models.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import model_h
|
| 6 |
+
import shutil
|
| 7 |
+
import pickle
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
# Plotting
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
|
| 13 |
+
# Model training and evaluation
|
| 14 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 15 |
+
from sklearn.model_selection import cross_validate, cross_val_predict
|
| 16 |
+
from sklearn.metrics import precision_recall_curve, auc
|
| 17 |
+
from sklearn.calibration import CalibratedClassifierCV
|
| 18 |
+
from imblearn.ensemble import BalancedRandomForestClassifier
|
| 19 |
+
import xgboost as xgb
|
| 20 |
+
import ml_insights as mli
|
| 21 |
+
import mlflow
|
| 22 |
+
|
| 23 |
+
# Explainability
|
| 24 |
+
from sklearn.inspection import permutation_importance
|
| 25 |
+
|
| 26 |
+
with open("./training/config.yaml", "r") as config:
|
| 27 |
+
config = yaml.safe_load(config)
|
| 28 |
+
|
| 29 |
+
model_type = config['model_settings']['model_type']
|
| 30 |
+
|
| 31 |
+
##############################################################
|
| 32 |
+
# Load data
|
| 33 |
+
##############################################################
|
| 34 |
+
# Setup log file
|
| 35 |
+
log = open(
|
| 36 |
+
os.path.join(config['outputs']['logging_dir'], "modelling_" + model_type + ".log"), "w")
|
| 37 |
+
sys.stdout = log
|
| 38 |
+
|
| 39 |
+
# Load CV folds
|
| 40 |
+
fold_patients = np.load(os.path.join(config['outputs']['cohort_info_dir'],
|
| 41 |
+
'fold_patients_' + model_type + '.npy'), allow_pickle=True)
|
| 42 |
+
|
| 43 |
+
# Load imputed crossval data
|
| 44 |
+
train_data_imp = model_h.load_data_for_modelling(os.path.join(
|
| 45 |
+
config["outputs"]["model_input_data_dir"],
|
| 46 |
+
"train_imputed_cv_{}.pkl".format(model_type),
|
| 47 |
+
))
|
| 48 |
+
|
| 49 |
+
# Load not imputed crossval data
|
| 50 |
+
train_data_no_imp = model_h.load_data_for_modelling(os.path.join(
|
| 51 |
+
config["outputs"]["model_input_data_dir"],
|
| 52 |
+
"train_not_imputed_cv_{}.pkl".format(model_type),
|
| 53 |
+
))
|
| 54 |
+
|
| 55 |
+
# Load imputed test data
|
| 56 |
+
test_data_imp = model_h.load_data_for_modelling(os.path.join(
|
| 57 |
+
config["outputs"]["model_input_data_dir"],
|
| 58 |
+
"test_imputed_{}.pkl".format(model_type),
|
| 59 |
+
))
|
| 60 |
+
|
| 61 |
+
# Load not imputed test data
|
| 62 |
+
test_data_no_imp = model_h.load_data_for_modelling(os.path.join(
|
| 63 |
+
config["outputs"]["model_input_data_dir"],
|
| 64 |
+
"test_not_imputed_{}.pkl".format(model_type),
|
| 65 |
+
))
|
| 66 |
+
|
| 67 |
+
# Load exac data
|
| 68 |
+
#train_exac_data = pd.read_pickle('./data/train_exac_data_' + model_type + '.pkl')
|
| 69 |
+
#test_exac_data = pd.read_pickle('./data/test_exac_data_' + model_type + '.pkl')
|
| 70 |
+
|
| 71 |
+
# Print date ranges for train and test set
|
| 72 |
+
print('Train date range',
|
| 73 |
+
train_data_imp['IndexDate'].min(), train_data_imp['IndexDate'].max())
|
| 74 |
+
print('Test date range',
|
| 75 |
+
test_data_imp['IndexDate'].min(), test_data_imp['IndexDate'].max())
|
| 76 |
+
|
| 77 |
+
# Set tags
|
| 78 |
+
tags = {"prediction_window": config['model_settings']['prediction_window'],
|
| 79 |
+
"lookback_period": config['model_settings']['lookback_period'],
|
| 80 |
+
"min_index_date": train_data_imp['IndexDate'].min(),
|
| 81 |
+
"max_index_date": train_data_imp['IndexDate'].max(),
|
| 82 |
+
"1_row_per_length_in_service_days": config['model_settings']['one_row_per_days_in_service'],
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# Create a tuple with training and validation indicies for each fold. Can be done with
|
| 86 |
+
# either imputed or not imputed data as both have same patients
|
| 87 |
+
cross_val_fold_indices = []
|
| 88 |
+
for fold in fold_patients:
|
| 89 |
+
fold_val_ids = train_data_no_imp[train_data_no_imp.StudyId.isin(fold)]
|
| 90 |
+
fold_train_ids = train_data_no_imp[~(
|
| 91 |
+
train_data_no_imp.StudyId.isin(fold_val_ids.StudyId))]
|
| 92 |
+
|
| 93 |
+
# Get index of rows in val and train
|
| 94 |
+
fold_val_index = fold_val_ids.index
|
| 95 |
+
fold_train_index = fold_train_ids.index
|
| 96 |
+
|
| 97 |
+
# Append tuple of training and val indices
|
| 98 |
+
cross_val_fold_indices.append((fold_train_index, fold_val_index))
|
| 99 |
+
|
| 100 |
+
# Create list of model features
|
| 101 |
+
cols_to_drop = ['StudyId', 'ExacWithin3Months', 'IndexDate', 'HospExacWithin3Months',
|
| 102 |
+
'CommExacWithin3Months']
|
| 103 |
+
features_list = [col for col in train_data_no_imp.columns if col not in cols_to_drop]
|
| 104 |
+
|
| 105 |
+
### Train data ###
|
| 106 |
+
# Separate features from target for data with no imputation performed
|
| 107 |
+
train_features_no_imp = train_data_no_imp[features_list].astype('float')
|
| 108 |
+
train_target_no_imp = train_data_no_imp.ExacWithin3Months.astype('float')
|
| 109 |
+
# Separate features from target for data with no imputation performed
|
| 110 |
+
train_features_imp = train_data_imp[features_list].astype('float')
|
| 111 |
+
train_target_imp = train_data_imp.ExacWithin3Months.astype('float')
|
| 112 |
+
|
| 113 |
+
### Test data ###
|
| 114 |
+
# Separate features from target for data with no imputation performed
|
| 115 |
+
test_features_no_imp = test_data_no_imp[features_list].astype('float')
|
| 116 |
+
test_target_no_imp = test_data_no_imp.ExacWithin3Months.astype('float')
|
| 117 |
+
# Separate features from target for data with no imputation performed
|
| 118 |
+
test_features_imp = test_data_imp[features_list].astype('float')
|
| 119 |
+
test_target_imp = test_data_imp.ExacWithin3Months.astype('float')
|
| 120 |
+
|
| 121 |
+
# Check that the target in imputed and not imputed datasets are the same. If not,
|
| 122 |
+
# raise an error
|
| 123 |
+
if not train_target_no_imp.equals(train_target_imp):
|
| 124 |
+
raise ValueError(
|
| 125 |
+
'Target variable is not the same in imputed and non imputed datasets in the train set.')
|
| 126 |
+
if not test_target_no_imp.equals(test_target_imp):
|
| 127 |
+
raise ValueError(
|
| 128 |
+
'Target variable is not the same in imputed and non imputed datasets in the test set.')
|
| 129 |
+
train_target = train_target_no_imp
|
| 130 |
+
test_target = test_target_no_imp
|
| 131 |
+
|
| 132 |
+
# Make sure all features are numeric
|
| 133 |
+
for features in [train_features_no_imp, train_features_imp,
|
| 134 |
+
test_features_no_imp, test_features_imp]:
|
| 135 |
+
for col in features:
|
| 136 |
+
features[col] = pd.to_numeric(features[col], errors='coerce')
|
| 137 |
+
|
| 138 |
+
##############################################################
|
| 139 |
+
# Specify which models to evaluate
|
| 140 |
+
##############################################################
|
| 141 |
+
# Set up MLflow
|
| 142 |
+
mlflow.set_tracking_uri("sqlite:///mlruns.db")
|
| 143 |
+
mlflow.set_experiment('model_h_drop_1_' + model_type)
|
| 144 |
+
|
| 145 |
+
# Set CV scoring strategies and any model parameters
|
| 146 |
+
scoring = ['f1', 'balanced_accuracy', 'accuracy', 'precision', 'recall', 'roc_auc',
|
| 147 |
+
'average_precision', 'neg_brier_score']
|
| 148 |
+
|
| 149 |
+
# Set up models, each tuple contains 4 elements: model, model name, imputation status,
|
| 150 |
+
# type of model
|
| 151 |
+
models = []
|
| 152 |
+
# These models are run for both hospital exac model and hospital and community exac model
|
| 153 |
+
models.append((BalancedRandomForestClassifier(random_state=0),
|
| 154 |
+
'balanced_random_forest', 'imputed', 'tree'))
|
| 155 |
+
models.append((xgb.XGBClassifier(random_state=0, use_label_encoder=False,
|
| 156 |
+
eval_metric='logloss'),
|
| 157 |
+
'xgb', 'not_imputed', 'tree'))
|
| 158 |
+
models.append((RandomForestClassifier(),
|
| 159 |
+
'random_forest', 'imputed', 'tree'))
|
| 160 |
+
|
| 161 |
+
# Get the parent run where hyperparameter tuning was done
|
| 162 |
+
if model_type == 'only_hosp':
|
| 163 |
+
parent_run_id = 'ba2d7244654c4b84a815932a3167648f'
|
| 164 |
+
if model_type == 'hosp_comm':
|
| 165 |
+
parent_run_id = 'f71edd4c72f14c0692431dca297ec131'
|
| 166 |
+
|
| 167 |
+
##############################################################
|
| 168 |
+
# Run models
|
| 169 |
+
##############################################################
|
| 170 |
+
#In MLflow run, perform K-fold cross validation and capture mean score across folds.
|
| 171 |
+
with mlflow.start_run(run_name='hyperparameter_optimised_models_12'):
|
| 172 |
+
for model in models:
|
| 173 |
+
# Get parameters of best scoring models
|
| 174 |
+
best_params = model_h.get_mlflow_run_params(
|
| 175 |
+
model[1], parent_run_id, 'sqlite:///mlruns.db', model_type)
|
| 176 |
+
# Each model will have multiple best scores for different scoring metrics.
|
| 177 |
+
for n, scorer in enumerate(best_params):
|
| 178 |
+
params = best_params[scorer]
|
| 179 |
+
model[0].set_params(**params)
|
| 180 |
+
with mlflow.start_run(run_name=model[1] + '_tuning_scorer_' + scorer, nested=True):
|
| 181 |
+
print(model[1], scorer)
|
| 182 |
+
# Create the artifacts directory if it doesn't exist
|
| 183 |
+
os.makedirs(config['outputs']['artifact_dir'], exist_ok=True)
|
| 184 |
+
# Remove existing directory contents to not mix files between different runs
|
| 185 |
+
shutil.rmtree(config['outputs']['artifact_dir'])
|
| 186 |
+
|
| 187 |
+
# Select correct data based on whether model is using imputed or not imputed
|
| 188 |
+
# dataset
|
| 189 |
+
if model[2] == 'imputed':
|
| 190 |
+
train_features = train_features_imp
|
| 191 |
+
test_features = test_features_imp
|
| 192 |
+
train_data = train_data_imp
|
| 193 |
+
test_data = test_data_imp
|
| 194 |
+
else:
|
| 195 |
+
train_features = train_features_no_imp
|
| 196 |
+
test_features = test_features_no_imp
|
| 197 |
+
train_data = train_data_no_imp
|
| 198 |
+
test_data = test_data_no_imp
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
mlflow.set_tags(tags=tags)
|
| 202 |
+
|
| 203 |
+
# Perform K-fold cross validation with custom folds
|
| 204 |
+
crossval = cross_validate(model[0], train_features, train_target,
|
| 205 |
+
cv=cross_val_fold_indices,
|
| 206 |
+
return_estimator=True, scoring=scoring,
|
| 207 |
+
return_indices=True)
|
| 208 |
+
|
| 209 |
+
# Get the predicted probabilities from each models
|
| 210 |
+
probabilities_cv = cross_val_predict(model[0], train_features,
|
| 211 |
+
train_target,
|
| 212 |
+
cv=cross_val_fold_indices,
|
| 213 |
+
method='predict_proba')[:, 1]
|
| 214 |
+
|
| 215 |
+
# Evaluation for uncalibrated model - test set
|
| 216 |
+
for iter_num, estimator in enumerate(crossval['estimator']):
|
| 217 |
+
probs_test = estimator.predict_proba(test_features)[:,1]
|
| 218 |
+
preds_test = estimator.predict(test_features)
|
| 219 |
+
uncalib_metrics_test = model_h.calc_eval_metrics_for_model(
|
| 220 |
+
test_target, preds_test, probs_test, 'uncalib_test')
|
| 221 |
+
if iter_num == 0:
|
| 222 |
+
uncalib_metrics_test_df = pd.DataFrame(
|
| 223 |
+
uncalib_metrics_test, index=[iter_num])
|
| 224 |
+
else:
|
| 225 |
+
uncalib_metrics_test_df_iter = pd.DataFrame(
|
| 226 |
+
uncalib_metrics_test, index=[iter_num])
|
| 227 |
+
uncalib_metrics_test_df = pd.concat(
|
| 228 |
+
[uncalib_metrics_test_df, uncalib_metrics_test_df_iter])
|
| 229 |
+
uncalib_metrics_test_mean = uncalib_metrics_test_df.mean()
|
| 230 |
+
uncalib_metrics_test_mean = uncalib_metrics_test_mean.to_dict()
|
| 231 |
+
|
| 232 |
+
# Get threshold that gives best F1 score for uncalibrated model
|
| 233 |
+
best_thres_uncal, f1_bt, prec_bt, rec_bt = model_h.get_threshold_with_best_f1_score(
|
| 234 |
+
train_target, probabilities_cv)
|
| 235 |
+
# Save f1 score, precision and recall for the best threshold
|
| 236 |
+
mlflow.log_metric('best_thres_uncal', best_thres_uncal)
|
| 237 |
+
mlflow.log_metric('f1_best_thres', f1_bt)
|
| 238 |
+
mlflow.log_metric('precision_best_thres', prec_bt)
|
| 239 |
+
mlflow.log_metric('recall_best_thres', rec_bt)
|
| 240 |
+
|
| 241 |
+
#### Plot confusion matrix at different thresholds ####
|
| 242 |
+
model_h.plot_confusion_matrix(
|
| 243 |
+
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, best_thres_uncal], probabilities_cv,
|
| 244 |
+
train_target, model[1], model_type, 'uncalib')
|
| 245 |
+
|
| 246 |
+
#### Calculate AUC-PR score ####
|
| 247 |
+
precision, recall, thresholds = precision_recall_curve(
|
| 248 |
+
train_target, probabilities_cv)
|
| 249 |
+
auc_pr = auc(recall, precision)
|
| 250 |
+
mlflow.log_metric('auc_pr', auc_pr)
|
| 251 |
+
|
| 252 |
+
#### Generate calibration curves ####
|
| 253 |
+
if model[1] != 'dummy_classifier':
|
| 254 |
+
### Sigmoid calibration ###
|
| 255 |
+
# Perform calibration
|
| 256 |
+
model_sig = CalibratedClassifierCV(
|
| 257 |
+
model[0], method='sigmoid',cv=cross_val_fold_indices)
|
| 258 |
+
model_sig.fit(train_features, train_target)
|
| 259 |
+
probs_sig = model_sig.predict_proba(test_features)[:, 1]
|
| 260 |
+
probs_sig_2 = model_sig.predict_proba(test_features)
|
| 261 |
+
preds_sig = model_sig.predict(test_features)
|
| 262 |
+
# Generate metrics for calibrated model
|
| 263 |
+
calib_metrics_sig = model_h.calc_eval_metrics_for_model(
|
| 264 |
+
test_target, preds_sig, probs_sig, 'sig')
|
| 265 |
+
# Get threshold with best f1 score for calibrated model
|
| 266 |
+
best_thres_sig, _, _, _ = model_h.get_threshold_with_best_f1_score(
|
| 267 |
+
test_target, probs_sig)
|
| 268 |
+
mlflow.log_metric('best_thres_sig', best_thres_sig)
|
| 269 |
+
# Plot confusion matrices for calibrated model
|
| 270 |
+
model_h.plot_confusion_matrix(
|
| 271 |
+
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, best_thres_sig], probs_sig,
|
| 272 |
+
test_target, model[1], model_type, "sig")
|
| 273 |
+
# Plot score distribution for calibrated model
|
| 274 |
+
model_h.plot_score_distribution(
|
| 275 |
+
test_target, probs_sig, config['outputs']['artifact_dir'], model[1], model_type, 'sig')
|
| 276 |
+
# Calculate std of auc-pr between CV folds
|
| 277 |
+
model_h.calc_std_for_calibrated_classifiers(
|
| 278 |
+
model_sig, 'sig', test_features, test_target)
|
| 279 |
+
|
| 280 |
+
### Isotonic calibration ###
|
| 281 |
+
# Perform calibration
|
| 282 |
+
model_iso = CalibratedClassifierCV(
|
| 283 |
+
model[0], method='isotonic', cv=cross_val_fold_indices)
|
| 284 |
+
model_iso.fit(train_features, train_target)
|
| 285 |
+
probs_iso = model_iso.predict_proba(test_features)[:, 1]
|
| 286 |
+
preds_iso = model_iso.predict(test_features)
|
| 287 |
+
# Generate metrics for calibrated model
|
| 288 |
+
calib_metrics_iso = model_h.calc_eval_metrics_for_model(
|
| 289 |
+
test_target, preds_iso, probs_iso, 'iso')
|
| 290 |
+
# Get threshold with best f1 score for calibrated model
|
| 291 |
+
best_thres_iso, _, _, _ = model_h.get_threshold_with_best_f1_score(
|
| 292 |
+
test_target, probs_iso)
|
| 293 |
+
mlflow.log_metric('best_thres_iso', best_thres_iso)
|
| 294 |
+
# Plot confusion matrices for calibrated model
|
| 295 |
+
model_h.plot_confusion_matrix(
|
| 296 |
+
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, best_thres_iso], probs_iso,
|
| 297 |
+
test_target, model[1], model_type, "iso")
|
| 298 |
+
# Plot score distribution for calibrated model
|
| 299 |
+
model_h.plot_score_distribution(
|
| 300 |
+
test_target, probs_iso, config['outputs']['artifact_dir'], model[1], model_type, 'iso')
|
| 301 |
+
# Calculate std of auc-pr between CV folds
|
| 302 |
+
model_h.calc_std_for_calibrated_classifiers(
|
| 303 |
+
model_iso, 'iso', test_features, test_target)
|
| 304 |
+
|
| 305 |
+
### Spline calibration ###
|
| 306 |
+
# Perform calibration
|
| 307 |
+
spline_calib = mli.SplineCalib()
|
| 308 |
+
spline_calib.fit(probabilities_cv, train_target)
|
| 309 |
+
model[0].fit(train_features, train_target)
|
| 310 |
+
preds_test_uncalib = model[0].predict_proba(test_features)[:,1]
|
| 311 |
+
probs_spline = spline_calib.calibrate(preds_test_uncalib)
|
| 312 |
+
preds_spline = probs_spline > 0.5
|
| 313 |
+
preds_spline = preds_spline.astype(int)
|
| 314 |
+
# Generate metrics for calibrated model
|
| 315 |
+
calib_metrics_spline = model_h.calc_eval_metrics_for_model(
|
| 316 |
+
test_target, preds_spline, probs_spline, 'spline')
|
| 317 |
+
# Get threshold with best f1 score for calibrated model
|
| 318 |
+
best_thres_spline, _, _, _ = model_h.get_threshold_with_best_f1_score(
|
| 319 |
+
test_target, probs_spline)
|
| 320 |
+
mlflow.log_metric('best_thres_spline', best_thres_spline)
|
| 321 |
+
# Plot confusion matrices for calibrated model
|
| 322 |
+
model_h.plot_confusion_matrix(
|
| 323 |
+
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, best_thres_spline], probs_spline,
|
| 324 |
+
test_target, model[1], model_type, "spline")
|
| 325 |
+
# Plot score distribution for calibrated model
|
| 326 |
+
model_h.plot_score_distribution(
|
| 327 |
+
test_target, probs_spline, config['outputs']['artifact_dir'], model[1], model_type, 'spline')
|
| 328 |
+
|
| 329 |
+
### Plot calibration curves ###
|
| 330 |
+
# Plot calibration curves for equal width bins (each bin has same width)
|
| 331 |
+
# and equal frequency bins (each bin has same number of observations)
|
| 332 |
+
for strategy in ['uniform', 'quantile']:
|
| 333 |
+
for bins in [5, 6, 10]:
|
| 334 |
+
plt.figure(figsize=(8,8))
|
| 335 |
+
plt.plot([0, 1], [0, 1], linestyle='--')
|
| 336 |
+
model_h.plot_calibration_curve(
|
| 337 |
+
train_target, probabilities_cv, bins, strategy, 'Uncalibrated')
|
| 338 |
+
model_h.plot_calibration_curve(
|
| 339 |
+
test_target, probs_sig, bins, strategy,'Sigmoid')
|
| 340 |
+
model_h.plot_calibration_curve(
|
| 341 |
+
test_target, probs_iso, bins, strategy, 'Isotonic')
|
| 342 |
+
model_h.plot_calibration_curve(
|
| 343 |
+
test_target, probs_spline, bins, strategy, 'Spline')
|
| 344 |
+
plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
|
| 345 |
+
plt.title(model[1])
|
| 346 |
+
plt.tight_layout()
|
| 347 |
+
plt.savefig(
|
| 348 |
+
os.path.join(config['outputs']['artifact_dir'], model[1] +
|
| 349 |
+
'_' + strategy + '_bins' + str(bins) + '_' +
|
| 350 |
+
model_type + '.png'))
|
| 351 |
+
plt.close()
|
| 352 |
+
|
| 353 |
+
# Plot uncalibrated model calibration curve at different bins and
|
| 354 |
+
# strategies
|
| 355 |
+
fig, (ax1,ax2) = plt.subplots(ncols=2, sharex=True, figsize=(15,10))
|
| 356 |
+
#plt.figure(figsize=(8,8))
|
| 357 |
+
for ax in [ax1, ax2]:
|
| 358 |
+
ax.plot([0, 1], [0, 1], linestyle='--')
|
| 359 |
+
for bins in [5, 6, 7, 8, 9]:
|
| 360 |
+
model_h.plot_calibration_curve(
|
| 361 |
+
train_target, probabilities_cv, bins, 'quantile', 'Bins=' +
|
| 362 |
+
str(bins), ax1)
|
| 363 |
+
for bins in [5, 6, 7, 8, 9]:
|
| 364 |
+
model_h.plot_calibration_curve(
|
| 365 |
+
train_target, probabilities_cv, bins, 'uniform', 'Bins=' +
|
| 366 |
+
str(bins), ax2)
|
| 367 |
+
ax1.title.set_text(model[1] + ' uncalibrated model quantile bins')
|
| 368 |
+
ax2.title.set_text(model[1] + ' uncalibrated model uniform bins')
|
| 369 |
+
plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
|
| 370 |
+
plt.tight_layout()
|
| 371 |
+
plt.savefig(
|
| 372 |
+
os.path.join(config['outputs']['artifact_dir'], model[1] + '_uncal_'
|
| 373 |
+
+ model_type + '.png'))
|
| 374 |
+
plt.close()
|
| 375 |
+
|
| 376 |
+
# Plot calibration curves with error bars
|
| 377 |
+
model_h.plot_calibration_plot_with_error_bars(
|
| 378 |
+
probabilities_cv, probs_sig, probs_iso, probs_spline, train_target,
|
| 379 |
+
test_target, model[1])
|
| 380 |
+
plt.close()
|
| 381 |
+
|
| 382 |
+
#### Get total gain and total cover for boosting machine models ####
|
| 383 |
+
if model[1].startswith("xgb"):
|
| 384 |
+
feat_importance_tot_gain_df = model_h.plot_feat_importance_model(
|
| 385 |
+
model[0], model[1], model_type)
|
| 386 |
+
# Save feature importance by total gain
|
| 387 |
+
if model[1].startswith("xgb"):
|
| 388 |
+
feat_importance_tot_gain_df.to_csv(
|
| 389 |
+
'./data/feature_importance_tot_gain_' + model_type + '.csv', index=False)
|
| 390 |
+
|
| 391 |
+
#### Calculate model performance by event type ####
|
| 392 |
+
if model[1] not in ['dummy_classifier']:
|
| 393 |
+
# Create df to contain prediction data and event type data
|
| 394 |
+
preds_event_df_uncalib = model_h.create_df_probabilities_and_predictions(
|
| 395 |
+
probabilities_cv, best_thres_uncal,
|
| 396 |
+
train_data['StudyId'].tolist(),
|
| 397 |
+
train_target,
|
| 398 |
+
train_data[['ExacWithin3Months','HospExacWithin3Months','CommExacWithin3Months']],
|
| 399 |
+
model[1], model_type, output_dir='./data/prediction_and_events/')
|
| 400 |
+
preds_events_df_sig = model_h.create_df_probabilities_and_predictions(
|
| 401 |
+
probs_sig, best_thres_sig, test_data['StudyId'].tolist(),
|
| 402 |
+
test_target,
|
| 403 |
+
test_data[['ExacWithin3Months', 'HospExacWithin3Months','CommExacWithin3Months']],
|
| 404 |
+
model[1], model_type, output_dir='./data/prediction_and_events/',
|
| 405 |
+
calib_type='sig')
|
| 406 |
+
preds_events_df_iso = model_h.create_df_probabilities_and_predictions(
|
| 407 |
+
probs_iso, best_thres_iso, test_data['StudyId'].tolist(),
|
| 408 |
+
test_target,
|
| 409 |
+
test_data[['ExacWithin3Months', 'HospExacWithin3Months','CommExacWithin3Months']],
|
| 410 |
+
model[1], model_type, output_dir='./data/prediction_and_events/',
|
| 411 |
+
calib_type='iso')
|
| 412 |
+
preds_events_df_spline = model_h.create_df_probabilities_and_predictions(
|
| 413 |
+
probs_spline, best_thres_spline, test_data['StudyId'].tolist(),
|
| 414 |
+
test_target,
|
| 415 |
+
test_data[['ExacWithin3Months', 'HospExacWithin3Months','CommExacWithin3Months']],
|
| 416 |
+
model[1], model_type, output_dir='./data/prediction_and_events/',
|
| 417 |
+
calib_type='spline')
|
| 418 |
+
# Subset to each event type and calculate metrics
|
| 419 |
+
metrics_by_event_type_uncalib = model_h.calc_metrics_by_event_type(
|
| 420 |
+
preds_event_df_uncalib, calib_type="uncalib")
|
| 421 |
+
metrics_by_event_type_sig = model_h.calc_metrics_by_event_type(
|
| 422 |
+
preds_events_df_sig, calib_type='sig')
|
| 423 |
+
metrics_by_event_type_iso = model_h.calc_metrics_by_event_type(
|
| 424 |
+
preds_events_df_iso, calib_type='iso')
|
| 425 |
+
metrics_by_event_type_spline = model_h.calc_metrics_by_event_type(
|
| 426 |
+
preds_events_df_spline, calib_type='spline')
|
| 427 |
+
# Subset to each event type and plot ROC curve
|
| 428 |
+
model_h.plot_roc_curve_by_event_type(
|
| 429 |
+
preds_event_df_uncalib, model[1], 'uncalib')
|
| 430 |
+
model_h.plot_roc_curve_by_event_type(
|
| 431 |
+
preds_events_df_sig, model[1], 'sig')
|
| 432 |
+
model_h.plot_roc_curve_by_event_type(
|
| 433 |
+
preds_events_df_iso, model[1], 'iso')
|
| 434 |
+
model_h.plot_roc_curve_by_event_type(
|
| 435 |
+
preds_events_df_spline, model[1], 'spline')
|
| 436 |
+
# Subset to each event type and plot PR curve
|
| 437 |
+
model_h.plot_prec_recall_by_event_type(
|
| 438 |
+
preds_event_df_uncalib, model[1], 'uncalib')
|
| 439 |
+
model_h.plot_prec_recall_by_event_type(
|
| 440 |
+
preds_events_df_sig, model[1], 'sig')
|
| 441 |
+
model_h.plot_prec_recall_by_event_type(
|
| 442 |
+
preds_events_df_iso, model[1], 'iso')
|
| 443 |
+
model_h.plot_prec_recall_by_event_type(
|
| 444 |
+
preds_events_df_spline, model[1], 'spline')
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
#### SHAP ####
|
| 448 |
+
if model[1] not in ['dummy_classifier']:
|
| 449 |
+
### Uncalibrated model ###
|
| 450 |
+
# Get the average SHAP values from CV folds for uncalibrated model
|
| 451 |
+
shap_values_v_uncal, shap_values_t_uncal = model_h.get_uncalibrated_shap(
|
| 452 |
+
crossval['estimator'], test_features, train_features,
|
| 453 |
+
train_data[features_list].columns,
|
| 454 |
+
model[1], model_type)
|
| 455 |
+
|
| 456 |
+
## Plot SHAP summary plots ##
|
| 457 |
+
model_h.plot_averaged_summary_plot(
|
| 458 |
+
shap_values_t_uncal,
|
| 459 |
+
train_data[features_list],
|
| 460 |
+
model[1], 'uncalib', model_type)
|
| 461 |
+
|
| 462 |
+
## Plot SHAP interaction heatmap ##
|
| 463 |
+
model_h.plot_shap_interaction_value_heatmap(
|
| 464 |
+
crossval['estimator'], train_features,
|
| 465 |
+
train_data[features_list].columns,
|
| 466 |
+
model[1], model_type)
|
| 467 |
+
|
| 468 |
+
### Calibrated models ###
|
| 469 |
+
calib_models = {'sig':model_sig, 'iso':model_iso}
|
| 470 |
+
for calib_model_name in calib_models:
|
| 471 |
+
# Get the average SHAP values from CV folds for calibrated model
|
| 472 |
+
shap_values_v, shap_values_t = model_h.get_calibrated_shap_by_classifier(
|
| 473 |
+
calib_models[calib_model_name], test_features, train_features,
|
| 474 |
+
train_data.drop(
|
| 475 |
+
columns=['StudyId', 'ExacWithin3Months', 'IndexDate',
|
| 476 |
+
'HospExacWithin3Months',
|
| 477 |
+
'CommExacWithin3Months']).columns,
|
| 478 |
+
calib_model_name, model[1], model_type)
|
| 479 |
+
|
| 480 |
+
## Plot SHAP summary plots ##
|
| 481 |
+
model_h.plot_averaged_summary_plot(
|
| 482 |
+
shap_values_t,
|
| 483 |
+
train_data.drop(
|
| 484 |
+
columns=['StudyId', 'ExacWithin3Months', 'IndexDate',
|
| 485 |
+
'HospExacWithin3Months','CommExacWithin3Months']),
|
| 486 |
+
model[1], calib_model_name, model_type)
|
| 487 |
+
|
| 488 |
+
## Get feature importance for local SHAP values ##
|
| 489 |
+
feature_imp_df = model_h.get_local_shap_values(
|
| 490 |
+
model[1], model_type, shap_values_v, test_features,
|
| 491 |
+
calib_model_name,shap_ids_dir='./data/prediction_and_events/')
|
| 492 |
+
feature_imp_df.to_csv(
|
| 493 |
+
'./data/prediction_and_events/local_feature_imp_df' + model[1] +
|
| 494 |
+
'_' + calib_model_name + '.csv')
|
| 495 |
+
|
| 496 |
+
## Plot local SHAP plots ##
|
| 497 |
+
test_feat_enc_conv = model_h.plot_local_shap(
|
| 498 |
+
model[1], model_type, shap_values_v, test_features, train_features,
|
| 499 |
+
calib_model_name,
|
| 500 |
+
row_ids_to_plot=['missed', 'incorrect', 'correct'],
|
| 501 |
+
artifact_dir=config['outputs']['artifact_dir'],
|
| 502 |
+
shap_ids_dir='./data/prediction_and_events/',
|
| 503 |
+
reverse_scaling_flag=False,
|
| 504 |
+
convert_target_encodings=True, imputation=model[2],
|
| 505 |
+
target_enc_path="./data/artifacts/target_encodings_" + model_type + ".json",
|
| 506 |
+
return_enc_converted_df=False)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
"""
|
| 510 |
+
### Plot SHAP dependency plots ###
|
| 511 |
+
os.makedirs( "./tmp/dependence_plots", exist_ok=True)
|
| 512 |
+
categorical_cols = [
|
| 513 |
+
"DaysSinceLastExac_te", "FEV1PercentPredicted_te"]
|
| 514 |
+
for categorical_col in categorical_cols:
|
| 515 |
+
shap.dependence_plot(
|
| 516 |
+
categorical_col, shap_values_v, test_feat_enc_conv,
|
| 517 |
+
interaction_index=None, show=False)
|
| 518 |
+
plt.tight_layout()
|
| 519 |
+
plt.savefig(
|
| 520 |
+
"./tmp/dependence_plots/dependence_plot_" + categorical_col
|
| 521 |
+
+ "_" + model[1] + "_" + calib_model_name + file_suffix + ".png")
|
| 522 |
+
plt.close()
|
| 523 |
+
"""
|
| 524 |
+
### Plot distribution of model scores for uncalibrated model ###
|
| 525 |
+
model_h.plot_score_distribution(
|
| 526 |
+
train_target, probabilities_cv, config['outputs']['artifact_dir'],
|
| 527 |
+
model[1], model_type)
|
| 528 |
+
|
| 529 |
+
"""
|
| 530 |
+
### Permutation feature importance ###
|
| 531 |
+
def calc_permutation_importance(model, features, target, scoring, n_repeats):
|
| 532 |
+
permutation_imp = permutation_importance(model, features, target, random_state=0, scoring=scoring, n_repeats=n_repeats)
|
| 533 |
+
for n, score in enumerate(permutation_imp):
|
| 534 |
+
if n == 0:
|
| 535 |
+
df = pd.DataFrame(data=permutation_imp[score]['importances_mean'], index=features.columns)
|
| 536 |
+
df = df.rename(columns={0:score})
|
| 537 |
+
else:
|
| 538 |
+
df[score] = permutation_imp[score]['importances_mean']
|
| 539 |
+
return df, permutation_imp
|
| 540 |
+
def plot_permutation_feature_importance(permutation_imp_full, metric, col_names, n_repeats, train_or_test):
|
| 541 |
+
os.makedirs("./tmp/permutation_feat_imp", exist_ok=True)
|
| 542 |
+
sorted_importances_idx = permutation_imp_full[metric].importances_mean.argsort()
|
| 543 |
+
importances = pd.DataFrame(
|
| 544 |
+
permutation_imp_full[metric].importances[sorted_importances_idx].T,
|
| 545 |
+
columns=col_names[sorted_importances_idx],
|
| 546 |
+
)
|
| 547 |
+
ax = importances.plot.box(vert=False, whis=10)
|
| 548 |
+
ax.set_title("Permutation Importances(" + train_or_test + ")")
|
| 549 |
+
ax.axvline(x=0, color="k", linestyle="--")
|
| 550 |
+
ax.set_xlabel("Decrease in accuracy score")
|
| 551 |
+
ax.figure.tight_layout()
|
| 552 |
+
plt.savefig('./tmp/permutation_feat_imp/' + train_or_test + '_' + metric + '_repeats' + str(n_repeats) +'.png')
|
| 553 |
+
|
| 554 |
+
from scipy.cluster import hierarchy
|
| 555 |
+
from scipy.spatial.distance import squareform
|
| 556 |
+
from scipy.stats import spearmanr
|
| 557 |
+
full_dataset_feat = pd.concat([train_features, test_features], axis=0)
|
| 558 |
+
print(train_features)
|
| 559 |
+
print(full_dataset_feat)
|
| 560 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
|
| 561 |
+
corr = spearmanr(full_dataset_feat).correlation
|
| 562 |
+
|
| 563 |
+
# Ensure the correlation matrix is symmetric
|
| 564 |
+
corr = (corr + corr.T) / 2
|
| 565 |
+
np.fill_diagonal(corr, 1)
|
| 566 |
+
|
| 567 |
+
# We convert the correlation matrix to a distance matrix before performing
|
| 568 |
+
# hierarchical clustering using Ward's linkage.
|
| 569 |
+
distance_matrix = 1 - np.abs(corr)
|
| 570 |
+
dist_linkage = hierarchy.ward(squareform(distance_matrix))
|
| 571 |
+
dendro = hierarchy.dendrogram(
|
| 572 |
+
dist_linkage, labels=full_dataset_feat.columns.to_list(), ax=ax1, leaf_rotation=90
|
| 573 |
+
)
|
| 574 |
+
dendro_idx = np.arange(0, len(dendro["ivl"]))
|
| 575 |
+
|
| 576 |
+
ax2.imshow(corr[dendro["leaves"], :][:, dendro["leaves"]])
|
| 577 |
+
ax2.set_xticks(dendro_idx)
|
| 578 |
+
ax2.set_yticks(dendro_idx)
|
| 579 |
+
ax2.set_xticklabels(dendro["ivl"], rotation="vertical")
|
| 580 |
+
ax2.set_yticklabels(dendro["ivl"])
|
| 581 |
+
_ = fig.tight_layout()
|
| 582 |
+
plt.show()
|
| 583 |
+
plt.close()
|
| 584 |
+
|
| 585 |
+
#features_to_drop = ["TotalEngagementMRC", "NumCommExacPrior6mo", "WeekAvgCATQ2", "WeekAvgCATQ4"]
|
| 586 |
+
|
| 587 |
+
#X_train_sel = train_features.drop(columns=features_to_drop)
|
| 588 |
+
#X_test_sel = test_features.drop(columns=features_to_drop)
|
| 589 |
+
|
| 590 |
+
from collections import defaultdict
|
| 591 |
+
|
| 592 |
+
cluster_ids = hierarchy.fcluster(dist_linkage, 0.5, criterion="distance")
|
| 593 |
+
cluster_id_to_feature_ids = defaultdict(list)
|
| 594 |
+
for idx, cluster_id in enumerate(cluster_ids):
|
| 595 |
+
cluster_id_to_feature_ids[cluster_id].append(idx)
|
| 596 |
+
selected_features = [v[0] for v in cluster_id_to_feature_ids.values()]
|
| 597 |
+
selected_features_names = full_dataset_feat.columns[selected_features]
|
| 598 |
+
|
| 599 |
+
X_train_sel = train_features[selected_features_names]
|
| 600 |
+
X_test_sel = test_features[selected_features_names]
|
| 601 |
+
print(selected_features_names)
|
| 602 |
+
# retrain
|
| 603 |
+
# Perform calibration
|
| 604 |
+
model_sig_perm = CalibratedClassifierCV(
|
| 605 |
+
model[0], method='sigmoid',cv=cross_val_fold_indices)
|
| 606 |
+
model_sig_perm.fit(X_train_sel, train_target)
|
| 607 |
+
probs_sig = model_sig_perm.predict_proba(X_test_sel)[:, 1]
|
| 608 |
+
probs_sig_2 = model_sig_perm.predict_proba(X_test_sel)
|
| 609 |
+
preds_sig = model_sig_perm.predict(X_test_sel)
|
| 610 |
+
print('before')
|
| 611 |
+
print(calib_metrics_sig)
|
| 612 |
+
# Generate metrics for calibrated model
|
| 613 |
+
calib_metrics_sig = copd.calc_eval_metrics_for_model(
|
| 614 |
+
test_target, preds_sig, probs_sig, 'sig')
|
| 615 |
+
print(calib_metrics_sig)
|
| 616 |
+
|
| 617 |
+
def plot_permutation_importance(clf, X, y, ax):
|
| 618 |
+
result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2,scoring='average_precision')
|
| 619 |
+
perm_sorted_idx = result.importances_mean.argsort()
|
| 620 |
+
|
| 621 |
+
ax.boxplot(
|
| 622 |
+
result.importances[perm_sorted_idx].T,
|
| 623 |
+
vert=False,
|
| 624 |
+
labels=X.columns[perm_sorted_idx],
|
| 625 |
+
)
|
| 626 |
+
ax.axvline(x=0, color="k", linestyle="--")
|
| 627 |
+
return ax
|
| 628 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 629 |
+
plot_permutation_importance(model_sig_perm, X_test_sel, test_target, ax)
|
| 630 |
+
ax.set_title("Permutation Importances on selected subset of features\n(test set)")
|
| 631 |
+
ax.set_xlabel("Decrease in accuracy score")
|
| 632 |
+
ax.figure.tight_layout()
|
| 633 |
+
plt.savefig('./tmp/permutation_feat_imp.png')
|
| 634 |
+
|
| 635 |
+
#for metric in ['f1', 'average_precision', 'roc_auc']:
|
| 636 |
+
# for n_repeats in [5,10, 50]:
|
| 637 |
+
# permutation_imp_train_df, permutation_imp_train_dict = calc_permutation_importance(model_sig, train_features, train_target, scoring=scoring, n_repeats=n_repeats)
|
| 638 |
+
# plot_permutation_feature_importance(permutation_imp_train_dict, metric, train_features.columns, n_repeats, 'train')
|
| 639 |
+
# for n_repeats in [5,10, 50]:
|
| 640 |
+
# permutation_imp_test_df, permutation_imp_test_dict = calc_permutation_importance(model_sig, test_features, test_target, scoring=scoring, n_repeats=n_repeats)
|
| 641 |
+
# plot_permutation_feature_importance(permutation_imp_test_dict, metric, test_features.columns, n_repeats, 'test')
|
| 642 |
+
"""
|
| 643 |
+
### Log metrics, parameters, and artifacts ###
|
| 644 |
+
# Log metrics averaged across folds
|
| 645 |
+
for score in scoring:
|
| 646 |
+
mlflow.log_metric(score, crossval['test_' + score].mean())
|
| 647 |
+
mlflow.log_metric(score + '_std', crossval['test_' + score].std())
|
| 648 |
+
# Log metrics for calibrated models
|
| 649 |
+
if model[1] != 'dummy_classifier':
|
| 650 |
+
mlflow.log_metrics(uncalib_metrics_test_mean)
|
| 651 |
+
mlflow.log_metrics(calib_metrics_sig)
|
| 652 |
+
mlflow.log_metrics(calib_metrics_iso)
|
| 653 |
+
mlflow.log_metrics(calib_metrics_spline)
|
| 654 |
+
mlflow.log_metrics(metrics_by_event_type_uncalib)
|
| 655 |
+
mlflow.log_metrics(metrics_by_event_type_sig)
|
| 656 |
+
mlflow.log_metrics(metrics_by_event_type_iso)
|
| 657 |
+
mlflow.log_metrics(metrics_by_event_type_spline)
|
| 658 |
+
# Log model parameters
|
| 659 |
+
params = model[0].get_params()
|
| 660 |
+
for param in params:
|
| 661 |
+
mlflow.log_param(param, params[param])
|
| 662 |
+
# Log artifacts
|
| 663 |
+
mlflow.log_artifacts(config['outputs']['artifact_dir'])
|
| 664 |
+
|
| 665 |
+
# Save sig model
|
| 666 |
+
with open('./data/model/trained_sig_' + model[1] + '_pkl', 'wb') as files:
|
| 667 |
+
pickle.dump(model_sig, files)
|
| 668 |
+
with open('./data/model/trained_iso_' + model[1] + '_pkl', 'wb') as files:
|
| 669 |
+
pickle.dump(model_iso, files)
|
| 670 |
+
with open('./data/model/trained_spline_' + model[1] + '_pkl', 'wb') as files:
|
| 671 |
+
pickle.dump(spline_calib, files)
|
| 672 |
+
|
| 673 |
+
mlflow.end_run()
|
training/cross_val_first_models.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import mlflow
|
| 6 |
+
import model_h
|
| 7 |
+
|
| 8 |
+
# Plotting
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
|
| 12 |
+
# Model training and evaluation
|
| 13 |
+
from sklearn.dummy import DummyClassifier
|
| 14 |
+
from sklearn.linear_model import LogisticRegression
|
| 15 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 16 |
+
from sklearn.model_selection import cross_validate, cross_val_predict
|
| 17 |
+
from sklearn.metrics import confusion_matrix, precision_recall_curve
|
| 18 |
+
from sklearn.calibration import calibration_curve, CalibratedClassifierCV
|
| 19 |
+
from imblearn.ensemble import BalancedRandomForestClassifier, BalancedBaggingClassifier
|
| 20 |
+
import lightgbm as lgb
|
| 21 |
+
import xgboost as xgb
|
| 22 |
+
from catboost import CatBoostClassifier
|
| 23 |
+
import ml_insights as mli
|
| 24 |
+
|
| 25 |
+
# Explainability
|
| 26 |
+
import shap
|
| 27 |
+
|
| 28 |
+
##############################################################
|
| 29 |
+
# Specify which model to perform cross validation on
|
| 30 |
+
##############################################################
|
| 31 |
+
model_only_hosp = True
|
| 32 |
+
if model_only_hosp is True:
|
| 33 |
+
file_suffix = "_only_hosp"
|
| 34 |
+
else:
|
| 35 |
+
file_suffix = "_hosp_comm"
|
| 36 |
+
|
| 37 |
+
##############################################################
|
| 38 |
+
# Load data
|
| 39 |
+
##############################################################
|
| 40 |
+
# Setup log file
|
| 41 |
+
log = open("./training/logging/modelling" + file_suffix + ".log", "w")
|
| 42 |
+
sys.stdout = log
|
| 43 |
+
|
| 44 |
+
# Load CV folds
|
| 45 |
+
fold_patients = np.load(
|
| 46 |
+
'./data/cohort_info/fold_patients' + file_suffix + '.npy', allow_pickle=True)
|
| 47 |
+
|
| 48 |
+
# Load imputed train data
|
| 49 |
+
train_data_imp = model_h.load_data_for_modelling(
|
| 50 |
+
'./data/model_data/train_data_cv_imp' + file_suffix + '.pkl')
|
| 51 |
+
train_data_imp = train_data_imp.drop(columns=['Sex_F', 'Age_TEnc'])
|
| 52 |
+
|
| 53 |
+
# Load not imputed train data
|
| 54 |
+
train_data_no_imp = model_h.load_data_for_modelling(
|
| 55 |
+
'./data/model_data/train_data_cv_no_imp' + file_suffix + '.pkl')
|
| 56 |
+
train_data_no_imp = train_data_no_imp.drop(columns=['Sex_F', 'Age_TEnc'])
|
| 57 |
+
|
| 58 |
+
# Load imputed test data
|
| 59 |
+
test_data_imp = model_h.load_data_for_modelling(
|
| 60 |
+
'./data/model_data/test_data_imp' + file_suffix + '.pkl')
|
| 61 |
+
test_data_imp = test_data_imp.drop(columns=['Sex_F', 'Age_TEnc'])
|
| 62 |
+
|
| 63 |
+
# Load not imputed test data
|
| 64 |
+
test_data_no_imp = model_h.load_data_for_modelling(
|
| 65 |
+
'./data/model_data/test_data_no_imp' + file_suffix + '.pkl')
|
| 66 |
+
test_data_no_imp = test_data_no_imp.drop(columns=['Sex_F', 'Age_TEnc'])
|
| 67 |
+
|
| 68 |
+
# Create a tuple with training and validation indicies for each fold. Can be done with
|
| 69 |
+
# either imputed or not imputed data as both have same patients
|
| 70 |
+
cross_val_fold_indices = []
|
| 71 |
+
for fold in fold_patients:
|
| 72 |
+
fold_val_ids = train_data_no_imp[train_data_no_imp.StudyId.isin(fold)]
|
| 73 |
+
fold_train_ids = train_data_no_imp[~(
|
| 74 |
+
train_data_no_imp.StudyId.isin(fold_val_ids.StudyId))]
|
| 75 |
+
|
| 76 |
+
# Get index of rows in val and train
|
| 77 |
+
fold_val_index = fold_val_ids.index
|
| 78 |
+
fold_train_index = fold_train_ids.index
|
| 79 |
+
|
| 80 |
+
# Append tuple of training and val indices
|
| 81 |
+
cross_val_fold_indices.append((fold_train_index, fold_val_index))
|
| 82 |
+
|
| 83 |
+
# Create list of model features
|
| 84 |
+
cols_to_drop = ['StudyId', 'ExacWithin3Months']
|
| 85 |
+
features_list = [col for col in train_data_no_imp.columns if col not in cols_to_drop]
|
| 86 |
+
|
| 87 |
+
# Train data
|
| 88 |
+
# Separate features from target for data with no imputation performed
|
| 89 |
+
train_features_no_imp = train_data_no_imp[features_list].astype('float')
|
| 90 |
+
train_target_no_imp = train_data_no_imp.ExacWithin3Months.astype('float')
|
| 91 |
+
# Separate features from target for data with no imputation performed
|
| 92 |
+
train_features_imp = train_data_imp[features_list].astype('float')
|
| 93 |
+
train_target_imp = train_data_imp.ExacWithin3Months.astype('float')
|
| 94 |
+
|
| 95 |
+
# Test data
|
| 96 |
+
# Separate features from target for data with no imputation performed
|
| 97 |
+
test_features_no_imp = test_data_no_imp[features_list].astype('float')
|
| 98 |
+
test_target_no_imp = test_data_no_imp.ExacWithin3Months.astype('float')
|
| 99 |
+
# Separate features from target for data with no imputation performed
|
| 100 |
+
test_features_imp = test_data_imp[features_list].astype('float')
|
| 101 |
+
test_target_imp = test_data_imp.ExacWithin3Months.astype('float')
|
| 102 |
+
|
| 103 |
+
# Check that the target in imputed and not imputed datasets are the same. If not,
|
| 104 |
+
# raise an error
|
| 105 |
+
if not train_target_no_imp.equals(train_target_imp):
|
| 106 |
+
raise ValueError(
|
| 107 |
+
'Target variable is not the same in imputed and non imputed datasets in the train set.')
|
| 108 |
+
if not test_target_no_imp.equals(test_target_imp):
|
| 109 |
+
raise ValueError(
|
| 110 |
+
'Target variable is not the same in imputed and non imputed datasets in the test set.')
|
| 111 |
+
train_target = train_target_no_imp
|
| 112 |
+
test_target = test_target_no_imp
|
| 113 |
+
|
| 114 |
+
# Make sure all features are numeric
|
| 115 |
+
for features in [train_features_no_imp, train_features_imp,
|
| 116 |
+
test_features_no_imp, test_features_imp]:
|
| 117 |
+
for col in features:
|
| 118 |
+
features[col] = pd.to_numeric(features[col], errors='coerce')
|
| 119 |
+
|
| 120 |
+
##############################################################
|
| 121 |
+
# Specify which models to evaluate
|
| 122 |
+
##############################################################
|
| 123 |
+
# Set up MLflow
|
| 124 |
+
mlflow.set_tracking_uri("sqlite:///mlruns.db")
|
| 125 |
+
mlflow.set_experiment('model_h_drop_1' + file_suffix)
|
| 126 |
+
|
| 127 |
+
# Set CV scoring strategies and any model parameters
|
| 128 |
+
scoring = ['f1', 'balanced_accuracy', 'accuracy', 'precision', 'recall', 'roc_auc',
|
| 129 |
+
'average_precision', 'neg_brier_score']
|
| 130 |
+
scale_pos_weight = train_target.value_counts()[0] / train_target.value_counts()[1]
|
| 131 |
+
|
| 132 |
+
# Set up models, each tuple contains 4 elements: model, model name, imputation status,
|
| 133 |
+
# type of model
|
| 134 |
+
models = []
|
| 135 |
+
# Dummy classifier
|
| 136 |
+
models.append((DummyClassifier(strategy='stratified'),
|
| 137 |
+
'dummy_classifier', 'imputed'))
|
| 138 |
+
# Logistic regression
|
| 139 |
+
models.append((LogisticRegression(random_state=0, max_iter=200),
|
| 140 |
+
'logistic_regression', 'imputed', 'linear'))
|
| 141 |
+
models.append((LogisticRegression(random_state=0, class_weight='balanced', max_iter=200),
|
| 142 |
+
'logistic_regression_CW_balanced', 'imputed', 'linear'))
|
| 143 |
+
# Random forest
|
| 144 |
+
models.append((RandomForestClassifier(random_state=0),
|
| 145 |
+
'random_forest', 'imputed', 'tree'))
|
| 146 |
+
models.append((RandomForestClassifier(random_state=0, class_weight='balanced'),
|
| 147 |
+
'random_forest_CW_balanced', 'imputed', 'tree'))
|
| 148 |
+
models.append((BalancedRandomForestClassifier(random_state=0),
|
| 149 |
+
'balanced_random_forest', 'imputed', 'tree'))
|
| 150 |
+
# Bagging
|
| 151 |
+
models.append((BalancedBaggingClassifier(random_state=0),
|
| 152 |
+
'balanced_bagging', 'imputed', 'tree'))
|
| 153 |
+
# XGBoost
|
| 154 |
+
models.append((xgb.XGBClassifier(random_state=0, use_label_encoder=False,
|
| 155 |
+
eval_metric='logloss', learning_rate=0.1),
|
| 156 |
+
'xgb', 'not_imputed', 'tree'))
|
| 157 |
+
models.append((xgb.XGBClassifier(random_state=0, use_label_encoder=False,
|
| 158 |
+
eval_metric='logloss', learning_rate=0.1, max_depth=4),
|
| 159 |
+
'xgb_mdepth_4', 'not_imputed', 'tree'))
|
| 160 |
+
models.append((xgb.XGBClassifier(random_state=0, use_label_encoder=False,
|
| 161 |
+
eval_metric='logloss', scale_pos_weight=scale_pos_weight, learning_rate=0.1),
|
| 162 |
+
'xgb_spw', 'not_imputed', 'tree'))
|
| 163 |
+
models.append((xgb.XGBClassifier(random_state=0, use_label_encoder=False,
|
| 164 |
+
eval_metric='logloss', scale_pos_weight=scale_pos_weight, learning_rate=0.1,
|
| 165 |
+
max_depth=4),
|
| 166 |
+
'xgb_spw_mdepth_4', 'not_imputed', 'tree'))
|
| 167 |
+
# Light GBM
|
| 168 |
+
models.append((lgb.LGBMClassifier(random_state=0, learning_rate=0.1, verbose_eval=-1),
|
| 169 |
+
'lgbm', 'not_imputed', 'tree'))
|
| 170 |
+
models.append((lgb.LGBMClassifier(random_state=0, learning_rate=0.1,
|
| 171 |
+
scale_pos_weight=scale_pos_weight, verbose_eval=-1),
|
| 172 |
+
'lgbm_spw', 'not_imputed', 'tree'))
|
| 173 |
+
# CatBoost
|
| 174 |
+
models.append((CatBoostClassifier(random_state=0, learning_rate=0.1),
|
| 175 |
+
'catboost', 'not_imputed', 'tree'))
|
| 176 |
+
|
| 177 |
+
# Convert features and target to a numpy array
|
| 178 |
+
# Train data
|
| 179 |
+
#train_features_no_imp = train_features_no_imp.to_numpy()
|
| 180 |
+
#train_features_imp = train_features_imp.to_numpy()
|
| 181 |
+
#train_target = train_target.to_numpy()
|
| 182 |
+
# Test data
|
| 183 |
+
#test_features_no_imp = test_features_no_imp.to_numpy()
|
| 184 |
+
#test_features_imp = test_features_imp.to_numpy()
|
| 185 |
+
#test_target = test_target.to_numpy()
|
| 186 |
+
|
| 187 |
+
##############################################################
|
| 188 |
+
# Run models
|
| 189 |
+
##############################################################
|
| 190 |
+
#In MLflow run, perform K-fold cross validation and capture mean score across folds.
|
| 191 |
+
with mlflow.start_run(run_name='model_selection_less_features_3rd_iter_minus_sex'):
|
| 192 |
+
for model in models:
|
| 193 |
+
with mlflow.start_run(run_name=model[1], nested=True):
|
| 194 |
+
print(model[1])
|
| 195 |
+
# Create the artifacts directory if it doesn't exist
|
| 196 |
+
artifact_dir = './tmp'
|
| 197 |
+
os.makedirs(artifact_dir, exist_ok=True)
|
| 198 |
+
# Remove existing directory contents to not mix files between different runs
|
| 199 |
+
for f in os.listdir(artifact_dir):
|
| 200 |
+
os.remove(os.path.join(artifact_dir, f))
|
| 201 |
+
|
| 202 |
+
# Perform K-fold cross validation with custom folds using imputed dataset for
|
| 203 |
+
# non-sparsity aware models
|
| 204 |
+
if model[2] == 'imputed':
|
| 205 |
+
crossval = cross_validate(model[0], train_features_imp, train_target,
|
| 206 |
+
cv=cross_val_fold_indices,
|
| 207 |
+
return_estimator=True, scoring=scoring,
|
| 208 |
+
return_indices=True)
|
| 209 |
+
|
| 210 |
+
# Get the predicted probabilities from each models
|
| 211 |
+
probabilities_cv = cross_val_predict(model[0], train_features_imp,
|
| 212 |
+
train_target, cv=cross_val_fold_indices,
|
| 213 |
+
method='predict_proba')[:, 1]
|
| 214 |
+
else:
|
| 215 |
+
crossval = cross_validate(model[0], train_features_no_imp, train_target,
|
| 216 |
+
cv=cross_val_fold_indices, return_estimator=True,
|
| 217 |
+
scoring=scoring, return_indices=True)
|
| 218 |
+
|
| 219 |
+
# Get the predicted probabilities from each models
|
| 220 |
+
probabilities_cv = cross_val_predict(model[0], train_features_no_imp,
|
| 221 |
+
train_target, cv=cross_val_fold_indices,
|
| 222 |
+
method='predict_proba')[:, 1]
|
| 223 |
+
|
| 224 |
+
# Get threshold that gives best F1 score
|
| 225 |
+
precision, recall, thresholds = precision_recall_curve(
|
| 226 |
+
train_target, probabilities_cv)
|
| 227 |
+
fscore = (2 * precision * recall) / (precision + recall)
|
| 228 |
+
# When getting the max fscore, if fscore is nan, nan will be returned as the
|
| 229 |
+
# max. Iterate until nan not returned.
|
| 230 |
+
fscore_zero = True
|
| 231 |
+
position = -1
|
| 232 |
+
while fscore_zero is True:
|
| 233 |
+
best_thres_idx = np.argsort(fscore, axis=0)[position]
|
| 234 |
+
if np.isnan(fscore[best_thres_idx]) == True:
|
| 235 |
+
position = position - 1
|
| 236 |
+
else:
|
| 237 |
+
fscore_zero = False
|
| 238 |
+
best_threshold = thresholds[best_thres_idx]
|
| 239 |
+
print('Best Threshold=%f, F-Score=%.3f, Precision=%.3f, Recall=%.3f' % (
|
| 240 |
+
best_threshold, fscore[best_thres_idx], precision[best_thres_idx],
|
| 241 |
+
recall[best_thres_idx]))
|
| 242 |
+
# Save f1 score, precision and recall for the best threshold
|
| 243 |
+
mlflow.log_metric('best_threshold', best_threshold)
|
| 244 |
+
mlflow.log_metric('f1_best_thres', fscore[best_thres_idx])
|
| 245 |
+
mlflow.log_metric('precision_best_thres', precision[best_thres_idx])
|
| 246 |
+
mlflow.log_metric('recall_best_thres', recall[best_thres_idx])
|
| 247 |
+
|
| 248 |
+
# Plot confusion matrix at different thresholds
|
| 249 |
+
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, best_threshold]
|
| 250 |
+
for threshold in thresholds:
|
| 251 |
+
y_predicted = probabilities_cv > threshold
|
| 252 |
+
model_h.plot_confusion_matrix(
|
| 253 |
+
train_target, y_predicted, model[1], threshold, file_suffix)
|
| 254 |
+
|
| 255 |
+
# Generate calibration curves
|
| 256 |
+
if model[1] != 'dummy_classifier':
|
| 257 |
+
# Calibrated model (Sigmoid)
|
| 258 |
+
model_sig = CalibratedClassifierCV(
|
| 259 |
+
model[0], method='sigmoid',cv=cross_val_fold_indices)
|
| 260 |
+
if model[2] == 'imputed':
|
| 261 |
+
model_sig.fit(train_features_imp, train_target)
|
| 262 |
+
probs_sig = model_sig.predict_proba(test_features_imp)[:, 1]
|
| 263 |
+
else:
|
| 264 |
+
model_sig.fit(train_features_no_imp, train_target)
|
| 265 |
+
probs_sig = model_sig.predict_proba(test_features_no_imp)[:, 1]
|
| 266 |
+
|
| 267 |
+
# Calibrated model (Isotonic)
|
| 268 |
+
model_iso = CalibratedClassifierCV(
|
| 269 |
+
model[0], method='isotonic', cv=cross_val_fold_indices)
|
| 270 |
+
if model[2] == 'imputed':
|
| 271 |
+
model_iso.fit(train_features_imp, train_target)
|
| 272 |
+
probs_iso = model_iso.predict_proba(test_features_imp)[:, 1]
|
| 273 |
+
else:
|
| 274 |
+
model_iso.fit(train_features_no_imp, train_target)
|
| 275 |
+
probs_iso = model_iso.predict_proba(test_features_no_imp)[:, 1]
|
| 276 |
+
|
| 277 |
+
# Spline calibration
|
| 278 |
+
spline_calib = mli.SplineCalib()
|
| 279 |
+
spline_calib.fit(probabilities_cv, train_target)
|
| 280 |
+
|
| 281 |
+
if model[2] == 'imputed':
|
| 282 |
+
model[0].fit(train_features_imp,train_target)
|
| 283 |
+
preds_test_uncalib = model[0].predict_proba(test_features_imp)[:,1]
|
| 284 |
+
else:
|
| 285 |
+
model[0].fit(train_features_no_imp,train_target)
|
| 286 |
+
preds_test_uncalib = model[0].predict_proba(test_features_no_imp)[:,1]
|
| 287 |
+
probs_spline = spline_calib.calibrate(preds_test_uncalib)
|
| 288 |
+
|
| 289 |
+
# Plot calibration curves for equal width bins (each bin has same width) and
|
| 290 |
+
# equal frequency bins (each bin has same number of observations)
|
| 291 |
+
for strategy in ['uniform', 'quantile']:
|
| 292 |
+
for bin_num in [5, 10]:
|
| 293 |
+
if strategy == 'uniform':
|
| 294 |
+
print('--- Creating calibration curve with equal width bins ---')
|
| 295 |
+
print('-- Num bins:', bin_num, ' --')
|
| 296 |
+
else:
|
| 297 |
+
print('--- Creating calibration curve with equal frequency bins ---')
|
| 298 |
+
print('-- Num bins:', bin_num, ' --')
|
| 299 |
+
print('Uncalibrated model:')
|
| 300 |
+
prob_true_uncal, prob_pred_uncal = calibration_curve(
|
| 301 |
+
train_target, probabilities_cv,n_bins=bin_num, strategy=strategy)
|
| 302 |
+
print('Calibrated model (sigmoid):')
|
| 303 |
+
prob_true_sig, prob_pred_sig = calibration_curve(
|
| 304 |
+
test_target, probs_sig, n_bins=bin_num, strategy=strategy)
|
| 305 |
+
print('Calibrated model (isotonic):')
|
| 306 |
+
prob_true_iso, prob_pred_iso = calibration_curve(
|
| 307 |
+
test_target, probs_iso, n_bins=bin_num, strategy=strategy)
|
| 308 |
+
print('Calibrated model (spline):')
|
| 309 |
+
prob_true_spline, prob_pred_spline = calibration_curve(
|
| 310 |
+
test_target, probs_spline, n_bins=bin_num, strategy=strategy)
|
| 311 |
+
|
| 312 |
+
plt.figure(figsize=(8,8))
|
| 313 |
+
plt.plot([0, 1], [0, 1], linestyle='--')
|
| 314 |
+
plt.plot(prob_pred_uncal, prob_true_uncal, marker='.',
|
| 315 |
+
label='Uncalibrated\n' + model[1])
|
| 316 |
+
plt.plot(prob_pred_sig, prob_true_sig, marker='.',
|
| 317 |
+
label='Calibrated (Sigmoid)\n' + model[1])
|
| 318 |
+
plt.plot(prob_pred_iso, prob_true_iso, marker='.',
|
| 319 |
+
label='Calibrated (Isotonic)\n' + model[1])
|
| 320 |
+
plt.plot(prob_pred_spline, prob_true_spline, marker='.',
|
| 321 |
+
label='Calibrated (Spline)\n' + model[1])
|
| 322 |
+
plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
|
| 323 |
+
plt.tight_layout()
|
| 324 |
+
plt.savefig(os.path.join(artifact_dir, model[1] + '_uncal_' +
|
| 325 |
+
strategy + '_bins' + str(bin_num) +
|
| 326 |
+
file_suffix + '.png'))
|
| 327 |
+
plt.close()
|
| 328 |
+
|
| 329 |
+
# Get total gain and total cover for boosting machine models
|
| 330 |
+
if model[1].startswith("xgb"):
|
| 331 |
+
feat_importance_tot_gain_df = model_h.plot_feat_importance_model(
|
| 332 |
+
model[0], model[1], file_suffix=file_suffix)
|
| 333 |
+
if (model[1].startswith("lgbm")):
|
| 334 |
+
feature_names = train_features_no_imp.columns.tolist()
|
| 335 |
+
feat_importance_tot_gain_df = model_h.plot_feat_importance_model(
|
| 336 |
+
model[0], model[1], file_suffix=file_suffix, feature_names=feature_names)
|
| 337 |
+
# Save feature importance by total gain
|
| 338 |
+
if (model[1].startswith("xgb")) | (model[1].startswith("lgbm")):
|
| 339 |
+
feat_importance_tot_gain_df.to_csv(
|
| 340 |
+
'./data/feature_importance_tot_gain' + file_suffix + '.csv', index=False)
|
| 341 |
+
|
| 342 |
+
# SHAP
|
| 343 |
+
if model[1] not in ['dummy_classifier', 'balanced_bagging']:
|
| 344 |
+
shap_values_list_train = []
|
| 345 |
+
shap_vals_per_cv = {}
|
| 346 |
+
|
| 347 |
+
# Create a dictionary to contain shap values. Dictionary is structured as
|
| 348 |
+
# index : fold_num : shap_values
|
| 349 |
+
for idx in range(0, len(train_data_imp)):
|
| 350 |
+
shap_vals_per_cv[idx] = {}
|
| 351 |
+
for n_fold in range(0, 5):
|
| 352 |
+
shap_vals_per_cv[idx][n_fold] = {}
|
| 353 |
+
|
| 354 |
+
# Get SHAP values for each fold
|
| 355 |
+
fold_num = 0
|
| 356 |
+
for i, estimator in enumerate(crossval['estimator']):
|
| 357 |
+
fold_num = fold_num + 1
|
| 358 |
+
# If imputation needed for model, use imputed features
|
| 359 |
+
if model[1] in ['logistic_regression',
|
| 360 |
+
'logistic_regression_CW_balanced', 'random_forest',
|
| 361 |
+
'random_forest_CW_balanced', 'balanced_bagging',
|
| 362 |
+
'balanced_random_forest']:
|
| 363 |
+
#X_test = train_features_imp[crossval['indices']['test'][i]]
|
| 364 |
+
X_train = train_features_imp.iloc[crossval['indices']['train'][i]]
|
| 365 |
+
X_test = train_features_imp.iloc[crossval['indices']['test'][i]]
|
| 366 |
+
else:
|
| 367 |
+
X_train = train_features_no_imp.iloc[crossval['indices']['train'][i]]
|
| 368 |
+
X_test = train_features_no_imp.iloc[crossval['indices']['test'][i]]
|
| 369 |
+
|
| 370 |
+
# Apply different explainers depending on type of model
|
| 371 |
+
if model[3] == 'linear':
|
| 372 |
+
explainer = shap.LinearExplainer(estimator, X_train)
|
| 373 |
+
if model[3] == 'tree':
|
| 374 |
+
explainer = shap.TreeExplainer(estimator)
|
| 375 |
+
|
| 376 |
+
# Get shap values
|
| 377 |
+
shap_values_train = explainer.shap_values(X_train)
|
| 378 |
+
# Output of shap values for some models is (class, num samples,
|
| 379 |
+
# num features). Get these in the format of (num samples, num features)
|
| 380 |
+
if len(np.shape(shap_values_train)) == 3:
|
| 381 |
+
shap_values_train = shap_values_train[1]
|
| 382 |
+
|
| 383 |
+
# Plot SHAP plots for each cv fold
|
| 384 |
+
shap.summary_plot(np.array(shap_values_train), X_train, show=False)
|
| 385 |
+
plt.savefig(os.path.join(artifact_dir, model[1] + '_shap_cv_fold_' +
|
| 386 |
+
str(fold_num) + file_suffix + '.png'))
|
| 387 |
+
plt.close()
|
| 388 |
+
|
| 389 |
+
# Add shap values to a dictionary.
|
| 390 |
+
train_idxs = X_train.index.tolist()
|
| 391 |
+
for n, train_idx in enumerate(train_idxs):
|
| 392 |
+
shap_vals_per_cv[train_idx][i] = shap_values_train[n]
|
| 393 |
+
|
| 394 |
+
# Calculate average shap values
|
| 395 |
+
average_shap_values, stds, ranges = [],[],[]
|
| 396 |
+
for i in range(0,len(train_data_imp)):
|
| 397 |
+
for n in range(0,5):
|
| 398 |
+
# If a cv fold is empty as that set has not been used in training,
|
| 399 |
+
# replace empty fold with NaN
|
| 400 |
+
try:
|
| 401 |
+
if not shap_vals_per_cv[i][n]:
|
| 402 |
+
shap_vals_per_cv[i][n] = np.NaN
|
| 403 |
+
except:
|
| 404 |
+
pass
|
| 405 |
+
# Create a df for each index that contains all shap values for each cv
|
| 406 |
+
# fold
|
| 407 |
+
df_per_obs = pd.DataFrame.from_dict(shap_vals_per_cv[i])
|
| 408 |
+
# Get relevant statistics for every sample
|
| 409 |
+
average_shap_values.append(df_per_obs.mean(axis=1).values)
|
| 410 |
+
stds.append(df_per_obs.std(axis=1).values)
|
| 411 |
+
ranges.append(df_per_obs.max(axis=1).values-df_per_obs.min(axis=1).values)
|
| 412 |
+
|
| 413 |
+
# Plot SHAP plots
|
| 414 |
+
if model[2] == 'imputed':
|
| 415 |
+
shap.summary_plot(np.array(average_shap_values), train_data_imp.drop(
|
| 416 |
+
columns=['StudyId', 'ExacWithin3Months']), show=False)
|
| 417 |
+
if model[2] == 'not_imputed':
|
| 418 |
+
shap.summary_plot(np.array(average_shap_values), train_data_no_imp.drop(
|
| 419 |
+
columns=['StudyId', 'ExacWithin3Months']), show=False)
|
| 420 |
+
plt.savefig(
|
| 421 |
+
os.path.join(artifact_dir, model[1] + '_shap' + file_suffix + '.png'))
|
| 422 |
+
plt.close()
|
| 423 |
+
|
| 424 |
+
# Get list of most important features in order
|
| 425 |
+
feat_importance_df = model_h.get_shap_feat_importance(
|
| 426 |
+
model[1], average_shap_values, features_list, file_suffix)
|
| 427 |
+
feat_importance_df.to_csv(
|
| 428 |
+
'./data/feature_importance_shap' + file_suffix + '.csv', index=False)
|
| 429 |
+
|
| 430 |
+
# Plot distribution of model scores (histogram plus KDE)
|
| 431 |
+
model_scores = pd.DataFrame({'model_score': probabilities_cv,
|
| 432 |
+
'true_label': train_target})
|
| 433 |
+
sns.displot(model_scores, x="model_score", hue="true_label", kde=True)
|
| 434 |
+
plt.savefig(os.path.join(artifact_dir, model[1] + 'score_distribution' +
|
| 435 |
+
file_suffix + '.png'))
|
| 436 |
+
plt.close()
|
| 437 |
+
|
| 438 |
+
# Log metrics averaged across folds
|
| 439 |
+
for score in scoring:
|
| 440 |
+
mlflow.log_metric(score, crossval['test_' + score].mean())
|
| 441 |
+
mlflow.log_metric(score + '_std', crossval['test_' + score].std())
|
| 442 |
+
# Log model parameters
|
| 443 |
+
params = model[0].get_params()
|
| 444 |
+
for param in params:
|
| 445 |
+
mlflow.log_param(param, params[param])
|
| 446 |
+
# Log artifacts
|
| 447 |
+
mlflow.log_artifacts(artifact_dir)
|
| 448 |
+
|
| 449 |
+
mlflow.end_run()
|
| 450 |
+
|
| 451 |
+
# Join shap feature importance and total gain
|
| 452 |
+
shap_feat_importance = pd.read_csv(
|
| 453 |
+
'./data/feature_importance_shap' + file_suffix + '.csv')
|
| 454 |
+
tot_gain_feat_importance = pd.read_csv(
|
| 455 |
+
'./data/feature_importance_tot_gain' + file_suffix + '.csv')
|
| 456 |
+
tot_gain_feat_importance = tot_gain_feat_importance.rename(columns={'index':'col_name'})
|
| 457 |
+
feat_importance_hierarchy = shap_feat_importance.merge(
|
| 458 |
+
tot_gain_feat_importance, on='col_name', how='left')
|
| 459 |
+
feat_importance_hierarchy.to_csv(
|
| 460 |
+
'./data/feat_importance_hierarchy' + file_suffix + '.csv', index=False)
|
training/encode_and_impute.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Script that performs encoding of categorical features and imputation.
|
| 2 |
+
|
| 3 |
+
Performs encoding of categorical features, and imputation of missing values. After encoding
|
| 4 |
+
and imputation are performed, features are dropped. Two versions of the data is saved:
|
| 5 |
+
imputed and not imputed dataframes.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import yaml
|
| 13 |
+
import json
|
| 14 |
+
import joblib
|
| 15 |
+
import encoding
|
| 16 |
+
import imputation
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
with open("./training/config.yaml", "r") as config:
|
| 20 |
+
config = yaml.safe_load(config)
|
| 21 |
+
|
| 22 |
+
# Specify which model to generate features for
|
| 23 |
+
model_type = config["model_settings"]["model_type"]
|
| 24 |
+
|
| 25 |
+
# Setup log file
|
| 26 |
+
log = open("./training/logging/encode_and_impute_" + model_type + ".log", "w")
|
| 27 |
+
sys.stdout = log
|
| 28 |
+
|
| 29 |
+
# Dataset to process - set through config file
|
| 30 |
+
data_to_process = config["model_settings"]["data_to_process"]
|
| 31 |
+
|
| 32 |
+
# Load data
|
| 33 |
+
data = pd.read_pickle(
|
| 34 |
+
os.path.join(
|
| 35 |
+
config["outputs"]["processed_data_dir"],
|
| 36 |
+
"{}_combined_{}.pkl".format(data_to_process, model_type),
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
############################################################################
|
| 41 |
+
# Target encode categorical data
|
| 42 |
+
############################################################################
|
| 43 |
+
|
| 44 |
+
categorical_cols = [
|
| 45 |
+
"LatestSymptomDiaryQ8",
|
| 46 |
+
"LatestSymptomDiaryQ9",
|
| 47 |
+
"LatestSymptomDiaryQ10",
|
| 48 |
+
"DaysSinceLastExac",
|
| 49 |
+
"AgeBinned",
|
| 50 |
+
"Comorbidities",
|
| 51 |
+
"FEV1PercentPredicted",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
# Multiple types of nans present in data ('nan' and np.NaN). Convert all these to 'nan' for
|
| 55 |
+
# categorical columns
|
| 56 |
+
for categorical_col in categorical_cols:
|
| 57 |
+
data[categorical_col] = data[categorical_col].replace(np.nan, "nan")
|
| 58 |
+
|
| 59 |
+
if data_to_process == "train":
|
| 60 |
+
# Get target encodings for entire train set
|
| 61 |
+
target_encodings = encoding.get_target_encodings(
|
| 62 |
+
train_data=data,
|
| 63 |
+
cols_to_encode=categorical_cols,
|
| 64 |
+
target_col="ExacWithin3Months",
|
| 65 |
+
smooth="auto",
|
| 66 |
+
)
|
| 67 |
+
train_encoded = encoding.apply_target_encodings(
|
| 68 |
+
data=data,
|
| 69 |
+
cols_to_encode=categorical_cols,
|
| 70 |
+
encodings=target_encodings,
|
| 71 |
+
drop_categorical_cols=False,
|
| 72 |
+
)
|
| 73 |
+
json.dump(
|
| 74 |
+
target_encodings,
|
| 75 |
+
open("./data/artifacts/target_encodings_" + model_type + ".json", "w"),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# K-fold target encode
|
| 79 |
+
# Get info on which patients belong to which fold
|
| 80 |
+
fold_patients = np.load(
|
| 81 |
+
os.path.join(
|
| 82 |
+
config["outputs"]["cohort_info_dir"],
|
| 83 |
+
"fold_patients_{}.npy".format(model_type),
|
| 84 |
+
),
|
| 85 |
+
allow_pickle=True,
|
| 86 |
+
)
|
| 87 |
+
train_encoded_cv, target_encodings = encoding.kfold_target_encode(
|
| 88 |
+
df=data,
|
| 89 |
+
fold_ids=fold_patients,
|
| 90 |
+
cols_to_encode=categorical_cols,
|
| 91 |
+
id_col="StudyId",
|
| 92 |
+
target="ExacWithin3Months",
|
| 93 |
+
smooth="auto",
|
| 94 |
+
drop_categorical_cols=False,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Drop categorical cols except for AgeBinned as it is needed in imputation step
|
| 98 |
+
categorical_cols.remove("AgeBinned")
|
| 99 |
+
train_encoded = train_encoded.drop(columns=categorical_cols)
|
| 100 |
+
train_encoded_cv = train_encoded_cv.drop(columns=categorical_cols)
|
| 101 |
+
|
| 102 |
+
if (data_to_process == "test") | (data_to_process == "forward_val"):
|
| 103 |
+
# Encode test set/forward val set based on entire train set
|
| 104 |
+
target_encodings = json.load(
|
| 105 |
+
open("./data/artifacts/target_encodings_" + model_type + ".json")
|
| 106 |
+
)
|
| 107 |
+
test_encoded = encoding.apply_target_encodings(
|
| 108 |
+
data=data,
|
| 109 |
+
cols_to_encode=categorical_cols,
|
| 110 |
+
encodings=target_encodings,
|
| 111 |
+
drop_categorical_cols=False,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Drop categorical cols except for AgeBinned as it is needed in imputation step
|
| 115 |
+
categorical_cols.remove("AgeBinned")
|
| 116 |
+
test_encoded = test_encoded.drop(columns=categorical_cols)
|
| 117 |
+
|
| 118 |
+
############################################################################
|
| 119 |
+
# Impute missing data
|
| 120 |
+
############################################################################
|
| 121 |
+
|
| 122 |
+
cols_to_ignore = [
|
| 123 |
+
"StudyId",
|
| 124 |
+
"PatientId",
|
| 125 |
+
"IndexDate",
|
| 126 |
+
"ExacWithin3Months",
|
| 127 |
+
"HospExacWithin3Months",
|
| 128 |
+
"CommExacWithin3Months",
|
| 129 |
+
"Age",
|
| 130 |
+
"Sex_F",
|
| 131 |
+
"SafeHavenID",
|
| 132 |
+
"AgeBinned",
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
if data_to_process == "train":
|
| 136 |
+
# Impute entire train set
|
| 137 |
+
not_imputed_train = train_encoded.copy()
|
| 138 |
+
cols_to_impute = train_encoded.drop(columns=cols_to_ignore).columns
|
| 139 |
+
|
| 140 |
+
imputer = imputation.get_imputer(
|
| 141 |
+
train_data=train_encoded,
|
| 142 |
+
cols_to_impute=cols_to_impute,
|
| 143 |
+
average_type="median",
|
| 144 |
+
cols_to_groupby=["AgeBinned", "Sex_F"],
|
| 145 |
+
)
|
| 146 |
+
imputed_train = imputation.apply_imputer(
|
| 147 |
+
data=train_encoded,
|
| 148 |
+
cols_to_impute=cols_to_impute,
|
| 149 |
+
imputer=imputer,
|
| 150 |
+
cols_to_groupby=["AgeBinned", "Sex_F"],
|
| 151 |
+
)
|
| 152 |
+
joblib.dump(imputer, "./data/artifacts/imputer_" + model_type + ".pkl")
|
| 153 |
+
|
| 154 |
+
# K-fold impute
|
| 155 |
+
not_imputed_train_cv = train_encoded_cv.copy()
|
| 156 |
+
imputed_train_cv = imputation.kfold_impute(
|
| 157 |
+
df=train_encoded,
|
| 158 |
+
fold_ids=fold_patients,
|
| 159 |
+
cols_to_impute=cols_to_impute,
|
| 160 |
+
average_type="median",
|
| 161 |
+
cols_to_groupby=["AgeBinned", "Sex_F"],
|
| 162 |
+
id_col="StudyId",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
df_columns = imputed_train.columns.tolist()
|
| 166 |
+
|
| 167 |
+
if (data_to_process == "test") | (data_to_process == "forward_val"):
|
| 168 |
+
not_imputed_test = test_encoded.copy()
|
| 169 |
+
cols_to_impute = test_encoded.drop(columns=cols_to_ignore).columns
|
| 170 |
+
|
| 171 |
+
# Impute test set/forward val set based on entire train set
|
| 172 |
+
imputer = joblib.load("./data/artifacts/imputer_" + model_type + ".pkl")
|
| 173 |
+
imputed_test = imputation.apply_imputer(
|
| 174 |
+
data=test_encoded,
|
| 175 |
+
cols_to_impute=cols_to_impute,
|
| 176 |
+
imputer=imputer,
|
| 177 |
+
cols_to_groupby=["AgeBinned", "Sex_F"],
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
df_columns = imputed_test.columns.tolist()
|
| 181 |
+
|
| 182 |
+
############################################################################
|
| 183 |
+
# Reduce feature space
|
| 184 |
+
############################################################################
|
| 185 |
+
cols_to_drop_startswith = (
|
| 186 |
+
"DiffLatest",
|
| 187 |
+
"Var",
|
| 188 |
+
"LatestEQ5D",
|
| 189 |
+
"TotalEngagement",
|
| 190 |
+
"Age",
|
| 191 |
+
"NumHosp",
|
| 192 |
+
"Required",
|
| 193 |
+
"LungFunction",
|
| 194 |
+
"EngagementCAT",
|
| 195 |
+
"LatestSymptomDiary",
|
| 196 |
+
"LatestAlbumin",
|
| 197 |
+
"LatestEosinophils",
|
| 198 |
+
"LatestNeutrophils",
|
| 199 |
+
"LatestWhite Blood Count",
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
additional_cols_to_drop = [
|
| 203 |
+
"PatientId",
|
| 204 |
+
"SafeHavenID",
|
| 205 |
+
"Sex_F",
|
| 206 |
+
"NumCommExacPrior6mo",
|
| 207 |
+
"AsthmaOverlap",
|
| 208 |
+
"TimeSinceLungFunc",
|
| 209 |
+
"LatestNeutLymphRatio",
|
| 210 |
+
"EngagementEQ5DTW1",
|
| 211 |
+
"EngagementMRCTW1",
|
| 212 |
+
"LatestMRCQ1",
|
| 213 |
+
"WeekAvgCATQ1",
|
| 214 |
+
"WeekAvgCATQ3",
|
| 215 |
+
"WeekAvgCATQ4",
|
| 216 |
+
"WeekAvgCATQ5",
|
| 217 |
+
"WeekAvgCATQ6",
|
| 218 |
+
"WeekAvgCATQ7",
|
| 219 |
+
"WeekAvgCATQ8",
|
| 220 |
+
"WeekAvgSymptomDiaryQ1",
|
| 221 |
+
"WeekAvgSymptomDiaryQ3",
|
| 222 |
+
"WeekAvgSymptomDiaryScore",
|
| 223 |
+
"EngagementSymptomDiaryTW1",
|
| 224 |
+
"ScaledSumSymptomDiaryQ3TW1",
|
| 225 |
+
# "Comorbidities_te",
|
| 226 |
+
]
|
| 227 |
+
|
| 228 |
+
cols_to_drop = []
|
| 229 |
+
cols_to_drop.extend(
|
| 230 |
+
[item for item in df_columns if item.startswith(cols_to_drop_startswith)]
|
| 231 |
+
)
|
| 232 |
+
cols_to_drop.extend(additional_cols_to_drop)
|
| 233 |
+
|
| 234 |
+
if data_to_process == "train":
|
| 235 |
+
imputed_train = imputed_train.drop(columns=cols_to_drop)
|
| 236 |
+
not_imputed_train = not_imputed_train.drop(columns=cols_to_drop)
|
| 237 |
+
imputed_train_cv = imputed_train_cv.drop(columns=cols_to_drop)
|
| 238 |
+
not_imputed_train_cv = not_imputed_train_cv.drop(columns=cols_to_drop)
|
| 239 |
+
if (data_to_process == "test") | (data_to_process == "forward_val"):
|
| 240 |
+
imputed_test = imputed_test.drop(columns=cols_to_drop)
|
| 241 |
+
not_imputed_test = not_imputed_test.drop(columns=cols_to_drop)
|
| 242 |
+
|
| 243 |
+
############################################################################
|
| 244 |
+
# Save data
|
| 245 |
+
############################################################################
|
| 246 |
+
os.makedirs(config["outputs"]["model_input_data_dir"], exist_ok=True)
|
| 247 |
+
|
| 248 |
+
if data_to_process == "train":
|
| 249 |
+
imputed_train.to_pickle(
|
| 250 |
+
os.path.join(
|
| 251 |
+
config["outputs"]["model_input_data_dir"],
|
| 252 |
+
"{}_imputed_{}.pkl".format(data_to_process, model_type),
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
not_imputed_train.to_pickle(
|
| 256 |
+
os.path.join(
|
| 257 |
+
config["outputs"]["model_input_data_dir"],
|
| 258 |
+
"{}_not_imputed_{}.pkl".format(data_to_process, model_type),
|
| 259 |
+
)
|
| 260 |
+
)
|
| 261 |
+
imputed_train_cv.to_pickle(
|
| 262 |
+
os.path.join(
|
| 263 |
+
config["outputs"]["model_input_data_dir"],
|
| 264 |
+
"{}_imputed_cv_{}.pkl".format(data_to_process, model_type),
|
| 265 |
+
)
|
| 266 |
+
)
|
| 267 |
+
not_imputed_train_cv.to_pickle(
|
| 268 |
+
os.path.join(
|
| 269 |
+
config["outputs"]["model_input_data_dir"],
|
| 270 |
+
"{}_not_imputed_cv_{}.pkl".format(data_to_process, model_type),
|
| 271 |
+
)
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
if (data_to_process == "test") | (data_to_process == "forward_val"):
|
| 275 |
+
imputed_test.to_pickle(
|
| 276 |
+
os.path.join(
|
| 277 |
+
config["outputs"]["model_input_data_dir"],
|
| 278 |
+
"{}_imputed_{}.pkl".format(data_to_process, model_type),
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
not_imputed_test.to_pickle(
|
| 282 |
+
os.path.join(
|
| 283 |
+
config["outputs"]["model_input_data_dir"],
|
| 284 |
+
"{}_not_imputed_{}.pkl".format(data_to_process, model_type),
|
| 285 |
+
)
|
| 286 |
+
)
|
training/encoding.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Functions for encoding categorical data with additive smoothing techniques applied."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import itertools
|
| 6 |
+
from sklearn.preprocessing import TargetEncoder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_target_encodings(
|
| 10 |
+
*,
|
| 11 |
+
train_data,
|
| 12 |
+
cols_to_encode,
|
| 13 |
+
target_col,
|
| 14 |
+
smooth="auto",
|
| 15 |
+
keep_nans_as_category=False,
|
| 16 |
+
cols_to_keep_nan_category=None,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Retrieve target encodings of input data for later use (performs no encoding).
|
| 20 |
+
|
| 21 |
+
The complete train data set is used to target encode the holdout test data set. This
|
| 22 |
+
function is used to obtain encodings for storage and later use on separate data.
|
| 23 |
+
|
| 24 |
+
Smoothing addresses overfitting caused by sparse data by relying on the global mean
|
| 25 |
+
(mean of target across all rows) rather than the local mean (mean of target across a
|
| 26 |
+
specific category) when there are a small number of observations in a category. The
|
| 27 |
+
degree of smoothing is controlled by the parameter 'smooth'. Higher values of 'smooth'
|
| 28 |
+
increases the influence of the global mean on the target encoding. A 'smooth' value of
|
| 29 |
+
100 can be interpreted as: there must be at least 100 values in the category for the
|
| 30 |
+
sample mean to overtake the global mean.
|
| 31 |
+
|
| 32 |
+
There is also an option to keep nan's as a category for cases on data missing not at
|
| 33 |
+
random. The format of nan for categorical columns required for the function is 'nan'.
|
| 34 |
+
|
| 35 |
+
Use kfold_target_encode to perform kfold encoding, and apply_target_encodings to use the
|
| 36 |
+
output of this function on the test data.
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
train_data : dataframe
|
| 41 |
+
data to be used to for target encoding at a later stage. This is likely the full
|
| 42 |
+
train data set.
|
| 43 |
+
cols_to_encode : list of strings
|
| 44 |
+
names of columns to be encoded.
|
| 45 |
+
target_col : str
|
| 46 |
+
name of the target variable column.
|
| 47 |
+
smooth : str or float, optional
|
| 48 |
+
controls the amount of smoothing applied. A larger smooth value will put more
|
| 49 |
+
weight on the global target mean. If "auto", then smooth is set to an
|
| 50 |
+
empirical Bayes estimate, defaults to "auto".
|
| 51 |
+
keep_nans_as_category : bool, optional
|
| 52 |
+
option to retain nans as a category for cases of data missing not at random, by
|
| 53 |
+
default False.
|
| 54 |
+
cols_to_keep_nan_category : list of strings, optional
|
| 55 |
+
names of columns to keep the encoded nan category, by default None. Need to state
|
| 56 |
+
names of columns if keep_nans_as_category is True.
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
encodings_all : dict
|
| 61 |
+
encodings used for each column.
|
| 62 |
+
|
| 63 |
+
Raises
|
| 64 |
+
-------
|
| 65 |
+
ValueError
|
| 66 |
+
error raised if there are multiple types of nan's in columns to be encoded.
|
| 67 |
+
ValueError
|
| 68 |
+
error raised if nans are not in the correct format: 'nan'.
|
| 69 |
+
ValueError
|
| 70 |
+
error raised if keep_nans_as_category is True but columns not provided.
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
train_data_to_encode = train_data[cols_to_encode]
|
| 74 |
+
train_target = train_data[target_col]
|
| 75 |
+
|
| 76 |
+
# Raise an error if there are multiple types of nan's
|
| 77 |
+
all_nan_types = [None, "None", np.NaN, "nan", "NAN", "N/A"]
|
| 78 |
+
incorrect_nan_types = ["None", np.NaN, "nan", "NAN", "N/A"]
|
| 79 |
+
for col in train_data_to_encode:
|
| 80 |
+
cat_present = train_data_to_encode[col].unique().tolist()
|
| 81 |
+
if len(list(set(all_nan_types) & set(cat_present))) > 1:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
"Multiple types of nans present in data. Make sure that missing values in"
|
| 84 |
+
"categorical columns are all recorded as 'nan'."
|
| 85 |
+
)
|
| 86 |
+
# Raise an error if nan not in correct format for function
|
| 87 |
+
if any(element in all_nan_types for element in cat_present):
|
| 88 |
+
if not "nan" in cat_present:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
"Missing values in categorical columns are not recorded as 'nan'."
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
encoder = TargetEncoder(smooth=smooth)
|
| 94 |
+
encoder = encoder.fit(train_data_to_encode, train_target)
|
| 95 |
+
|
| 96 |
+
# Get dictionary with encodings
|
| 97 |
+
paired_dicts = []
|
| 98 |
+
paired_arrays = zip(encoder.categories_, encoder.encodings_)
|
| 99 |
+
for category_array, value_array in paired_arrays:
|
| 100 |
+
paired_dict = dict(zip(category_array, value_array))
|
| 101 |
+
paired_dicts.append(paired_dict)
|
| 102 |
+
encodings_all = dict(zip(encoder.feature_names_in_, paired_dicts))
|
| 103 |
+
|
| 104 |
+
# Sklearn treats nans as a category. The default in this function is to convert nan
|
| 105 |
+
# categories back to np.NaN unless stated otherwise.
|
| 106 |
+
if keep_nans_as_category is False:
|
| 107 |
+
for col in encodings_all:
|
| 108 |
+
encodings_all[col].update({"nan": np.nan})
|
| 109 |
+
# If it is specified to keep nan categories for specific features, only those features
|
| 110 |
+
# not specified in cols_to_keep_nan_category are converted to np.NaN.
|
| 111 |
+
if (keep_nans_as_category is True) and (cols_to_keep_nan_category is None):
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"Parameter keep_nans_as_category is True but cols_to_keep_nan_category not provided."
|
| 114 |
+
)
|
| 115 |
+
if (keep_nans_as_category is True) and not (cols_to_keep_nan_category is None):
|
| 116 |
+
cols_to_remove_nan_cat = set(cols_to_encode) - set(cols_to_keep_nan_category)
|
| 117 |
+
for col_to_remove_nan in cols_to_remove_nan_cat:
|
| 118 |
+
encodings_all[col_to_remove_nan].update({"nan": np.nan})
|
| 119 |
+
|
| 120 |
+
return encodings_all
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def apply_target_encodings(*, data, cols_to_encode, encodings, drop_categorical_cols=False):
|
| 124 |
+
"""Target encode input data with supplied encodings.
|
| 125 |
+
|
| 126 |
+
Parameters
|
| 127 |
+
----------
|
| 128 |
+
data : dataframe
|
| 129 |
+
data with columns to be target encoded.
|
| 130 |
+
cols_to_encode : list of strings
|
| 131 |
+
list of columns to target encode.
|
| 132 |
+
encodings : dict
|
| 133 |
+
target encodings to use on input data (from training data).
|
| 134 |
+
drop_categorical_cols: bool, optional
|
| 135 |
+
option to drop categorical columns after encoding, defaults to False.
|
| 136 |
+
|
| 137 |
+
Returns
|
| 138 |
+
-------
|
| 139 |
+
data : dataframe
|
| 140 |
+
target encoded version of the input data.
|
| 141 |
+
|
| 142 |
+
Raises
|
| 143 |
+
-------
|
| 144 |
+
AssertionError
|
| 145 |
+
raises an error if the column to be encoded is not in the passed data.
|
| 146 |
+
ValueError
|
| 147 |
+
error raised if nans are not in the correct format: 'nan'.
|
| 148 |
+
ValueError
|
| 149 |
+
error raised if keep_nans_as_category is True but columns not provided.
|
| 150 |
+
|
| 151 |
+
"""
|
| 152 |
+
data_to_encode = data[cols_to_encode]
|
| 153 |
+
# Raise an error if there are multiple types of nan's
|
| 154 |
+
all_nan_types = [None, "None", np.NaN, "nan", "NAN", "N/A"]
|
| 155 |
+
incorrect_nan_types = ["None", np.NaN, "nan", "NAN", "N/A"]
|
| 156 |
+
for col in data_to_encode:
|
| 157 |
+
cat_present = data_to_encode[col].unique().tolist()
|
| 158 |
+
if len(list(set(all_nan_types) & set(cat_present))) > 1:
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"Multiple types of nans present in data. Make sure that missing values in"
|
| 161 |
+
"categorical columns are all recorded as 'nan'."
|
| 162 |
+
)
|
| 163 |
+
# Raise an error if nan not in correct format for function
|
| 164 |
+
if any(element in all_nan_types for element in cat_present):
|
| 165 |
+
if not "nan" in cat_present:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
"Missing values in categorical columns are not recorded as 'nan'."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
encoded_data = data.copy()
|
| 171 |
+
for col in cols_to_encode:
|
| 172 |
+
assert (
|
| 173 |
+
col in encodings.keys()
|
| 174 |
+
), "No target encodings found for {} column".format(col)
|
| 175 |
+
encodings_col = encodings[col]
|
| 176 |
+
|
| 177 |
+
# Account for the case where the new data includes a category not present
|
| 178 |
+
# in the train data encodings and set that category encoding nan
|
| 179 |
+
data_unique = data[col].unique().tolist()
|
| 180 |
+
encodings_unique = list(set(encodings_col.keys()))
|
| 181 |
+
diffs = np.setdiff1d(data_unique, encodings_unique)
|
| 182 |
+
encodings_col.update(zip(diffs, itertools.repeat(np.nan)))
|
| 183 |
+
|
| 184 |
+
# Use the lookup table to place each category in the current fold with its
|
| 185 |
+
# encoded value from the train data (in the new _te column)
|
| 186 |
+
filtered = encoded_data.filter(items=[col])
|
| 187 |
+
filtered_encodings = filtered.replace(encodings_col)
|
| 188 |
+
filtered_encodings = filtered_encodings.rename(columns={col: col + "_te"})
|
| 189 |
+
encoded_data = pd.concat([encoded_data, filtered_encodings], axis=1)
|
| 190 |
+
|
| 191 |
+
if drop_categorical_cols is True:
|
| 192 |
+
encoded_data = encoded_data.drop(columns=cols_to_encode)
|
| 193 |
+
return encoded_data
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def kfold_target_encode(
|
| 197 |
+
*,
|
| 198 |
+
df,
|
| 199 |
+
fold_ids,
|
| 200 |
+
cols_to_encode,
|
| 201 |
+
id_col,
|
| 202 |
+
target,
|
| 203 |
+
smooth="auto",
|
| 204 |
+
keep_nans_as_category=False,
|
| 205 |
+
cols_to_keep_nan_category=None,
|
| 206 |
+
drop_categorical_cols=False,
|
| 207 |
+
):
|
| 208 |
+
"""Perform K-fold target encoding.
|
| 209 |
+
|
| 210 |
+
Fold by fold target encoding of train data is used to prevent data leakage in cross-
|
| 211 |
+
validation (the same folds are used for encoding and CV). For example, in 5-fold
|
| 212 |
+
target encoding, each fold is encoded using the other 4 folds and that fold is
|
| 213 |
+
then used as the validation fold in CV. Smoothing is performed on each K-fold.
|
| 214 |
+
|
| 215 |
+
Parameters
|
| 216 |
+
----------
|
| 217 |
+
df : dataframe
|
| 218 |
+
data with columns to be target encoded. Will generally be the train data.
|
| 219 |
+
fold_ids : list of arrays
|
| 220 |
+
each array contains the validation patient IDs for each fold.
|
| 221 |
+
cols_to_encode : list of strings
|
| 222 |
+
columns to target encode.
|
| 223 |
+
id_col : str
|
| 224 |
+
name of patient ID column.
|
| 225 |
+
target : str
|
| 226 |
+
name of target column.
|
| 227 |
+
smooth : str or float, optional
|
| 228 |
+
controls the amount of smoothing applied. A larger smooth value will put more
|
| 229 |
+
weight on the global target mean. If "auto", then smooth is set to an
|
| 230 |
+
empirical Bayes estimate, defaults to "auto".
|
| 231 |
+
keep_nans_as_category : bool, optional
|
| 232 |
+
option to retain nans as a category for cases of data missing not at random, by
|
| 233 |
+
default False.
|
| 234 |
+
cols_to_keep_nan_category : list of strings, optional
|
| 235 |
+
names of columns to keep the encoded nan category, by default None. Need to state
|
| 236 |
+
names of columns if keep_nans_as_category is True.
|
| 237 |
+
drop_categorical_cols: bool, optional
|
| 238 |
+
option to drop categorical columns after encoding, defaults to False.
|
| 239 |
+
|
| 240 |
+
Returns
|
| 241 |
+
-------
|
| 242 |
+
encoded_df_cv : dataframe
|
| 243 |
+
k-fold target encoded version of the input data.
|
| 244 |
+
fold_encodings_all : dataframe
|
| 245 |
+
contains target encodings for each fold.
|
| 246 |
+
|
| 247 |
+
"""
|
| 248 |
+
# Loop over CV folds and perform K-fold target encoding
|
| 249 |
+
encoded_data_cv = []
|
| 250 |
+
fold_encodings_all = []
|
| 251 |
+
for fold in fold_ids:
|
| 252 |
+
# Divide data into train folds and validation fold
|
| 253 |
+
validation_fold = df[df[id_col].isin(fold)]
|
| 254 |
+
train_folds = df[~df[id_col].isin(fold)]
|
| 255 |
+
|
| 256 |
+
# Obtain target encodings from train folds
|
| 257 |
+
fold_encodings = get_target_encodings(
|
| 258 |
+
train_data=train_folds,
|
| 259 |
+
cols_to_encode=cols_to_encode,
|
| 260 |
+
target_col=target,
|
| 261 |
+
smooth=smooth,
|
| 262 |
+
keep_nans_as_category=keep_nans_as_category,
|
| 263 |
+
cols_to_keep_nan_category=cols_to_keep_nan_category,
|
| 264 |
+
)
|
| 265 |
+
fold_encodings_all.append(fold_encodings)
|
| 266 |
+
|
| 267 |
+
# Apply to validation fold
|
| 268 |
+
encoded_data_fold = apply_target_encodings(
|
| 269 |
+
data=validation_fold,
|
| 270 |
+
cols_to_encode=cols_to_encode,
|
| 271 |
+
encodings=fold_encodings,
|
| 272 |
+
drop_categorical_cols=drop_categorical_cols,
|
| 273 |
+
)
|
| 274 |
+
encoded_data_cv.append(encoded_data_fold)
|
| 275 |
+
|
| 276 |
+
# Place the encoded validation fold data into a single df
|
| 277 |
+
encoded_df_cv = pd.concat(encoded_data_cv)
|
| 278 |
+
|
| 279 |
+
# Place the encodings for all folds into a df
|
| 280 |
+
fold_encodings_all = pd.json_normalize(fold_encodings_all).T
|
| 281 |
+
return encoded_df_cv, fold_encodings_all
|
training/imputation.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Functions for imputing missing data."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pandas.api.types import is_numeric_dtype
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def replace_nan_with_mode(df, col, col_mode, random_state):
|
| 9 |
+
"""Replaces nan in categorical columns with the mode.
|
| 10 |
+
|
| 11 |
+
Function only used for categorical columns. Replaces nan in categorical columns with the
|
| 12 |
+
mode. If there are multiple modes, one mode is randomly chosen.
|
| 13 |
+
|
| 14 |
+
Parameters
|
| 15 |
+
----------
|
| 16 |
+
df : dataframe
|
| 17 |
+
dataframe containing the categorical columns to impute and the columns with the mode.
|
| 18 |
+
col : str
|
| 19 |
+
name of column to impute.
|
| 20 |
+
col_mode : str
|
| 21 |
+
name of column containing the calculated modes.
|
| 22 |
+
random_state : int
|
| 23 |
+
to seed the random generator.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
df : dataframe
|
| 28 |
+
input dataframe with missing values in categorical column imputed.
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
np.random.seed(seed=random_state)
|
| 32 |
+
for index, row in df.iterrows():
|
| 33 |
+
if row[col] == "nan":
|
| 34 |
+
# Deals with cases where there are multiple modes. If there are multiple modes,
|
| 35 |
+
# one mode is chosen at random
|
| 36 |
+
if (not isinstance(row[col_mode], str)) and len(row[col_mode]) > 1:
|
| 37 |
+
df.at[index, col] = np.random.choice(list(row[col_mode]))
|
| 38 |
+
else:
|
| 39 |
+
df.at[index, col] = row[col_mode]
|
| 40 |
+
return df
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_imputer(*, train_data, cols_to_impute, average_type, cols_to_groupby=None):
|
| 44 |
+
"""Retrieve imputer of input data for later use (performs no imputing).
|
| 45 |
+
|
| 46 |
+
The complete train data set is used to impute the holdout test data set. This function
|
| 47 |
+
is used to obtain imputations for storage and later use on separate data. The average
|
| 48 |
+
specified (e.g. median) can be calculated on the data provided or the data can be
|
| 49 |
+
grouped by specified features (e.g. binned age and sex) prior to average calculation.
|
| 50 |
+
For categorical columns, the mode is used. The format of nan for categorical columns
|
| 51 |
+
required for the function is 'nan'.
|
| 52 |
+
|
| 53 |
+
Use apply_imputer to perform imputation.
|
| 54 |
+
|
| 55 |
+
Parameters
|
| 56 |
+
----------
|
| 57 |
+
train_data : dataframe
|
| 58 |
+
data to be used to for imputation at a later stage. This is likely the full train
|
| 59 |
+
data set.
|
| 60 |
+
cols_to_impute : list of strings
|
| 61 |
+
names of columns to perform imputation on. If cols_to_groupby is not None, these
|
| 62 |
+
columns cannot appear in cols_to_impute.
|
| 63 |
+
average_type : str
|
| 64 |
+
type of average to calculate. Must be either 'median' or 'mean'. For categorical
|
| 65 |
+
columns, the 'mode' will automatically be calculated.
|
| 66 |
+
cols_to_groupby : list of strings, optional
|
| 67 |
+
option to group data before calculating average.
|
| 68 |
+
|
| 69 |
+
Returns
|
| 70 |
+
-------
|
| 71 |
+
imputer : dataframe
|
| 72 |
+
contains average values calculated, to be used in imputation.
|
| 73 |
+
|
| 74 |
+
Raises
|
| 75 |
+
-------
|
| 76 |
+
ValueError
|
| 77 |
+
raises an error if elements in cols_to_groupby appear in cols_to_impute
|
| 78 |
+
ValueError
|
| 79 |
+
error raised if nans are not in the correct format: 'nan'.
|
| 80 |
+
|
| 81 |
+
"""
|
| 82 |
+
if average_type not in ["mean", "median"]:
|
| 83 |
+
raise ValueError("average_type must be either 'mean or 'median'.")
|
| 84 |
+
|
| 85 |
+
if not cols_to_groupby is None:
|
| 86 |
+
if any(column in cols_to_groupby for column in cols_to_impute):
|
| 87 |
+
raise ValueError(
|
| 88 |
+
"Elements in cols_to_groupby should not appear in cols_to_impute"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
imputer = []
|
| 92 |
+
for col in cols_to_impute:
|
| 93 |
+
is_numeric = is_numeric_dtype(train_data[col])
|
| 94 |
+
# For numeric columns, calculate specified average_type
|
| 95 |
+
if is_numeric:
|
| 96 |
+
imputer_for_col = train_data.groupby(cols_to_groupby)[col].agg(
|
| 97 |
+
[average_type]
|
| 98 |
+
)
|
| 99 |
+
imputer_for_col = imputer_for_col.rename(columns={average_type: col})
|
| 100 |
+
# For categorical columns, calculate mode
|
| 101 |
+
else:
|
| 102 |
+
# Raise an error if nans are in the incorrect format
|
| 103 |
+
all_nan_types = [None, "None", np.NaN, "nan", "NAN", "N/A"]
|
| 104 |
+
cat_present = train_data[col].unique().tolist()
|
| 105 |
+
if any(element in all_nan_types for element in cat_present):
|
| 106 |
+
if not "nan" in cat_present:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
"Missing values in categorical columns are not recorded as 'nan'."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Drop any categories with 'nan' so that when getting the mode, nan isn't
|
| 112 |
+
# treated as a category
|
| 113 |
+
cols_to_groupby_plus_impute = cols_to_groupby + [col]
|
| 114 |
+
train_data_col = train_data[cols_to_groupby_plus_impute]
|
| 115 |
+
train_data_col = train_data_col[train_data_col[col] != "nan"]
|
| 116 |
+
imputer_for_col = train_data_col.groupby(cols_to_groupby).agg(
|
| 117 |
+
pd.Series.mode
|
| 118 |
+
)
|
| 119 |
+
imputer.append(imputer_for_col)
|
| 120 |
+
imputer = pd.concat(imputer, axis=1)
|
| 121 |
+
|
| 122 |
+
# If there are any nans after grouping, fill nans with values from similar groups
|
| 123 |
+
imputer = imputer.sort_values(cols_to_groupby)
|
| 124 |
+
imputer = imputer.ffill().bfill()
|
| 125 |
+
else:
|
| 126 |
+
imputer = train_data[cols_to_impute].agg(average_type)
|
| 127 |
+
return imputer
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def apply_imputer(*, data, cols_to_impute, imputer, cols_to_groupby=None):
|
| 131 |
+
"""Impute input data with supplied imputer.
|
| 132 |
+
|
| 133 |
+
Parameters
|
| 134 |
+
----------
|
| 135 |
+
data : dataframe
|
| 136 |
+
data with columns to be imputed.
|
| 137 |
+
cols_to_impute : list of strings
|
| 138 |
+
names of columns to be imputed. If cols_to_groupby is not None, these columns cannot
|
| 139 |
+
appear in cols_to_impute.
|
| 140 |
+
imputer : dataframe
|
| 141 |
+
contains average values calculated, to be used in imputation.
|
| 142 |
+
cols_to_groupby : list of strings, optional
|
| 143 |
+
option to group data before calculating average. Must be the same as in get_imputer.
|
| 144 |
+
|
| 145 |
+
Returns
|
| 146 |
+
-------
|
| 147 |
+
data : dataframe
|
| 148 |
+
imputed dataframe.
|
| 149 |
+
|
| 150 |
+
Raises
|
| 151 |
+
-------
|
| 152 |
+
ValueError
|
| 153 |
+
raises an error if cols_to_groupby in apply_imputer do not match cols_to_groupby in
|
| 154 |
+
get_imputer.
|
| 155 |
+
|
| 156 |
+
"""
|
| 157 |
+
if (cols_to_groupby) != (imputer.index.names):
|
| 158 |
+
raise ValueError(
|
| 159 |
+
"Groups used to generate the imputer and apply imputer must be the same. \
|
| 160 |
+
Groups used to generate the imputer are: {}, groups used to apply the \
|
| 161 |
+
imputer are: {}".format(
|
| 162 |
+
imputer.index.names, cols_to_groupby
|
| 163 |
+
)
|
| 164 |
+
)
|
| 165 |
+
if not cols_to_groupby is None:
|
| 166 |
+
imputer = imputer.add_suffix("_avg")
|
| 167 |
+
imputer = imputer.reset_index()
|
| 168 |
+
|
| 169 |
+
data_imputed = data.merge(imputer, on=cols_to_groupby, how="left")
|
| 170 |
+
for col in cols_to_impute:
|
| 171 |
+
is_numeric = is_numeric_dtype(data[col])
|
| 172 |
+
if is_numeric:
|
| 173 |
+
data_imputed[col] = np.where(
|
| 174 |
+
data_imputed[col].isna(),
|
| 175 |
+
data_imputed[col + "_avg"],
|
| 176 |
+
data_imputed[col],
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
data_imputed = replace_nan_with_mode(
|
| 180 |
+
data_imputed, col, col + "_avg", random_state=0
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Drop columns containing the average values
|
| 184 |
+
data_imputed = data_imputed.loc[:, ~data_imputed.columns.str.endswith("_avg")]
|
| 185 |
+
else:
|
| 186 |
+
data_imputed = data_imputed.fillna(imputer)
|
| 187 |
+
return data_imputed
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def kfold_impute(
|
| 191 |
+
*,
|
| 192 |
+
df,
|
| 193 |
+
fold_ids,
|
| 194 |
+
cols_to_impute,
|
| 195 |
+
average_type,
|
| 196 |
+
cols_to_groupby,
|
| 197 |
+
id_col,
|
| 198 |
+
):
|
| 199 |
+
"""Perform K-fold imputation.
|
| 200 |
+
|
| 201 |
+
Fold by fold imputation of train data is used to prevent data leakage in cross-
|
| 202 |
+
validation (the same folds are used for imputation and CV). For example, in 5-fold
|
| 203 |
+
imputation, each fold is imputed using the other 4 folds and that fold is then used as
|
| 204 |
+
the validation fold in CV.
|
| 205 |
+
|
| 206 |
+
Parameters
|
| 207 |
+
----------
|
| 208 |
+
df : dataframe
|
| 209 |
+
data with columns to be imputed. Will generally be the train data.
|
| 210 |
+
fold_ids : list of arrays
|
| 211 |
+
each array contains the validation patient IDs for each fold.
|
| 212 |
+
cols_to_impute : list of strings
|
| 213 |
+
columns to impute.
|
| 214 |
+
average_type : str
|
| 215 |
+
type of average to calculate (e.g. median, mean).
|
| 216 |
+
cols_to_groupby : list of strings, optional
|
| 217 |
+
option to group data before calculating average.
|
| 218 |
+
id_col : str
|
| 219 |
+
name of patient ID column.
|
| 220 |
+
|
| 221 |
+
Returns
|
| 222 |
+
-------
|
| 223 |
+
imputed_df_cv : dataframe
|
| 224 |
+
k-fold imputed version of the input data.
|
| 225 |
+
|
| 226 |
+
"""
|
| 227 |
+
# Loop over CV folds and perform K-fold imputation
|
| 228 |
+
imputed_data_cv = []
|
| 229 |
+
for fold in fold_ids:
|
| 230 |
+
# Divide data into train folds and validation fold
|
| 231 |
+
validation_fold = df[df[id_col].isin(fold)]
|
| 232 |
+
train_folds = df[~df[id_col].isin(fold)]
|
| 233 |
+
|
| 234 |
+
# Obtain imputers from train folds
|
| 235 |
+
fold_imputer = get_imputer(
|
| 236 |
+
train_data=train_folds,
|
| 237 |
+
cols_to_impute=cols_to_impute,
|
| 238 |
+
average_type=average_type,
|
| 239 |
+
cols_to_groupby=cols_to_groupby,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Apply to validation fold
|
| 243 |
+
imputed_data_fold = apply_imputer(
|
| 244 |
+
data=validation_fold,
|
| 245 |
+
cols_to_impute=cols_to_impute,
|
| 246 |
+
imputer=fold_imputer,
|
| 247 |
+
cols_to_groupby=cols_to_groupby,
|
| 248 |
+
)
|
| 249 |
+
imputed_data_cv.append(imputed_data_fold)
|
| 250 |
+
|
| 251 |
+
# Place the imputed validation fold data into a single df
|
| 252 |
+
imputed_df_cv = pd.concat(imputed_data_cv)
|
| 253 |
+
|
| 254 |
+
return imputed_df_cv
|
| 255 |
+
|
training/model_h.py
ADDED
|
@@ -0,0 +1,2061 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing code for model H (longer term exacerbation prediction)."""
|
| 2 |
+
|
| 3 |
+
# General
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import random
|
| 10 |
+
import json
|
| 11 |
+
import joblib
|
| 12 |
+
import mlflow
|
| 13 |
+
|
| 14 |
+
# Feature engineering
|
| 15 |
+
from sklearn import base
|
| 16 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 17 |
+
from imblearn.over_sampling import RandomOverSampler, SMOTE
|
| 18 |
+
|
| 19 |
+
# Calibration
|
| 20 |
+
from sklearn.calibration import calibration_curve, CalibratedClassifierCV
|
| 21 |
+
import ml_insights as mli
|
| 22 |
+
|
| 23 |
+
# Metrics
|
| 24 |
+
from sklearn.metrics import (
|
| 25 |
+
confusion_matrix,
|
| 26 |
+
precision_recall_curve,
|
| 27 |
+
auc,
|
| 28 |
+
average_precision_score,
|
| 29 |
+
roc_auc_score,
|
| 30 |
+
brier_score_loss,
|
| 31 |
+
f1_score,
|
| 32 |
+
precision_score,
|
| 33 |
+
recall_score,
|
| 34 |
+
roc_curve,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Explainability
|
| 38 |
+
import shap
|
| 39 |
+
|
| 40 |
+
# Plotting
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
import seaborn as sns
|
| 43 |
+
|
| 44 |
+
##############################################################
|
| 45 |
+
# Functions for setting up model labels
|
| 46 |
+
##############################################################
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def apply_logic_response_criterion(df, N=2, minimum_period=14, maximum_period=35):
|
| 50 |
+
"""
|
| 51 |
+
Apply PRO LOGIC criterion 2 (consecutive negative Q5 replies required between events).
|
| 52 |
+
|
| 53 |
+
For events that occur after the minimum required period following a previous exac,
|
| 54 |
+
e.g. longer than 14 days, but before they are automatically considered as a new exac
|
| 55 |
+
event, e.g. 35 days, PRO LOGIC considers weekly PRO responses between the two events.
|
| 56 |
+
For subsequent events to count as separate events, there must be at least N
|
| 57 |
+
consecutive negative responses (no rescue meds taken) to weekly PROs between each
|
| 58 |
+
postive reply. Note PRO LOGIC is applied to both hospital and patient reported events.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
df (pd.DataFrame): must contain columns for PatientId, DateOfEvent, Q5Answered,
|
| 62 |
+
NegativeQ5, IsExac and DaysSinceLastExac.
|
| 63 |
+
minimum_period (int): minimum number of days since the previous exac (any exacs
|
| 64 |
+
within this window will already be removed with PRO LOGIC criterion 1).
|
| 65 |
+
Default value is 14 days.
|
| 66 |
+
maximum_period (int): maximum number of days since the previous exac (any exacs
|
| 67 |
+
occurring after this period will automatically count as a separate event).
|
| 68 |
+
Default is 35 days.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
pd.DataFrame: input df with a new boolean column 'RemoveExac'.
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
# Retrieve dataframe indices of exacs falling under PRO LOGIC criterion 2 (Q5 replies)
|
| 75 |
+
indices = get_logic_exacerbation_indices(
|
| 76 |
+
df, minimum_period=minimum_period, maximum_period=maximum_period
|
| 77 |
+
)
|
| 78 |
+
remove_exac = []
|
| 79 |
+
# Loop over each exac and evaluate PRO LOGIC criterion, returning 1 (remove) or 0
|
| 80 |
+
for exac_index in indices:
|
| 81 |
+
remove_exac.append(logic_consecutive_negative_responses(df, exac_index, N))
|
| 82 |
+
# Create dataframe containing exac indices and a boolean column stating whether to
|
| 83 |
+
# remove that exac due to failing Q5 response criterion and merge with original df
|
| 84 |
+
remove_exac = pd.DataFrame({"ind": indices, "RemoveExac": remove_exac})
|
| 85 |
+
df = df.merge(
|
| 86 |
+
remove_exac.set_index("ind"), left_index=True, right_index=True, how="left"
|
| 87 |
+
)
|
| 88 |
+
return df
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def bin_numeric_column(*, col, bins, labels):
|
| 92 |
+
"""
|
| 93 |
+
Use pd.cut to bin numeric data into categories.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
col (pd.Series): dataframe column to be binned.
|
| 97 |
+
bins (list): numeric values of bins.
|
| 98 |
+
labels (list): corresponding labels for the bins.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
pd.Series: binned column.
|
| 102 |
+
"""
|
| 103 |
+
return pd.cut(col, bins=bins, labels=labels, right=False).astype("str")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def calculate_days_since_last_event(*, df, event_col, output_col):
|
| 107 |
+
"""
|
| 108 |
+
Calculate the days since the last event, e.g. exacerbation or rescue med prescription.
|
| 109 |
+
|
| 110 |
+
Restarts the count from one the day following an event. Any days without a
|
| 111 |
+
previous event have the output column set to -1
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
df (pd.DataFrame): dataframe with a column containing dates and a boolean column
|
| 115 |
+
stating whether an event occurred on that date.
|
| 116 |
+
event_col (str): name of the boolean column for whether an event occurred.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
df: the input dateframe with an additional column stating the number of days since
|
| 120 |
+
the previous event occurred (or -1 if no previous event).
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
# Get all events
|
| 124 |
+
all_events = df[df[event_col].eq(1)].copy()
|
| 125 |
+
all_events["PrevEvent"] = all_events.index
|
| 126 |
+
# Merge the full df with the event df on their indices to the closest date in the past
|
| 127 |
+
# i.e. the most recent exacerbation
|
| 128 |
+
df = pd.merge_asof(
|
| 129 |
+
df,
|
| 130 |
+
all_events["PrevEvent"],
|
| 131 |
+
left_index=True,
|
| 132 |
+
right_index=True,
|
| 133 |
+
direction="backward",
|
| 134 |
+
)
|
| 135 |
+
# Calculate the days since the previous event, restarting the count from 1 the
|
| 136 |
+
# day following an exacerbation (using shift)
|
| 137 |
+
df[output_col] = df.index - df["PrevEvent"].shift(1)
|
| 138 |
+
# Set to -1 for any rows without a prior exacerbation
|
| 139 |
+
df[output_col] = df[output_col].fillna(-1).astype("int64")
|
| 140 |
+
df = df.drop(columns=["PrevEvent"])
|
| 141 |
+
return df
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def extract_clinician_verified_exacerbations(df):
|
| 145 |
+
"""
|
| 146 |
+
Extract verified events from clinician verification spreadsheets.
|
| 147 |
+
|
| 148 |
+
Extract only clinician verified events from verification spreadsheets and set the date
|
| 149 |
+
to the clinician supplied date if entered. Include a flag column for if the date was
|
| 150 |
+
changed from the PRO question response date.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
df (pd.DataFrame): event verification data supplied by clinicians.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
pd.DataFrame: contains StudyId, DateOfEvent (a mix of true event dates and PRO
|
| 157 |
+
response dates if true dates unknown), IsCommExac (set to 1 here, used
|
| 158 |
+
after merging later) and ExacDateUnknown (boolean, 1 if clinicians did not
|
| 159 |
+
change the date).
|
| 160 |
+
|
| 161 |
+
"""
|
| 162 |
+
# Filter for only verified events
|
| 163 |
+
df = df[df["Exacerbation confirmed"] == 1].copy()
|
| 164 |
+
df["DateRecorded"] = pd.to_datetime(df.DateRecorded, utc=True).dt.normalize()
|
| 165 |
+
df["New Date"] = pd.to_datetime(df["New Date"], utc=True).dt.normalize()
|
| 166 |
+
# Change the event date to the clinician supplied date if entered. This is considered
|
| 167 |
+
# the true event date. Set the event date to the PRO response date otherwise and flag
|
| 168 |
+
# that the true date is unknown
|
| 169 |
+
df["DateOfEvent"] = np.where(
|
| 170 |
+
df["Date changed"] == 1, df["New Date"], df["DateRecorded"]
|
| 171 |
+
)
|
| 172 |
+
df["DateOfEvent"] = np.where(
|
| 173 |
+
(df["Date changed"] == 1) & (df["New Date"].isna()),
|
| 174 |
+
df["DateRecorded"],
|
| 175 |
+
df["DateOfEvent"],
|
| 176 |
+
)
|
| 177 |
+
df["ExacDateUnknown"] = np.int64(np.where(df["Date changed"] == 1, 0, 1))
|
| 178 |
+
df["ExacDateUnknown"] = np.int64(
|
| 179 |
+
np.where(
|
| 180 |
+
(df["Date changed"] == 1) & (df["New Date"].isna()),
|
| 181 |
+
1,
|
| 182 |
+
df["ExacDateUnknown"],
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
# Flag all events as community events (this df will merge with hospital events later)
|
| 186 |
+
df["IsCommExac"] = 1
|
| 187 |
+
df = df[["StudyId", "DateOfEvent", "IsCommExac", "ExacDateUnknown"]]
|
| 188 |
+
return df
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def define_hospital_admission(events):
|
| 192 |
+
"""
|
| 193 |
+
Define whether a COPD service event was an admission and return 1 (yes) or 0 (no).
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
events (pd.DataFrame): events from COPD service previously merged with
|
| 197 |
+
PatientEventTypes.txt to get a column containing EventTypeId
|
| 198 |
+
event_name_col (str): name of column containing COPD service EventTypeId
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
array: boolean stating whether an event was a hospital admission.
|
| 202 |
+
|
| 203 |
+
"""
|
| 204 |
+
hospital_event_names = [
|
| 205 |
+
"Hospital admission - emergency, COPD related",
|
| 206 |
+
"Hospital admission - emergency, COPD unrelated",
|
| 207 |
+
]
|
| 208 |
+
return np.where(events.isin(hospital_event_names), 1, 0)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def define_service_exac_event(
|
| 212 |
+
*, events, event_name_col="EventType", include_community=False
|
| 213 |
+
):
|
| 214 |
+
"""State if a COPD service event was an exacerbation and return 1 (yes) or 0 (no).
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
events (pd.DataFrame): events from COPD service previously merged with
|
| 218 |
+
PatientEventTypes.txt to get a column containing EventTypeId
|
| 219 |
+
event_name_col (str): name of column containing COPD service EventTypeId
|
| 220 |
+
include_community (bool): whether to include event types corresponding to
|
| 221 |
+
patient reported exacerbations (e.g. community managed with rescue meds).
|
| 222 |
+
Defaults to False.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
array: boolean stating whether an event was an exacerbation.
|
| 226 |
+
|
| 227 |
+
"""
|
| 228 |
+
if include_community is True:
|
| 229 |
+
exacerbation_event_names = [
|
| 230 |
+
"Hospital admission - emergency, COPD related",
|
| 231 |
+
"Exacerbation - self-managed with rescue pack",
|
| 232 |
+
"GP review - emergency, COPD related",
|
| 233 |
+
"Emergency department attendance, COPD related",
|
| 234 |
+
"Exacerbation - started abs/steroid by clinical team",
|
| 235 |
+
]
|
| 236 |
+
else:
|
| 237 |
+
exacerbation_event_names = [
|
| 238 |
+
"Hospital admission - emergency, COPD related",
|
| 239 |
+
"GP review - emergency, COPD related",
|
| 240 |
+
"Emergency department attendance, COPD related",
|
| 241 |
+
"Exacerbation - started abs/steroid by clinical team",
|
| 242 |
+
]
|
| 243 |
+
return np.where(events.isin(exacerbation_event_names), 1, 0)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def fill_column_by_patient(*, df, id_col, col):
|
| 247 |
+
"""
|
| 248 |
+
Forward and back fill data by patient to fill gaps, e.g. from merges.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
df (pd.DataFrame): patient data. Must contain col and id_col columns.
|
| 252 |
+
id_col (str): name of column containing unique patient identifiers.
|
| 253 |
+
col (str): name of column to be filled.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
pd.DataFrame: input data with col infilled.
|
| 257 |
+
"""
|
| 258 |
+
groupby_df = df.groupby(id_col)[col].apply(lambda x: x)
|
| 259 |
+
groupby_df = groupby_df.reset_index(level=0)
|
| 260 |
+
groupby_df = groupby_df.apply(lambda x: x.ffill().bfill())
|
| 261 |
+
df[col] = groupby_df[col]
|
| 262 |
+
# df[col] = df.groupby(id_col)[col].apply(lambda x: x.ffill().bfill())
|
| 263 |
+
return df
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def filter_symptom_diary(*, df, patients, date_cutoff=None):
|
| 267 |
+
"""
|
| 268 |
+
Filter COPD symptom diary data for patients and dates of interest.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
df (pd.DataFrame): symptom diary data. Must contain 'SubmissionTime' and
|
| 272 |
+
'PatientId' columns.
|
| 273 |
+
patients (list): patient IDs of interest.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
pd.DataFrame: filtered symptom diary.
|
| 277 |
+
"""
|
| 278 |
+
df["SubmissionTime"] = pd.to_datetime(df.SubmissionTime, utc=True).dt.normalize()
|
| 279 |
+
# Take only data from after the cutoff if provided (e.g. weekly Q5 change)
|
| 280 |
+
if date_cutoff:
|
| 281 |
+
df = df[df.SubmissionTime >= date_cutoff]
|
| 282 |
+
# Filter for patients of interest
|
| 283 |
+
df = df[df.PatientId.isin(patients)]
|
| 284 |
+
return df
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def get_logic_exacerbation_indices(df, minimum_period=14, maximum_period=35):
|
| 288 |
+
"""
|
| 289 |
+
Return dataframe indices of exacs that need checking for PRO reponses since last exac.
|
| 290 |
+
|
| 291 |
+
Get the indices of exacerbations that occur long enough after the previous event to
|
| 292 |
+
not be removed by PRO LOGIC criterion 1 (e.g. within 14 days of previous exac) but
|
| 293 |
+
not long enough after to be counted as a separate event without further analysis.
|
| 294 |
+
Called by apply_logic_response_criterion.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
df (pd.DataFrame): must contain IsExac and DaysSinceLastExac columns.
|
| 298 |
+
minimum_period (int): minimum number of days since the previous exac (any exacs
|
| 299 |
+
within this window will already be removed with PRO LOGIC criterion 1).
|
| 300 |
+
Default value is 14 days.
|
| 301 |
+
maximum_period (int): maximum number of days since the previous exac (any exacs
|
| 302 |
+
occurring after this period will automatically count as a separate event).
|
| 303 |
+
Default is 35 days.
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
list: dataframe indices of relevant events.
|
| 307 |
+
"""
|
| 308 |
+
# Get the dataframe indices for all exacerbations occurring within period of interest
|
| 309 |
+
indices = df[
|
| 310 |
+
(df.IsExac.eq(1))
|
| 311 |
+
& (df.DaysSinceLastExac > minimum_period)
|
| 312 |
+
& (df.DaysSinceLastExac <= maximum_period)
|
| 313 |
+
].index.to_list()
|
| 314 |
+
return indices
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def get_rescue_med_pro_responses(df):
|
| 318 |
+
"""Extract all responses to weekly PRO Q5 (rescue meds).
|
| 319 |
+
|
| 320 |
+
Add new boolean columns stating if Q5 was answered, whether it was a negative response
|
| 321 |
+
(no rescue meds taken in previous week) and whether the reply means a community
|
| 322 |
+
exacerbation. The latter two columns will be opposites.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
df (pd.DataFrame): PRO symptom diary responses.
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
pd.DataFrame: filtered weekly PROs with additional boolean columns Q5Answered,
|
| 329 |
+
NegativeQ5 and IsCommExac.
|
| 330 |
+
|
| 331 |
+
"""
|
| 332 |
+
# Extract responses to weekly PRO Q5 (rescue meds)
|
| 333 |
+
df = df[df.SymptomDiaryQ5.notna()].copy()
|
| 334 |
+
df["SymptomDiaryQ5"] = df["SymptomDiaryQ5"].astype("int64")
|
| 335 |
+
# Columns for whether Q5 was answered and if the response was negative (no exac)
|
| 336 |
+
df["Q5Answered"] = 1
|
| 337 |
+
df["NegativeQ5"] = np.int64(np.where(df.SymptomDiaryQ5 == 0, 1, 0))
|
| 338 |
+
# Define community exacerbation as a positive reply to Q5
|
| 339 |
+
df["IsCommExac"] = np.int64(np.where(df.SymptomDiaryQ5 == 1, 1, 0))
|
| 340 |
+
return df
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def logic_consecutive_negative_responses(df, i, N=2):
|
| 344 |
+
"""
|
| 345 |
+
Calculate number of consecutive -ve Q5 replies since previous exac (PRO LOGIC).
|
| 346 |
+
|
| 347 |
+
Given the dataframe index of the current exac identified as falling under the Q5
|
| 348 |
+
criterion, calculate the number of negative replies to the weeky rescue med question
|
| 349 |
+
and check if there are enough for the event to count as distinct from the previous.
|
| 350 |
+
Called by apply_logic_response_criterion.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
df (pd.DataFrame): must contain weekly PRO replies and output from
|
| 354 |
+
get_rescue_med_pro_responses, set_pro_exac_dates and
|
| 355 |
+
calculate_days_since_exacerbation.
|
| 356 |
+
i (int): index of exac of interest.
|
| 357 |
+
N (int): number fo consecutive negative rescue meds required for event to be
|
| 358 |
+
counted as a separate event and retained in data. Default is 2.
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
int: flag for whether the exac failed the criterion. Returns 1 for failed (exac to
|
| 362 |
+
be removed) and 0 for passed (exac to be retained).
|
| 363 |
+
|
| 364 |
+
"""
|
| 365 |
+
# Select data since the previous exacerbation
|
| 366 |
+
days = int(df.iloc[i].DaysSinceLastExac)
|
| 367 |
+
data = df.iloc[i - days + 1 : i]
|
| 368 |
+
|
| 369 |
+
# Select replies to Q5
|
| 370 |
+
data = data[data.Q5Answered.eq(1)][
|
| 371 |
+
["PatientId", "DateOfEvent", "Q5Answered", "NegativeQ5"]
|
| 372 |
+
]
|
| 373 |
+
# Check if there are sufficient responses
|
| 374 |
+
if len(data) < N:
|
| 375 |
+
return 1
|
| 376 |
+
else:
|
| 377 |
+
# Resample to 7 days (weekly) to account for missing responses. Resampling using
|
| 378 |
+
# the 'W' option can give spurious nans - use '7D' instead
|
| 379 |
+
data = (
|
| 380 |
+
data.set_index("DateOfEvent")
|
| 381 |
+
.resample("7D", origin="start")
|
| 382 |
+
.sum()
|
| 383 |
+
.reset_index()
|
| 384 |
+
)
|
| 385 |
+
# Calculate number of consecutive negative replies to Q5 (no rescue meds taken)
|
| 386 |
+
consecutive_negative_responses = (
|
| 387 |
+
data[data.NegativeQ5.eq(1)]["NegativeQ5"]
|
| 388 |
+
.groupby(data.NegativeQ5.eq(0).cumsum())
|
| 389 |
+
.sum()
|
| 390 |
+
.reset_index(drop=True)
|
| 391 |
+
.max()
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
return 1 if consecutive_negative_responses < N else 0
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def minimum_period_between_exacerbations(df, minimum_days=14):
|
| 398 |
+
"""
|
| 399 |
+
Identify exacs occurring too soon after the previous exac based on DaysSinceLastExac.
|
| 400 |
+
|
| 401 |
+
Returns 1 if the exacerbation occurred within minimum_days of that patient's previous
|
| 402 |
+
exacerbation and 0 if not.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
df (pd.DataFrame): must contain DaysSinceLastExac column.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
array: contains 1 or 0.
|
| 409 |
+
"""
|
| 410 |
+
return np.where(
|
| 411 |
+
(df["DaysSinceLastExac"] > 0) & (df["DaysSinceLastExac"] <= minimum_days), 1, 0
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def remove_data_between_exacerbations(df):
|
| 416 |
+
"""
|
| 417 |
+
Remove data between first exac and subsequent exacs that failed PRO LOGIC criterion 2.
|
| 418 |
+
|
| 419 |
+
Ensures only the first in a series of related events are counted. Any subsequent exacs
|
| 420 |
+
that occurred too close to the initial event without sufficient negative weekly PRO
|
| 421 |
+
responses in the interim will be flagged for removal. This function removes flags for
|
| 422 |
+
removal all data from the day after the first event up to the date of events to be
|
| 423 |
+
removed. Data following the final event in the series will be removed by
|
| 424 |
+
minimum_period_between_exacerbations.
|
| 425 |
+
|
| 426 |
+
Args:
|
| 427 |
+
df (pd.DataFrame): must contain RemoveExac and DaysSinceLastExac columns.
|
| 428 |
+
|
| 429 |
+
Returns:
|
| 430 |
+
pd.DataFrame: days between first event and subsequent event(s) that failed the Q5
|
| 431 |
+
criterion are now flagged for removal in RemoveRow.
|
| 432 |
+
|
| 433 |
+
"""
|
| 434 |
+
indices = df[df.RemoveExac.eq(1)].index.to_list()
|
| 435 |
+
# Check there are exacerbations that failed the logic criterion for N consecutive
|
| 436 |
+
# negative reponses to Q5 of weekly PROs (rescue meds)
|
| 437 |
+
if len(indices) > 0:
|
| 438 |
+
for exac_index in indices:
|
| 439 |
+
# Select data since the previous exacerbation
|
| 440 |
+
days = int(df.iloc[exac_index].DaysSinceLastExac)
|
| 441 |
+
# Set data since last exac up to and including current exac to be removed
|
| 442 |
+
df.loc[exac_index - days + 1 : exac_index, "RemoveRow"] = 1
|
| 443 |
+
return df
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def remove_unknown_date_exacerbations(df, days_to_remove=7):
|
| 447 |
+
"""
|
| 448 |
+
Remove data prior to and including an exacerbation whose date is unknown.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
df (pd.DataFrame): one row per day per patient for full data window. Must include
|
| 452 |
+
ExacDateUnknown column.
|
| 453 |
+
days_to_remove (int): number of days of data to remove leading up to (and
|
| 454 |
+
including) the PRO response date. Default is 7 days.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
pd.DataFrame: input dataframe with updated RemoveRow column.
|
| 458 |
+
|
| 459 |
+
"""
|
| 460 |
+
# Get indices of all exacs whose dates are flagged as unknown.
|
| 461 |
+
indices = df[df.ExacDateUnknown.eq(1)].index.to_list()
|
| 462 |
+
# Check there are exacerbations with unknown dates (answer=1 in SymptomDiaryQ11a)
|
| 463 |
+
if len(indices) > 0:
|
| 464 |
+
for exac_index in indices:
|
| 465 |
+
# Set specified number of previous days data up to and including current exac
|
| 466 |
+
# to be removed
|
| 467 |
+
df.loc[exac_index - days_to_remove + 1 : exac_index, "RemoveRow"] = 1
|
| 468 |
+
return df
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def set_pro_exac_dates(df):
|
| 472 |
+
"""
|
| 473 |
+
Set date of community exacerbations reported in weekly PROs Q5 and flag unknown dates.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
df (pd.DataFrame: processed weekly PROs Q5 respnses, e.g. output of
|
| 477 |
+
get_rescue_med_pro_responses
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
pd.DataFrame: input dataframe with additional columns for DateOfEvent (datetime)
|
| 481 |
+
and ExacDateUnknown (0 or 1).
|
| 482 |
+
"""
|
| 483 |
+
# Take known exacerbation (rescue med) dates from SymptomDiaryQ11b, otherwise set the
|
| 484 |
+
# date to the date of PRO response
|
| 485 |
+
df["DateOfEvent"] = np.where(
|
| 486 |
+
df.SymptomDiaryQ11a == 2, df.SymptomDiaryQ11b, df.SubmissionTime
|
| 487 |
+
)
|
| 488 |
+
# Flag which dates were unknown from the PRO response
|
| 489 |
+
df["ExacDateUnknown"] = np.int64(
|
| 490 |
+
np.where((df.IsCommExac == 1) & (df.SymptomDiaryQ11a != 2), 1, 0)
|
| 491 |
+
)
|
| 492 |
+
df["DateOfEvent"] = pd.to_datetime(df.DateOfEvent, utc=True).dt.normalize()
|
| 493 |
+
df = df.drop_duplicates(keep="last", subset=["PatientId", "DateOfEvent"])
|
| 494 |
+
return df
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
##############################################################
|
| 498 |
+
# Functions for generating features
|
| 499 |
+
##############################################################
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def weigh_features_by_recency(
|
| 503 |
+
*, df, feature, feature_recency_days, median_value, decay_rate=0.01
|
| 504 |
+
):
|
| 505 |
+
"""Gives more weight to more recent observations.
|
| 506 |
+
|
| 507 |
+
More weight is given to more recent observations. The older the observation, the value
|
| 508 |
+
will be scaled more towards the median. This is because abnormal values observed in the
|
| 509 |
+
past may not be reflective of current status for the patient.
|
| 510 |
+
|
| 511 |
+
Parameters
|
| 512 |
+
----------
|
| 513 |
+
df : dataframe
|
| 514 |
+
input dataframe containing data to be scaled.
|
| 515 |
+
feature : str
|
| 516 |
+
name of feature to scale.
|
| 517 |
+
feature_recency_days : int
|
| 518 |
+
number of days prior to index date when the feature was observed.
|
| 519 |
+
median_value : int
|
| 520 |
+
median value for feature across all patients.
|
| 521 |
+
decay_rate : float, optional
|
| 522 |
+
the rate at which the observation is scaled, by default 0.01. Higher decay rate
|
| 523 |
+
leads to more extreme scaling towards the median.
|
| 524 |
+
|
| 525 |
+
Returns
|
| 526 |
+
-------
|
| 527 |
+
df : dataframe
|
| 528 |
+
input dataframe with weighted columns.
|
| 529 |
+
"""
|
| 530 |
+
df["Weights"] = np.exp(-decay_rate * df[feature_recency_days])
|
| 531 |
+
df["LowerThanMedian"] = np.where((df[feature] - median_value) < 0, 1, 0)
|
| 532 |
+
|
| 533 |
+
df[feature + "Weighted"] = np.where(
|
| 534 |
+
df["LowerThanMedian"] == 1,
|
| 535 |
+
df[feature] / df["Weights"],
|
| 536 |
+
df[feature] * df["Weights"],
|
| 537 |
+
)
|
| 538 |
+
# df[feature + 'Weighted'] = np.where(((df['LowerThanMedian'] == 1) & (df[feature + 'Weighted'] > median_value)), np.NaN, df[feature + 'Weighted'])
|
| 539 |
+
# df[feature + 'Weighted'] = np.where(((df['LowerThanMedian'] == 0) & (df[feature + 'Weighted'] < median_value)), np.NaN, df[feature + 'Weighted'])
|
| 540 |
+
df = df.drop(columns=[feature_recency_days, "Weights", "LowerThanMedian"])
|
| 541 |
+
return df
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
##############################################################
|
| 545 |
+
# Functions for modelling
|
| 546 |
+
##############################################################
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def load_data_for_modelling(data_path):
|
| 550 |
+
"""Loading and sorting data by StudyId.
|
| 551 |
+
|
| 552 |
+
Args:
|
| 553 |
+
data_path (str): filepath of data.
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
df: df containing features and target.
|
| 557 |
+
|
| 558 |
+
"""
|
| 559 |
+
data = pd.read_pickle(data_path)
|
| 560 |
+
data = data.sort_values(by=["StudyId", "IndexDate"])
|
| 561 |
+
data["IndexDate"] = pd.to_datetime(data.IndexDate, utc=True)
|
| 562 |
+
data = data.reset_index(drop=True)
|
| 563 |
+
return data
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def get_mlflow_run_params(model_name, run_id, mlflow_db, model_type, parent_run=True):
|
| 567 |
+
"""Searches the mlflow database runs to return the parameters used in the run(s).
|
| 568 |
+
|
| 569 |
+
Searches the database containing previously logged mlflow runs to find the specified
|
| 570 |
+
run. If parent_run is True, run_id is the run id of the parent run that contains
|
| 571 |
+
multiple child runs and will return the parameters of all the runs recorded under the
|
| 572 |
+
parent run id. If parent_run is False, the run_id is the run id of a specific run, only
|
| 573 |
+
the parameters of that specific run will be returned.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
model_name (str): name of model.
|
| 577 |
+
run_id (str): run id value of the run(s) that are to be searched.
|
| 578 |
+
mlflow_db (str): database where the mlflow runs are recorded.
|
| 579 |
+
model_type (str): contains information on type of model being run.
|
| 580 |
+
parent_run (bool, optional): specifies whether to query all child runs in the parent
|
| 581 |
+
run specified or a specific child run. Defaults to True.
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
dict: contains parameters used in a specific run or multiple runs depending on
|
| 585 |
+
whether parent_run is True.
|
| 586 |
+
|
| 587 |
+
"""
|
| 588 |
+
# Set the tracking uri to the database where runs are recorded
|
| 589 |
+
mlflow.set_tracking_uri(mlflow_db)
|
| 590 |
+
|
| 591 |
+
# Get the run_ids of the best models from hyperparameter tuning
|
| 592 |
+
if parent_run is True:
|
| 593 |
+
param_tuning_ids = mlflow.search_runs(
|
| 594 |
+
experiment_names=["model_h_drop_1_" + model_type],
|
| 595 |
+
filter_string="tags.mlflow.parentRunId = '"
|
| 596 |
+
+ run_id
|
| 597 |
+
+ "' and run_name = '"
|
| 598 |
+
+ model_name
|
| 599 |
+
+ "'",
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# Get parameters for the best models from hyperparameter tuning
|
| 603 |
+
best_params = {}
|
| 604 |
+
if parent_run is True:
|
| 605 |
+
for index, row in param_tuning_ids.iterrows():
|
| 606 |
+
best_param = mlflow.get_run(row["run_id"]).data.params
|
| 607 |
+
else:
|
| 608 |
+
best_param = mlflow.get_run(run_id).data.params
|
| 609 |
+
|
| 610 |
+
# Values are all given as strings. Convert the values into the correct format
|
| 611 |
+
for key in best_param:
|
| 612 |
+
try:
|
| 613 |
+
best_param[key] = int(best_param[key])
|
| 614 |
+
except:
|
| 615 |
+
try:
|
| 616 |
+
best_param[key] = float(best_param[key])
|
| 617 |
+
except:
|
| 618 |
+
best_param[key] = best_param[key]
|
| 619 |
+
if best_param[key] == "None":
|
| 620 |
+
best_param[key] = None
|
| 621 |
+
if best_param[key] == "True":
|
| 622 |
+
best_param[key] = True
|
| 623 |
+
if best_param[key] == "False":
|
| 624 |
+
best_param[key] = False
|
| 625 |
+
|
| 626 |
+
# Return a dictionary with multiple runs if parent_run is True
|
| 627 |
+
if parent_run is True:
|
| 628 |
+
best_params[best_param["opt_scorer"]] = best_param
|
| 629 |
+
for key in best_params:
|
| 630 |
+
best_params[key].pop("opt_scorer")
|
| 631 |
+
# Return a dictionary with a single run if parent_run is False
|
| 632 |
+
else:
|
| 633 |
+
best_params = best_param
|
| 634 |
+
return best_params
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
##############################################################
|
| 638 |
+
# Functions for calculating metrics
|
| 639 |
+
##############################################################
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def plot_confusion_matrix(
|
| 643 |
+
thresholds,
|
| 644 |
+
probs,
|
| 645 |
+
target,
|
| 646 |
+
model_name,
|
| 647 |
+
file_suffix,
|
| 648 |
+
calibration_type,
|
| 649 |
+
split_by_event=False,
|
| 650 |
+
event_type="hosp_comm",
|
| 651 |
+
):
|
| 652 |
+
"""Plot confusion matrices for calibrated models at different thresholds.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
thresholds (list): list of thresholds to plot the confusion matrices at.
|
| 656 |
+
probs (array): probability estimates for positive class for the test set.
|
| 657 |
+
target (array): true values.
|
| 658 |
+
model_name (str): name of model.
|
| 659 |
+
file_suffix (str): type of model run.
|
| 660 |
+
calibration_type (str): type of calibration.
|
| 661 |
+
split_by_event (bool, optional): determines whether to plot confusion matrix by event
|
| 662 |
+
type (hospital vs community events). Defaults to False.
|
| 663 |
+
event_type (str, optional): specifies the event type(s) that the confusion matrices
|
| 664 |
+
are being plotted for. Defaults to 'hosp_comm'.
|
| 665 |
+
|
| 666 |
+
Returns:
|
| 667 |
+
None.
|
| 668 |
+
|
| 669 |
+
"""
|
| 670 |
+
# Create folders to contain confusion matrices for each calibration type
|
| 671 |
+
os.makedirs("./tmp/cm_" + calibration_type, exist_ok=True)
|
| 672 |
+
|
| 673 |
+
for threshold in thresholds:
|
| 674 |
+
y_predicted = probs > threshold
|
| 675 |
+
cm = confusion_matrix(target, y_predicted)
|
| 676 |
+
group_names = ["True Neg", "False Pos", "False Neg", "True Pos"]
|
| 677 |
+
group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()]
|
| 678 |
+
group_percentages = [
|
| 679 |
+
"{0:.2%}".format(value) for value in cm.flatten() / np.sum(cm)
|
| 680 |
+
]
|
| 681 |
+
labels = [
|
| 682 |
+
f"{v1}\n{v2}\n{v3}"
|
| 683 |
+
for v1, v2, v3 in zip(group_names, group_counts, group_percentages)
|
| 684 |
+
]
|
| 685 |
+
labels = np.asarray(labels).reshape(2, 2)
|
| 686 |
+
sns.heatmap(cm, annot=labels, fmt="", cmap="Blues")
|
| 687 |
+
if split_by_event is False:
|
| 688 |
+
output_filename = (
|
| 689 |
+
"./tmp/cm_"
|
| 690 |
+
+ calibration_type
|
| 691 |
+
+ "/"
|
| 692 |
+
+ model_name
|
| 693 |
+
+ "_cm_"
|
| 694 |
+
+ calibration_type
|
| 695 |
+
+ "_thres_"
|
| 696 |
+
+ str(threshold)
|
| 697 |
+
+ "_"
|
| 698 |
+
+ file_suffix
|
| 699 |
+
+ ".png"
|
| 700 |
+
)
|
| 701 |
+
else:
|
| 702 |
+
output_filename = (
|
| 703 |
+
"./tmp/cm_"
|
| 704 |
+
+ calibration_type
|
| 705 |
+
+ "/"
|
| 706 |
+
+ model_name
|
| 707 |
+
+ "_cm_"
|
| 708 |
+
+ calibration_type
|
| 709 |
+
+ "_thres_"
|
| 710 |
+
+ str(threshold)
|
| 711 |
+
+ event_type
|
| 712 |
+
+ ".png"
|
| 713 |
+
)
|
| 714 |
+
plt.savefig(output_filename)
|
| 715 |
+
plt.close()
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def calc_best_f1_score(y_true, y_probs, best_threshold=None):
|
| 719 |
+
"""Finds the best f1 score and the respective threshold, precision, and recall.
|
| 720 |
+
|
| 721 |
+
Args:
|
| 722 |
+
y_true (array): ground truth target values.
|
| 723 |
+
y_probs (array): probabilites of the positive class for the target.
|
| 724 |
+
best_threshold (bool, optional): ability to provide a threshold used previously.
|
| 725 |
+
Defaults to None.
|
| 726 |
+
|
| 727 |
+
Returns:
|
| 728 |
+
best_threshold: int.
|
| 729 |
+
f1_best_thres: int.
|
| 730 |
+
precision_best_thres: int.
|
| 731 |
+
recall_best_thres: int.
|
| 732 |
+
|
| 733 |
+
"""
|
| 734 |
+
# Get threshold that gives best F1 score
|
| 735 |
+
precision, recall, thresholds = precision_recall_curve(y_true, y_probs)
|
| 736 |
+
fscore = (2 * precision * recall) / (precision + recall)
|
| 737 |
+
|
| 738 |
+
if best_threshold is None:
|
| 739 |
+
# When getting the max fscore, if fscore is nan, nan will be returned as the max.
|
| 740 |
+
# Iterate until nan not returned.
|
| 741 |
+
fscore_zero = True
|
| 742 |
+
position = -1
|
| 743 |
+
while fscore_zero is True:
|
| 744 |
+
best_thres_idx = np.argsort(fscore, axis=0)[position]
|
| 745 |
+
if np.isnan(fscore[best_thres_idx]) == True:
|
| 746 |
+
position = position - 1
|
| 747 |
+
else:
|
| 748 |
+
fscore_zero = False
|
| 749 |
+
else:
|
| 750 |
+
# Find the id of the threshold that is most similar to the threshold provided.
|
| 751 |
+
best_thres_idx = (np.abs(thresholds - best_threshold)).argmin()
|
| 752 |
+
|
| 753 |
+
# Get the scores at the threshold that gives the best f1 score
|
| 754 |
+
best_threshold = thresholds[best_thres_idx]
|
| 755 |
+
f1_best_thres = fscore[best_thres_idx]
|
| 756 |
+
precision_best_thres = precision[best_thres_idx]
|
| 757 |
+
recall_best_thres = recall[best_thres_idx]
|
| 758 |
+
|
| 759 |
+
print(
|
| 760 |
+
"Best Threshold=%f, F-Score=%.3f, Precision=%.3f, Recall=%.3f"
|
| 761 |
+
% (best_threshold, f1_best_thres, precision_best_thres, recall_best_thres)
|
| 762 |
+
)
|
| 763 |
+
return best_threshold, f1_best_thres, precision_best_thres, recall_best_thres
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
def get_threshold_with_best_f1_score(target, probs):
|
| 767 |
+
"""Calculate threshold at which the best F1 score is obtained.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
target (array): true values for target.
|
| 771 |
+
probs (array): probability estimates for positive class.
|
| 772 |
+
|
| 773 |
+
Returns:
|
| 774 |
+
best_threshold: int.
|
| 775 |
+
f1_best_thres: int.
|
| 776 |
+
precision_best_thres: int.
|
| 777 |
+
recall_best_thres: int.
|
| 778 |
+
|
| 779 |
+
"""
|
| 780 |
+
# Calculate precision, recall and f1 for all thresholds
|
| 781 |
+
precision, recall, thresholds = precision_recall_curve(target, probs)
|
| 782 |
+
f1score = (2 * precision * recall) / (precision + recall)
|
| 783 |
+
|
| 784 |
+
# When getting the max f1score, if f1score is nan, nan will be returned as the
|
| 785 |
+
# max. Iterate until nan not returned.
|
| 786 |
+
f1score_zero = True
|
| 787 |
+
position = -1
|
| 788 |
+
while f1score_zero is True:
|
| 789 |
+
best_thres_idx = np.argsort(f1score, axis=0)[position]
|
| 790 |
+
if np.isnan(f1score[best_thres_idx]) == True:
|
| 791 |
+
position = position - 1
|
| 792 |
+
else:
|
| 793 |
+
f1score_zero = False
|
| 794 |
+
best_threshold = thresholds[best_thres_idx]
|
| 795 |
+
f1_best_thres = f1score[best_thres_idx]
|
| 796 |
+
precision_best_thres = precision[best_thres_idx]
|
| 797 |
+
recall_best_thres = recall[best_thres_idx]
|
| 798 |
+
|
| 799 |
+
print(
|
| 800 |
+
"Best Threshold=%f, F-Score=%.3f, Precision=%.3f, Recall=%.3f"
|
| 801 |
+
% (best_threshold, f1_best_thres, precision_best_thres, recall_best_thres)
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
return best_threshold, f1_best_thres, precision_best_thres, recall_best_thres
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def calc_eval_metrics_for_model(
|
| 808 |
+
y_true, y_pred, y_probs, calib_type, best_threshold=None
|
| 809 |
+
):
|
| 810 |
+
"""Calculates evaluation metrics for models.
|
| 811 |
+
|
| 812 |
+
Calculates the evaluation metrics for uncalibrated and calibrated models. Additionally,
|
| 813 |
+
also calculates metrics for specific event types (e.g. only hospital events, only
|
| 814 |
+
community events). Metrics for specific events types are calculated based on the
|
| 815 |
+
best_threshold provided to allow direct comparison between the performance of the general
|
| 816 |
+
model and the performance broken down by event type.
|
| 817 |
+
|
| 818 |
+
Calculates the following metrics:
|
| 819 |
+
- f1 score at a threshold of 0.5.
|
| 820 |
+
- precision at a threshold of 0.5.
|
| 821 |
+
- recall at a threshold of 0.5.
|
| 822 |
+
- highest f1 score across thresholds.
|
| 823 |
+
- precision at threshold that has highest f1 score.
|
| 824 |
+
- recall at threshold that has highest f1 score.
|
| 825 |
+
- auc-pr.
|
| 826 |
+
- average precision.
|
| 827 |
+
- roc-auc.
|
| 828 |
+
- negative brier score.
|
| 829 |
+
|
| 830 |
+
Args:
|
| 831 |
+
y_true (array): ground truth target values.
|
| 832 |
+
y_pred (array): estimated targets as returned by classifier.
|
| 833 |
+
y_probs (array): probabilites of the positive class for the target.
|
| 834 |
+
calib_type (str): type of calibrated model to calculate metrics.
|
| 835 |
+
best_threshold (bool, optional): ability to provide a threshold value at which to
|
| 836 |
+
calculate metrics.
|
| 837 |
+
|
| 838 |
+
Returns:
|
| 839 |
+
dict: contains metrics for calibrated model.
|
| 840 |
+
|
| 841 |
+
"""
|
| 842 |
+
# Calculate precision, recall and f1 score at a threshold of 0.5
|
| 843 |
+
precision = precision_score(y_true, y_pred)
|
| 844 |
+
recall = recall_score(y_true, y_pred)
|
| 845 |
+
f1 = f1_score(y_true, y_pred)
|
| 846 |
+
|
| 847 |
+
# Calculate f1 score at best threshold
|
| 848 |
+
if best_threshold is None:
|
| 849 |
+
(
|
| 850 |
+
best_thres,
|
| 851 |
+
f1_best_thres,
|
| 852 |
+
prec_best_thres,
|
| 853 |
+
recall_best_thres,
|
| 854 |
+
) = calc_best_f1_score(y_true, y_probs)
|
| 855 |
+
else:
|
| 856 |
+
(
|
| 857 |
+
best_thres,
|
| 858 |
+
f1_best_thres,
|
| 859 |
+
prec_best_thres,
|
| 860 |
+
recall_best_thres,
|
| 861 |
+
) = calc_best_f1_score(y_true, y_probs, best_threshold=best_threshold)
|
| 862 |
+
|
| 863 |
+
# Calculate area under curves
|
| 864 |
+
precision_, recall_, thresholds_ = precision_recall_curve(y_true, y_probs)
|
| 865 |
+
fscore_ = (2 * precision_ * recall_) / (precision_ + recall_)
|
| 866 |
+
auc_pr = auc(recall_, precision_)
|
| 867 |
+
average_precision = average_precision_score(y_true, y_probs)
|
| 868 |
+
roc_auc = roc_auc_score(y_true, y_probs)
|
| 869 |
+
|
| 870 |
+
# Calculate negative brier loss
|
| 871 |
+
neg_brier_score = -abs(brier_score_loss(y_true, y_probs))
|
| 872 |
+
|
| 873 |
+
# Create dict with metrics
|
| 874 |
+
calib_metrics = {
|
| 875 |
+
"precision_" + calib_type: precision,
|
| 876 |
+
"recall_" + calib_type: recall,
|
| 877 |
+
"f1_" + calib_type: f1,
|
| 878 |
+
"precision_best_thres_" + calib_type: prec_best_thres,
|
| 879 |
+
"recall_best_thres_" + calib_type: recall_best_thres,
|
| 880 |
+
"f1_best_thres_" + calib_type: f1_best_thres,
|
| 881 |
+
"auc_pr_" + calib_type: auc_pr,
|
| 882 |
+
"average_precision_" + calib_type: average_precision,
|
| 883 |
+
"roc_auc_" + calib_type: roc_auc,
|
| 884 |
+
"neg_brier_score_" + calib_type: neg_brier_score,
|
| 885 |
+
# "best_thres_" + calib_type: best_thres,
|
| 886 |
+
}
|
| 887 |
+
return calib_metrics
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def create_df_probabilities_and_predictions(
|
| 891 |
+
probs,
|
| 892 |
+
best_threshold,
|
| 893 |
+
patient_id,
|
| 894 |
+
target,
|
| 895 |
+
hosp_comm_exac_df,
|
| 896 |
+
model_name,
|
| 897 |
+
file_suffix,
|
| 898 |
+
output_dir,
|
| 899 |
+
calib_type="uncalib",
|
| 900 |
+
):
|
| 901 |
+
"""Creates dataframe that allows plotting of shap local plots.
|
| 902 |
+
|
| 903 |
+
Creates a dataframe that contains patient identifier, model prediction probability,
|
| 904 |
+
threshold, model prediction, ground truth and an explanation column that describes
|
| 905 |
+
whether the model prediction was correct, and whether an exacerbation occurred. The
|
| 906 |
+
dataframe is saved in the output directory specified.
|
| 907 |
+
|
| 908 |
+
Args:
|
| 909 |
+
probs (array): model output probabilities for class 1 (exacerbation).
|
| 910 |
+
best_threshold (float): threshold where f1 score is highest, at which classification
|
| 911 |
+
is performed.
|
| 912 |
+
patient_id (list): list of patient ids in the same order as probs.
|
| 913 |
+
target (array): ground truth label.
|
| 914 |
+
hosp_comm_exac_df (pd.Dataframe): contains the event data broken down by event type.
|
| 915 |
+
Possible event types are hospital or community.
|
| 916 |
+
model_name (str): name of model.
|
| 917 |
+
file_suffix (str): type of model.
|
| 918 |
+
output_dir (str): output directory where df is saved.
|
| 919 |
+
calib_type (str, optional): type of calibration performed. Defaults to 'uncalib'.
|
| 920 |
+
|
| 921 |
+
Returns:
|
| 922 |
+
pd.DataFrame: contains probabilities, predictions and event types.
|
| 923 |
+
|
| 924 |
+
"""
|
| 925 |
+
predicted_best_thres = probs > best_threshold
|
| 926 |
+
predicted_best_thres = predicted_best_thres.astype(int)
|
| 927 |
+
probs_target = pd.DataFrame(
|
| 928 |
+
{
|
| 929 |
+
"StudyId": patient_id,
|
| 930 |
+
"Probs": probs,
|
| 931 |
+
"Threshold": best_threshold,
|
| 932 |
+
"Predicted": predicted_best_thres,
|
| 933 |
+
"Target": target,
|
| 934 |
+
"Explanation": np.NaN,
|
| 935 |
+
}
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
probs_target = probs_target.merge(
|
| 939 |
+
hosp_comm_exac_df, right_index=True, left_index=True, how="left"
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
probs_target["Explanation"] = np.where(
|
| 943 |
+
(probs_target["Predicted"] == probs_target["Target"])
|
| 944 |
+
& (probs_target["Predicted"] == 0),
|
| 945 |
+
"true negative",
|
| 946 |
+
probs_target["Explanation"],
|
| 947 |
+
)
|
| 948 |
+
probs_target["Explanation"] = np.where(
|
| 949 |
+
(probs_target["Predicted"] == probs_target["Target"])
|
| 950 |
+
& (probs_target["Predicted"] == 1),
|
| 951 |
+
"correct",
|
| 952 |
+
probs_target["Explanation"],
|
| 953 |
+
)
|
| 954 |
+
probs_target["Explanation"] = np.where(
|
| 955 |
+
(probs_target["Predicted"] != probs_target["Target"])
|
| 956 |
+
& (probs_target["Predicted"] == 0),
|
| 957 |
+
"missed",
|
| 958 |
+
probs_target["Explanation"],
|
| 959 |
+
)
|
| 960 |
+
probs_target["Explanation"] = np.where(
|
| 961 |
+
(probs_target["Predicted"] != probs_target["Target"])
|
| 962 |
+
& (probs_target["Predicted"] == 1),
|
| 963 |
+
"incorrect",
|
| 964 |
+
probs_target["Explanation"],
|
| 965 |
+
)
|
| 966 |
+
probs_target.to_csv(
|
| 967 |
+
os.path.join(
|
| 968 |
+
output_dir,
|
| 969 |
+
"preds_and_events_"
|
| 970 |
+
+ calib_type
|
| 971 |
+
+ "_"
|
| 972 |
+
+ model_name
|
| 973 |
+
+ "_"
|
| 974 |
+
+ file_suffix
|
| 975 |
+
+ ".csv",
|
| 976 |
+
)
|
| 977 |
+
)
|
| 978 |
+
return probs_target
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
def calc_metrics_by_event_type(preds_event_df, calib_type):
|
| 982 |
+
"""Calculates performance metrics by event type (hospital or community).
|
| 983 |
+
|
| 984 |
+
Args:
|
| 985 |
+
preds_event_df (pd.Dataframe): contains values required for calculating metrics.
|
| 986 |
+
calib_type (str): type of calibration performed.
|
| 987 |
+
|
| 988 |
+
Returns:
|
| 989 |
+
dict: contains performance metrics for both community and hospital events.
|
| 990 |
+
|
| 991 |
+
"""
|
| 992 |
+
preds_events_comm = preds_event_df[preds_event_df["HospExacWithin3Months"] == 0]
|
| 993 |
+
metrics_comm = calc_eval_metrics_for_model(
|
| 994 |
+
preds_events_comm["Target"],
|
| 995 |
+
preds_events_comm["Predicted"],
|
| 996 |
+
preds_events_comm["Probs"],
|
| 997 |
+
calib_type,
|
| 998 |
+
best_threshold=preds_events_comm["Threshold"][0],
|
| 999 |
+
)
|
| 1000 |
+
metrics_comm = {f"{k}_comm": v for k, v in metrics_comm.items()}
|
| 1001 |
+
|
| 1002 |
+
preds_events_hosp = preds_event_df[preds_event_df["CommExacWithin3Months"] == 0]
|
| 1003 |
+
metrics_hosp = calc_eval_metrics_for_model(
|
| 1004 |
+
preds_events_hosp["Target"],
|
| 1005 |
+
preds_events_hosp["Predicted"],
|
| 1006 |
+
preds_events_hosp["Probs"],
|
| 1007 |
+
calib_type,
|
| 1008 |
+
best_threshold=preds_events_comm["Threshold"][0],
|
| 1009 |
+
)
|
| 1010 |
+
metrics_hosp = {f"{k}_hosp": v for k, v in metrics_hosp.items()}
|
| 1011 |
+
metrics_by_event_type = metrics_comm.copy()
|
| 1012 |
+
metrics_by_event_type.update(metrics_hosp)
|
| 1013 |
+
return metrics_by_event_type
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
def plot_roc_curve_by_event_type(preds_event_df, model_name, calib_type):
|
| 1017 |
+
"""Plots ROC curve for multiple event types.
|
| 1018 |
+
|
| 1019 |
+
Args:
|
| 1020 |
+
preds_event_df (pd.Dataframe): contains values required for calculating metrics.
|
| 1021 |
+
model_name (str): name of model.
|
| 1022 |
+
calib_type (str): type of calibration.
|
| 1023 |
+
|
| 1024 |
+
Returns:
|
| 1025 |
+
None.
|
| 1026 |
+
"""
|
| 1027 |
+
os.makedirs("./tmp/performance", exist_ok=True)
|
| 1028 |
+
fpr, tpr, _ = roc_curve(preds_event_df["Target"], preds_event_df["Probs"])
|
| 1029 |
+
auc_roc = roc_auc_score(preds_event_df["Target"], preds_event_df["Probs"])
|
| 1030 |
+
plt.plot(fpr, tpr, label="Hosp+Comm exacs" + "\n AUC=" + str(round(auc_roc, 2)))
|
| 1031 |
+
mapping_dict = {
|
| 1032 |
+
"HospExacWithin3Months": "Hosp exacs",
|
| 1033 |
+
"CommExacWithin3Months": "Comm exacs",
|
| 1034 |
+
}
|
| 1035 |
+
for key in mapping_dict:
|
| 1036 |
+
preds_events_df_subset = preds_event_df[
|
| 1037 |
+
(preds_event_df[key] == 1) | (preds_event_df["ExacWithin3Months"] == 0)
|
| 1038 |
+
]
|
| 1039 |
+
fpr, tpr, _ = roc_curve(
|
| 1040 |
+
preds_events_df_subset["Target"], preds_events_df_subset["Probs"]
|
| 1041 |
+
)
|
| 1042 |
+
auc_roc = roc_auc_score(
|
| 1043 |
+
preds_events_df_subset["Target"], preds_events_df_subset["Probs"]
|
| 1044 |
+
)
|
| 1045 |
+
plt.plot(fpr, tpr, label=mapping_dict[key] + "\n AUC=" + str(round(auc_roc, 2)))
|
| 1046 |
+
plt.plot([0, 1], [0, 1], linestyle="--", color="black")
|
| 1047 |
+
plt.title("ROC-AUC Curve")
|
| 1048 |
+
plt.ylabel("True Positive Rate")
|
| 1049 |
+
plt.xlabel("False Positive Rate")
|
| 1050 |
+
plt.legend(loc="lower right")
|
| 1051 |
+
plt.savefig("./tmp/performance/" + model_name + "_roc_curve_" + calib_type + ".png")
|
| 1052 |
+
plt.close()
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
def plot_prec_recall_by_event_type(preds_event_df, model_name, calib_type):
|
| 1056 |
+
"""Plots a precision recall curve for multiple event types.
|
| 1057 |
+
|
| 1058 |
+
Args:
|
| 1059 |
+
preds_event_df (pd.Dataframe): contains values required for calculating metrics.
|
| 1060 |
+
model_name (str): name of model.
|
| 1061 |
+
calib_type (str): type of calibration.
|
| 1062 |
+
|
| 1063 |
+
Returns:
|
| 1064 |
+
None.
|
| 1065 |
+
"""
|
| 1066 |
+
os.makedirs("./tmp/performance", exist_ok=True)
|
| 1067 |
+
precision, recall, thresholds = precision_recall_curve(
|
| 1068 |
+
preds_event_df["Target"], preds_event_df["Probs"]
|
| 1069 |
+
)
|
| 1070 |
+
auc_pr = auc(recall, precision)
|
| 1071 |
+
plt.plot(
|
| 1072 |
+
recall,
|
| 1073 |
+
precision,
|
| 1074 |
+
label="Hosp+Comm exacs" + "\n AUC-PR=" + str(round(auc_pr, 2)),
|
| 1075 |
+
)
|
| 1076 |
+
mapping_dict = {
|
| 1077 |
+
"HospExacWithin3Months": "Hosp exacs",
|
| 1078 |
+
"CommExacWithin3Months": "Comm exacs",
|
| 1079 |
+
}
|
| 1080 |
+
for key in mapping_dict:
|
| 1081 |
+
preds_event_df_subset = preds_event_df[
|
| 1082 |
+
(preds_event_df[key] == 1) | (preds_event_df["ExacWithin3Months"] == 0)
|
| 1083 |
+
]
|
| 1084 |
+
precision, recall, thresholds = precision_recall_curve(
|
| 1085 |
+
preds_event_df_subset["Target"], preds_event_df_subset["Probs"]
|
| 1086 |
+
)
|
| 1087 |
+
auc_pr = auc(recall, precision)
|
| 1088 |
+
plt.plot(
|
| 1089 |
+
recall,
|
| 1090 |
+
precision,
|
| 1091 |
+
label=mapping_dict[key] + "\n AUC-PR=" + str(round(auc_pr, 2)),
|
| 1092 |
+
)
|
| 1093 |
+
plt.title("Precision-Recall Curve")
|
| 1094 |
+
plt.ylabel("Precision")
|
| 1095 |
+
plt.xlabel("Recall")
|
| 1096 |
+
plt.legend(loc="upper right")
|
| 1097 |
+
plt.savefig("./tmp/performance/" + model_name + "_pr_curve_" + calib_type + ".png")
|
| 1098 |
+
plt.close()
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
def plot_cm_by_event_type(preds_event_df, model_name, file_suffix, calib_type):
|
| 1102 |
+
"""Plots confusion matrices by event type (hospital or community).
|
| 1103 |
+
|
| 1104 |
+
Args:
|
| 1105 |
+
preds_event_df (pd.Dataframe): contains values required for plotting confusion matrices.
|
| 1106 |
+
model_name (str): name of model.
|
| 1107 |
+
file_suffix (str): type of model run.
|
| 1108 |
+
calib_type (str): type of calibration performed.
|
| 1109 |
+
|
| 1110 |
+
Returns:
|
| 1111 |
+
None.
|
| 1112 |
+
|
| 1113 |
+
"""
|
| 1114 |
+
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, preds_event_df["Threshold"][0]]
|
| 1115 |
+
# Community events
|
| 1116 |
+
preds_events_comm = preds_event_df[preds_event_df["HospExacWithin3Months"] == 0]
|
| 1117 |
+
plot_confusion_matrix(
|
| 1118 |
+
thresholds,
|
| 1119 |
+
preds_events_comm["Probs"],
|
| 1120 |
+
preds_events_comm["Target"],
|
| 1121 |
+
model_name,
|
| 1122 |
+
file_suffix,
|
| 1123 |
+
calib_type,
|
| 1124 |
+
split_by_event=True,
|
| 1125 |
+
event_type="comm",
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
# Hospital events
|
| 1129 |
+
preds_events_hosp = preds_event_df[preds_event_df["CommExacWithin3Months"] == 0]
|
| 1130 |
+
plot_confusion_matrix(
|
| 1131 |
+
thresholds,
|
| 1132 |
+
preds_events_hosp["Probs"],
|
| 1133 |
+
preds_events_hosp["Target"],
|
| 1134 |
+
model_name,
|
| 1135 |
+
file_suffix,
|
| 1136 |
+
calib_type,
|
| 1137 |
+
split_by_event=True,
|
| 1138 |
+
event_type="hosp",
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
def reverse_scaling(data, imputation, file_suffix):
|
| 1143 |
+
"""Reverse scaling performed.
|
| 1144 |
+
|
| 1145 |
+
Args:
|
| 1146 |
+
data (pd.Dataframe): data after scaling performed.
|
| 1147 |
+
imputation (str): to identify whether imputation was performed for the model.
|
| 1148 |
+
file_suffix (str): type of model being run.
|
| 1149 |
+
|
| 1150 |
+
Returns:
|
| 1151 |
+
pd.Dataframe: input dataframe after scaling has been reversed.
|
| 1152 |
+
|
| 1153 |
+
"""
|
| 1154 |
+
if imputation == "imputed":
|
| 1155 |
+
scaler = joblib.load("./data/artifacts/scaler_imp" + file_suffix + ".pkl")
|
| 1156 |
+
else:
|
| 1157 |
+
scaler = joblib.load("./data/artifacts/scaler_no_imp" + file_suffix + ".pkl")
|
| 1158 |
+
data_scaling_reversed = scaler.inverse_transform(data)
|
| 1159 |
+
data_scaling_reversed = pd.DataFrame(
|
| 1160 |
+
data=data_scaling_reversed, columns=data.columns.tolist()
|
| 1161 |
+
)
|
| 1162 |
+
return data_scaling_reversed
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
def convert_target_encodings_into_groups(target_enc_path, test_features):
|
| 1166 |
+
"""Convert target encodings back into original category names.
|
| 1167 |
+
|
| 1168 |
+
Target encodings are converted back to their original category names to simplify
|
| 1169 |
+
interpretation using SHAP.
|
| 1170 |
+
|
| 1171 |
+
Args:
|
| 1172 |
+
target_enc_path (str): path where the target encodings are saved.
|
| 1173 |
+
test_features (pd.Dataframe): test values after scaling reversed.
|
| 1174 |
+
|
| 1175 |
+
Returns:
|
| 1176 |
+
pd.Dataframe: dataframe after converting target encodings into categories.
|
| 1177 |
+
|
| 1178 |
+
"""
|
| 1179 |
+
target_encodings = json.load(open(target_enc_path))
|
| 1180 |
+
# Add _te suffix as all columns that have been target encoded have the suffix.
|
| 1181 |
+
target_encodings = {key + "_te": val for key, val in target_encodings.items()}
|
| 1182 |
+
|
| 1183 |
+
# Invert the target encodings so that target encoded value maps to the categorical value.
|
| 1184 |
+
target_encodings_inv = {}
|
| 1185 |
+
for col_name in target_encodings:
|
| 1186 |
+
for category in target_encodings[col_name]:
|
| 1187 |
+
target_encodings[col_name][category] = round(
|
| 1188 |
+
target_encodings[col_name][category], 6
|
| 1189 |
+
)
|
| 1190 |
+
target_encodings_inv[col_name] = dict(
|
| 1191 |
+
(v, k) for k, v in target_encodings[col_name].items()
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
# Round categorical column values to 6 decimal places as in the target encoding dict.
|
| 1195 |
+
test_features_enc_conv = test_features.copy()
|
| 1196 |
+
target_encoded_cols = test_features_enc_conv.columns[
|
| 1197 |
+
test_features_enc_conv.columns.str.endswith("_te")
|
| 1198 |
+
]
|
| 1199 |
+
for col in target_encoded_cols:
|
| 1200 |
+
test_features_enc_conv[col] = round(test_features_enc_conv[col], 6)
|
| 1201 |
+
|
| 1202 |
+
# Replace the target encoding with the corresponding categorical value.
|
| 1203 |
+
test_features_enc_conv = test_features_enc_conv.replace(target_encodings_inv)
|
| 1204 |
+
return test_features_enc_conv
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
def plot_score_distribution(
|
| 1208 |
+
target, probs, artifact_dir, model_name, file_suffix, calibration_type="uncal"
|
| 1209 |
+
):
|
| 1210 |
+
"""Plots distribution of probabilties for each class.
|
| 1211 |
+
|
| 1212 |
+
Args:
|
| 1213 |
+
target (array): true values for target.
|
| 1214 |
+
probs (array): probability estimates for positive class.
|
| 1215 |
+
artifact_dir (str): output directory where plot is saved.
|
| 1216 |
+
model_name (str): name of model.
|
| 1217 |
+
file_suffix (str): type of model run.
|
| 1218 |
+
calibration_type (str): type of calibration performed. Defaults to 'uncal'.
|
| 1219 |
+
|
| 1220 |
+
Returns:
|
| 1221 |
+
None.
|
| 1222 |
+
|
| 1223 |
+
"""
|
| 1224 |
+
# Create folders to contain score distributions
|
| 1225 |
+
os.makedirs(os.path.join(artifact_dir, "score_distribution"), exist_ok=True)
|
| 1226 |
+
|
| 1227 |
+
# Plot score distribution
|
| 1228 |
+
model_scores = pd.DataFrame({"model_score": probs, "true_label": target})
|
| 1229 |
+
sns.displot(model_scores, x="model_score", hue="true_label", kde=True, bins=20)
|
| 1230 |
+
plt.savefig(
|
| 1231 |
+
os.path.join(
|
| 1232 |
+
artifact_dir,
|
| 1233 |
+
"score_distribution",
|
| 1234 |
+
model_name + "_score_dist_" + calibration_type + "_" + file_suffix + ".png",
|
| 1235 |
+
)
|
| 1236 |
+
)
|
| 1237 |
+
plt.close()
|
| 1238 |
+
|
| 1239 |
+
|
| 1240 |
+
##############################################################
|
| 1241 |
+
# Functions for model calibration
|
| 1242 |
+
##############################################################
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
def plot_calibration_curve(target, probs, bins, strategy, calib_type, ax=None):
|
| 1246 |
+
"""Plot calibration curve.
|
| 1247 |
+
|
| 1248 |
+
Args:
|
| 1249 |
+
target (array): true target values.
|
| 1250 |
+
probs (array): probability estimates for the positive class.
|
| 1251 |
+
bins (int): number of bins for plotting calibration curve.
|
| 1252 |
+
strategy (str): strategy used to define the widths of the bin. Possible options are
|
| 1253 |
+
'uniform' or 'quantile'.
|
| 1254 |
+
calib_type (str): type of calibration performed.
|
| 1255 |
+
ax (None or matplotlib axis): allows plotting of multiple calibration curves on one
|
| 1256 |
+
plot. Defaults to None.
|
| 1257 |
+
|
| 1258 |
+
Returns:
|
| 1259 |
+
None.
|
| 1260 |
+
|
| 1261 |
+
"""
|
| 1262 |
+
prob_true, prob_pred = calibration_curve(
|
| 1263 |
+
target, probs, n_bins=bins, strategy=strategy
|
| 1264 |
+
)
|
| 1265 |
+
if ax is None:
|
| 1266 |
+
plt.plot(prob_pred, prob_true, marker=".", label=calib_type)
|
| 1267 |
+
else:
|
| 1268 |
+
ax.plot(prob_pred, prob_true, marker=".", label=calib_type)
|
| 1269 |
+
|
| 1270 |
+
|
| 1271 |
+
def equalObs(x, nbin):
|
| 1272 |
+
"""Function to calculate equal frequency bins.
|
| 1273 |
+
|
| 1274 |
+
Args:
|
| 1275 |
+
x (array): data to be divided into bins.
|
| 1276 |
+
nbin (int): number of bins required.
|
| 1277 |
+
|
| 1278 |
+
Returns:
|
| 1279 |
+
array: bin edges that give equal frequency bins.
|
| 1280 |
+
"""
|
| 1281 |
+
nlen = len(x)
|
| 1282 |
+
bin_edges = np.interp(np.linspace(0, nlen, nbin + 1), np.arange(nlen), np.sort(x))
|
| 1283 |
+
return bin_edges
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
def plot_calibration_plot_with_error_bars(
|
| 1287 |
+
probs_uncal,
|
| 1288 |
+
probs_sig,
|
| 1289 |
+
probs_iso,
|
| 1290 |
+
probs_spline,
|
| 1291 |
+
train_target,
|
| 1292 |
+
test_target,
|
| 1293 |
+
model_name,
|
| 1294 |
+
artifact_dir="./tmp",
|
| 1295 |
+
):
|
| 1296 |
+
"""Plots calibration plots with 95% confidence interval bars.
|
| 1297 |
+
|
| 1298 |
+
Args:
|
| 1299 |
+
probs_uncal (array): probability estimates for the positive class for the
|
| 1300 |
+
uncalibrated model.
|
| 1301 |
+
probs_sig (array): probability estimates for the positive class for the sigmoid model.
|
| 1302 |
+
probs_iso (array): probability estimates for the positive class for the isotonic model.
|
| 1303 |
+
probs_spline (array): probability estimates for the positive class for the spline
|
| 1304 |
+
model.
|
| 1305 |
+
train_target (array): true target values for the train set.
|
| 1306 |
+
test_target (array): true target values for the test set.
|
| 1307 |
+
model_name (str): name of the model.
|
| 1308 |
+
artifact_dir (str, optional): path of output directory. Defaults to "./tmp".
|
| 1309 |
+
Returns:
|
| 1310 |
+
None.
|
| 1311 |
+
"""
|
| 1312 |
+
# Create histogram with equal-frequency bins
|
| 1313 |
+
n, bins_uncal, patches = plt.hist(
|
| 1314 |
+
probs_uncal, equalObs(probs_uncal, 5), edgecolor="black"
|
| 1315 |
+
)
|
| 1316 |
+
n, bins_sig, patches = plt.hist(
|
| 1317 |
+
probs_sig, equalObs(probs_sig, 5), edgecolor="black"
|
| 1318 |
+
)
|
| 1319 |
+
n, bins_iso, patches = plt.hist(
|
| 1320 |
+
probs_iso, equalObs(probs_iso, 5), edgecolor="black"
|
| 1321 |
+
)
|
| 1322 |
+
n, bins_spline, patches = plt.hist(
|
| 1323 |
+
probs_spline, equalObs(probs_spline, 5), edgecolor="black"
|
| 1324 |
+
)
|
| 1325 |
+
|
| 1326 |
+
# Plot calibration plots
|
| 1327 |
+
fig = plt.figure(figsize=(15, 15))
|
| 1328 |
+
plt.subplot(2, 2, 1)
|
| 1329 |
+
mli.plot_reliability_diagram(
|
| 1330 |
+
train_target,
|
| 1331 |
+
probs_uncal,
|
| 1332 |
+
bins=bins_uncal[:-1],
|
| 1333 |
+
ci_ref="point",
|
| 1334 |
+
reliability_title="Uncalibrated",
|
| 1335 |
+
)
|
| 1336 |
+
plt.subplot(2, 2, 2)
|
| 1337 |
+
mli.plot_reliability_diagram(
|
| 1338 |
+
test_target,
|
| 1339 |
+
probs_sig,
|
| 1340 |
+
bins=bins_sig[:-1],
|
| 1341 |
+
ci_ref="point",
|
| 1342 |
+
reliability_title="Sigmoid",
|
| 1343 |
+
)
|
| 1344 |
+
plt.subplot(2, 2, 3)
|
| 1345 |
+
mli.plot_reliability_diagram(
|
| 1346 |
+
test_target,
|
| 1347 |
+
probs_iso,
|
| 1348 |
+
bins=bins_iso[:-1],
|
| 1349 |
+
ci_ref="point",
|
| 1350 |
+
reliability_title="Isotonic",
|
| 1351 |
+
)
|
| 1352 |
+
plt.subplot(2, 2, 4)
|
| 1353 |
+
mli.plot_reliability_diagram(
|
| 1354 |
+
test_target,
|
| 1355 |
+
probs_spline,
|
| 1356 |
+
bins=bins_spline[:-1],
|
| 1357 |
+
ci_ref="point",
|
| 1358 |
+
reliability_title="Spline",
|
| 1359 |
+
)
|
| 1360 |
+
plt.savefig(
|
| 1361 |
+
os.path.join(artifact_dir, model_name + "calibration_plots_error_bars.png")
|
| 1362 |
+
)
|
| 1363 |
+
plt.close()
|
| 1364 |
+
|
| 1365 |
+
|
| 1366 |
+
def calc_std_for_calibrated_classifiers(
|
| 1367 |
+
calib_model, calib_model_name, test_features, test_target
|
| 1368 |
+
):
|
| 1369 |
+
auc_prs = []
|
| 1370 |
+
for i, clf in enumerate(calib_model.calibrated_classifiers_):
|
| 1371 |
+
probs_calib = clf.predict_proba(test_features)[:, 1]
|
| 1372 |
+
|
| 1373 |
+
precision_, recall_, thresholds_ = precision_recall_curve(
|
| 1374 |
+
test_target, probs_calib
|
| 1375 |
+
)
|
| 1376 |
+
auc_pr = auc(recall_, precision_)
|
| 1377 |
+
|
| 1378 |
+
auc_prs.append(auc_pr)
|
| 1379 |
+
print("Calibrated model:", calib_model_name)
|
| 1380 |
+
print("Mean AUC-PR:", np.mean(auc_prs))
|
| 1381 |
+
print("Std AUC-PR:", np.std(auc_prs))
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
##############################################################
|
| 1385 |
+
# Functions for model explainability
|
| 1386 |
+
##############################################################
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
def plot_feat_importance_model(model, model_name, file_suffix, feature_names=None):
|
| 1390 |
+
"""Plotting model feature importances derived from model.
|
| 1391 |
+
|
| 1392 |
+
Plots feature importance metrics derived from models. Only performed for XGBoost and
|
| 1393 |
+
Light GBM models currently. For the XGBoost model, total cover and total gain plotted.
|
| 1394 |
+
For the Light GBM model, total gain is plotted. Function also returns a table with a
|
| 1395 |
+
number that is representative of the feature importance.
|
| 1396 |
+
|
| 1397 |
+
Args:
|
| 1398 |
+
model (variable): model that has been fit on train data.
|
| 1399 |
+
model_name (str): name of model.
|
| 1400 |
+
file_suffix (str): contains information on type of model being run.
|
| 1401 |
+
feature_names (list): list of feature names.
|
| 1402 |
+
|
| 1403 |
+
Returns:
|
| 1404 |
+
pd.DataFrame: df containing feature importance position.
|
| 1405 |
+
|
| 1406 |
+
"""
|
| 1407 |
+
if model_name.startswith("xgb"):
|
| 1408 |
+
total_gain = model.get_booster().get_score(importance_type="total_gain")
|
| 1409 |
+
total_cover = model.get_booster().get_score(importance_type="total_cover")
|
| 1410 |
+
total_gain = pd.DataFrame.from_dict(
|
| 1411 |
+
total_gain, orient="index", columns=["total_gain"]
|
| 1412 |
+
)
|
| 1413 |
+
total_cover = pd.DataFrame.from_dict(
|
| 1414 |
+
total_cover, orient="index", columns=["total_cover"]
|
| 1415 |
+
)
|
| 1416 |
+
feat_importance = total_gain.join(total_cover)
|
| 1417 |
+
if model_name.startswith("lgbm"):
|
| 1418 |
+
total_gain = model.booster_.feature_importance(importance_type="gain")
|
| 1419 |
+
feat_importance = dict(zip(feature_names, total_gain))
|
| 1420 |
+
feat_importance = pd.DataFrame.from_dict(
|
| 1421 |
+
feat_importance, orient="index", columns=["total_gain"]
|
| 1422 |
+
)
|
| 1423 |
+
feat_importance = feat_importance.sort_values(by="total_gain", ascending=False)
|
| 1424 |
+
print("Total gain and total cover\n", feat_importance)
|
| 1425 |
+
feat_importance.plot.barh(figsize=(10, 10))
|
| 1426 |
+
plt.tight_layout()
|
| 1427 |
+
plt.savefig(
|
| 1428 |
+
"./tmp/" + model_name + "_feat_importance_model_" + file_suffix + ".png"
|
| 1429 |
+
)
|
| 1430 |
+
plt.close()
|
| 1431 |
+
|
| 1432 |
+
# Create feature importance table
|
| 1433 |
+
feat_importance[model_name] = range(1, 1 + len(feat_importance))
|
| 1434 |
+
feat_importance = feat_importance.drop(
|
| 1435 |
+
columns=["total_gain", "total_cover"], errors="ignore"
|
| 1436 |
+
)
|
| 1437 |
+
feat_importance = feat_importance.reset_index()
|
| 1438 |
+
if model_name == "xgb":
|
| 1439 |
+
feat_importance_tot_gain_df = feat_importance
|
| 1440 |
+
else:
|
| 1441 |
+
feat_importance_tot_gain_df = pd.read_csv(
|
| 1442 |
+
"./data/feature_importance_tot_gain_" + file_suffix + ".csv"
|
| 1443 |
+
)
|
| 1444 |
+
feat_importance_tot_gain_df = feat_importance_tot_gain_df.merge(
|
| 1445 |
+
feat_importance, on="index", how="left"
|
| 1446 |
+
)
|
| 1447 |
+
return feat_importance_tot_gain_df
|
| 1448 |
+
|
| 1449 |
+
|
| 1450 |
+
def get_shap_feat_importance(model_name, shap_values, feature_names, file_suffix):
|
| 1451 |
+
"""Creates a table with feature importance hierarchy using shap values.
|
| 1452 |
+
|
| 1453 |
+
A table containing feature importance position created for all models except for the
|
| 1454 |
+
dummy classifier.
|
| 1455 |
+
|
| 1456 |
+
Args:
|
| 1457 |
+
model_name (str): name of model.
|
| 1458 |
+
shap_values (array): contains shap values.
|
| 1459 |
+
feature_names (list): list of feature names.
|
| 1460 |
+
file_suffix (str): contains information on type of model being run.
|
| 1461 |
+
|
| 1462 |
+
Returns:
|
| 1463 |
+
pd.DataFrame: df containing feature importance position.
|
| 1464 |
+
|
| 1465 |
+
"""
|
| 1466 |
+
if model_name != "dummy_classifier":
|
| 1467 |
+
shap_vals_df = pd.DataFrame(shap_values, columns=feature_names)
|
| 1468 |
+
vals = np.abs(shap_vals_df.values).mean(0)
|
| 1469 |
+
shap_importance = pd.DataFrame(
|
| 1470 |
+
list(zip(feature_names, vals)),
|
| 1471 |
+
columns=["col_name", "feature_importance_vals"],
|
| 1472 |
+
)
|
| 1473 |
+
shap_importance = shap_importance.sort_values(
|
| 1474 |
+
by=["feature_importance_vals"], ascending=False
|
| 1475 |
+
)
|
| 1476 |
+
shap_importance[model_name] = range(1, 1 + len(shap_importance))
|
| 1477 |
+
shap_importance = shap_importance.drop(columns=["feature_importance_vals"])
|
| 1478 |
+
shap_importance = shap_importance.reset_index(drop=True)
|
| 1479 |
+
if model_name == "logistic_regression":
|
| 1480 |
+
feat_importance_df = shap_importance
|
| 1481 |
+
else:
|
| 1482 |
+
feat_importance_df = pd.read_csv(
|
| 1483 |
+
"./data/feature_importance_shap" + file_suffix + ".csv"
|
| 1484 |
+
)
|
| 1485 |
+
feat_importance_df = feat_importance_df.merge(
|
| 1486 |
+
shap_importance, on="col_name", how="left"
|
| 1487 |
+
)
|
| 1488 |
+
else:
|
| 1489 |
+
pass
|
| 1490 |
+
return feat_importance_df
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
def get_local_shap_values(
|
| 1494 |
+
model_name, file_suffix, shap_values, test_data, calib_name, shap_ids_dir
|
| 1495 |
+
):
|
| 1496 |
+
# Read df containing probabilities and predictions
|
| 1497 |
+
probs_target = pd.read_csv(
|
| 1498 |
+
os.path.join(
|
| 1499 |
+
shap_ids_dir,
|
| 1500 |
+
"preds_and_events_"
|
| 1501 |
+
+ calib_name
|
| 1502 |
+
+ "_"
|
| 1503 |
+
+ model_name
|
| 1504 |
+
+ "_"
|
| 1505 |
+
+ file_suffix
|
| 1506 |
+
+ ".csv",
|
| 1507 |
+
)
|
| 1508 |
+
)
|
| 1509 |
+
|
| 1510 |
+
explanation = shap.Explanation(shap_values, data=test_data, display_data=True)
|
| 1511 |
+
|
| 1512 |
+
feature_imp_all = []
|
| 1513 |
+
for row_id in range(0, len(probs_target)):
|
| 1514 |
+
feature_names = explanation[row_id].data.index
|
| 1515 |
+
shap_values = explanation[row_id].values
|
| 1516 |
+
feature_importance = pd.DataFrame(data=shap_values, index=feature_names)
|
| 1517 |
+
feature_importance = feature_importance.sort_values(
|
| 1518 |
+
by=0, ascending=False
|
| 1519 |
+
).reset_index()
|
| 1520 |
+
feature_importance = feature_importance["index"].rename(row_id)
|
| 1521 |
+
feature_imp_all.append(feature_importance)
|
| 1522 |
+
feature_imp_all = pd.concat(feature_imp_all, axis=1)
|
| 1523 |
+
return feature_imp_all
|
| 1524 |
+
|
| 1525 |
+
|
| 1526 |
+
def plot_local_shap(
|
| 1527 |
+
model_name,
|
| 1528 |
+
file_suffix,
|
| 1529 |
+
shap_values,
|
| 1530 |
+
test_data,
|
| 1531 |
+
train_data,
|
| 1532 |
+
calib_name,
|
| 1533 |
+
row_ids_to_plot,
|
| 1534 |
+
artifact_dir,
|
| 1535 |
+
shap_ids_dir,
|
| 1536 |
+
reverse_scaling_flag=False,
|
| 1537 |
+
convert_target_encodings=False,
|
| 1538 |
+
imputation=None,
|
| 1539 |
+
target_enc_path=None,
|
| 1540 |
+
return_enc_converted_df=False,
|
| 1541 |
+
):
|
| 1542 |
+
"""Plots local shap plots for specified row ids.
|
| 1543 |
+
|
| 1544 |
+
Local shap plots are plotted for the specified row ids. The bars are colored according
|
| 1545 |
+
to their values. If the value is higher than the mean, the bar is colored red. If the
|
| 1546 |
+
value is lower than the mean, the bar is colored blue. A text box is added with patient
|
| 1547 |
+
details and prediction.
|
| 1548 |
+
|
| 1549 |
+
Args:
|
| 1550 |
+
model_name (str): name of model.
|
| 1551 |
+
file_suffix (str): type of model.
|
| 1552 |
+
test_data (pd.DataFrame): dataframe containing feature names and values. Must be in
|
| 1553 |
+
the same shape as the shap_values.
|
| 1554 |
+
train_data (pd.DataFrame): dataframe containing train feature names and values.
|
| 1555 |
+
Needed for calculating the median for custom coloring of bars.
|
| 1556 |
+
shap_values (array): matrix of SHAP values (# samples x # features).
|
| 1557 |
+
calib_name (str): type of calibration performed.
|
| 1558 |
+
row_ids_to_plot (list): list of ids for which to plot the local SHAP plot for.
|
| 1559 |
+
artifact_dir (str): output directory where plot is saved.
|
| 1560 |
+
shap_ids_dir (str): directory where data used for SHAP plots is stored.
|
| 1561 |
+
reverse_scaling_flag (bool, optional): flag to identify whether scaling should be
|
| 1562 |
+
reversed. Defaults to False.
|
| 1563 |
+
convert_target_encodings (bool, optional): flag to allow conversion of target
|
| 1564 |
+
encodings to category groups. Defaults to False.
|
| 1565 |
+
imputation (None or str, optional): whether imputation was performed. Defaults to None.
|
| 1566 |
+
target_enc_path (None or str, optional): path to target encodings. Defaults to None.
|
| 1567 |
+
return_enc_converted_df (bool, optional): flag to identify whether to return the
|
| 1568 |
+
dataframe with converted target encodings. Defaults to False.
|
| 1569 |
+
|
| 1570 |
+
Returns:
|
| 1571 |
+
pd.Dataframe: contains target encodings converted to categories. If reverse_scaling is
|
| 1572 |
+
True, values in the df will be converted back to their original values.
|
| 1573 |
+
|
| 1574 |
+
"""
|
| 1575 |
+
# Create folder to contain local plots
|
| 1576 |
+
os.makedirs(os.path.join(artifact_dir, "shap_local_plots"), exist_ok=True)
|
| 1577 |
+
|
| 1578 |
+
# Read df containing probabilities and predictions
|
| 1579 |
+
probs_target = pd.read_csv(
|
| 1580 |
+
os.path.join(
|
| 1581 |
+
shap_ids_dir,
|
| 1582 |
+
"preds_and_events_"
|
| 1583 |
+
+ calib_name
|
| 1584 |
+
+ "_"
|
| 1585 |
+
+ model_name
|
| 1586 |
+
+ "_"
|
| 1587 |
+
+ file_suffix
|
| 1588 |
+
+ ".csv",
|
| 1589 |
+
)
|
| 1590 |
+
)
|
| 1591 |
+
|
| 1592 |
+
# Calculate median values for each feature
|
| 1593 |
+
median_values = pd.DataFrame(
|
| 1594 |
+
data=[np.median(test_data, axis=0)], columns=test_data.columns
|
| 1595 |
+
)
|
| 1596 |
+
for col_name in median_values.columns:
|
| 1597 |
+
if col_name.endswith("te"):
|
| 1598 |
+
median_values[col_name] = np.NaN
|
| 1599 |
+
median_values = median_values.T
|
| 1600 |
+
median_values = median_values.rename(columns={0: "median"})
|
| 1601 |
+
|
| 1602 |
+
# TODO Add median for when reverse scaling flag is True
|
| 1603 |
+
|
| 1604 |
+
# Reverse scaling in order to convert target encodings into groups if flag is True
|
| 1605 |
+
if convert_target_encodings is True:
|
| 1606 |
+
# data_scaling_reversed = reverse_scaling(test_data, imputation, file_suffix)
|
| 1607 |
+
data_enc_conv = convert_target_encodings_into_groups(target_enc_path, test_data)
|
| 1608 |
+
# if reverse_scaling_flag is False:
|
| 1609 |
+
# data_enc_conv = data_enc_conv.loc[
|
| 1610 |
+
# :, data_enc_conv.columns.str.endswith("te")
|
| 1611 |
+
# ]
|
| 1612 |
+
# data_no_categorical = test_data.loc[
|
| 1613 |
+
# :, ~test_data.columns.str.endswith("te")
|
| 1614 |
+
# ]
|
| 1615 |
+
# data_enc_conv = data_no_categorical.merge(
|
| 1616 |
+
# data_enc_conv, left_index=True, right_index=True, how="left"
|
| 1617 |
+
# )
|
| 1618 |
+
|
| 1619 |
+
for i in row_ids_to_plot:
|
| 1620 |
+
if type(i) == str:
|
| 1621 |
+
probs_target_to_plot = probs_target[probs_target["Explanation"] == i]
|
| 1622 |
+
row_ids_to_plot_section = probs_target_to_plot.index
|
| 1623 |
+
# Create folder to contain local plots
|
| 1624 |
+
os.makedirs(
|
| 1625 |
+
os.path.join(artifact_dir, "shap_local_plots", i), exist_ok=True
|
| 1626 |
+
)
|
| 1627 |
+
|
| 1628 |
+
# Plot local SHAP plots of selected ids
|
| 1629 |
+
for id in row_ids_to_plot_section:
|
| 1630 |
+
fig, ax = plt.subplots()
|
| 1631 |
+
if convert_target_encodings is True:
|
| 1632 |
+
explanation = shap.Explanation(
|
| 1633 |
+
shap_values, data=data_enc_conv, display_data=True
|
| 1634 |
+
)
|
| 1635 |
+
else:
|
| 1636 |
+
explanation = shap.Explanation(
|
| 1637 |
+
shap_values, data=test_data, display_data=True
|
| 1638 |
+
)
|
| 1639 |
+
shap.plots.bar(
|
| 1640 |
+
explanation[id], show_data=True, show=False, max_display=None
|
| 1641 |
+
)
|
| 1642 |
+
|
| 1643 |
+
# Create a text box with probability and threshold
|
| 1644 |
+
print(probs_target)
|
| 1645 |
+
probability = round(probs_target.iloc[id]["Probs"], 3)
|
| 1646 |
+
threshold = round(probs_target.iloc[id]["Threshold"], 3)
|
| 1647 |
+
textstr = (
|
| 1648 |
+
"StudyId: "
|
| 1649 |
+
+ probs_target.iloc[id]["StudyId"]
|
| 1650 |
+
+ "\n"
|
| 1651 |
+
+ "Probability: "
|
| 1652 |
+
+ str(probability)
|
| 1653 |
+
+ "\n"
|
| 1654 |
+
+ "Threshold: "
|
| 1655 |
+
+ str(threshold)
|
| 1656 |
+
+ "\n"
|
| 1657 |
+
+ "Ground truth: "
|
| 1658 |
+
+ str(probs_target.iloc[id]["Target"])
|
| 1659 |
+
+ "\n"
|
| 1660 |
+
+ "Prediction: "
|
| 1661 |
+
+ str(probs_target.iloc[id]["Predicted"])
|
| 1662 |
+
)
|
| 1663 |
+
|
| 1664 |
+
# Place a text box in upper left of figure
|
| 1665 |
+
props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
|
| 1666 |
+
ax.text(
|
| 1667 |
+
0.02,
|
| 1668 |
+
0.98,
|
| 1669 |
+
textstr,
|
| 1670 |
+
fontsize=12,
|
| 1671 |
+
transform=fig.transFigure,
|
| 1672 |
+
verticalalignment="top",
|
| 1673 |
+
bbox=props,
|
| 1674 |
+
)
|
| 1675 |
+
plt.title(probs_target.iloc[id]["Explanation"])
|
| 1676 |
+
|
| 1677 |
+
# Get the current Axes object
|
| 1678 |
+
ax = plt.gca()
|
| 1679 |
+
|
| 1680 |
+
# Get the input values for the patient
|
| 1681 |
+
data_id = test_data.iloc[id].T.to_frame()
|
| 1682 |
+
data_id = data_id.rename(columns={id: "values"})
|
| 1683 |
+
|
| 1684 |
+
# Order the columns by importance and merge
|
| 1685 |
+
df_shap_values = pd.DataFrame(
|
| 1686 |
+
data=[shap_values[id]], columns=test_data.columns
|
| 1687 |
+
).T
|
| 1688 |
+
df_shap_values = df_shap_values.abs()
|
| 1689 |
+
df_shap_values = df_shap_values.rename(columns={0: "shap_values"})
|
| 1690 |
+
df_feat_importance = df_shap_values.sort_values(
|
| 1691 |
+
by="shap_values", ascending=False
|
| 1692 |
+
)
|
| 1693 |
+
df_feat_importance = df_feat_importance.merge(
|
| 1694 |
+
median_values, right_index=True, left_index=True, how="left"
|
| 1695 |
+
)
|
| 1696 |
+
df_feat_importance = df_feat_importance.merge(
|
| 1697 |
+
data_id, right_index=True, left_index=True, how="left"
|
| 1698 |
+
)
|
| 1699 |
+
|
| 1700 |
+
# Customise the colors of the bars based on whether the value is higher/lower
|
| 1701 |
+
# compared to the median
|
| 1702 |
+
colors = [
|
| 1703 |
+
(
|
| 1704 |
+
"lightgreen"
|
| 1705 |
+
if pd.isna(median)
|
| 1706 |
+
else "#ff0051" if value >= median else "#008bfb"
|
| 1707 |
+
)
|
| 1708 |
+
for value, median in zip(
|
| 1709 |
+
df_feat_importance["values"], df_feat_importance["median"]
|
| 1710 |
+
)
|
| 1711 |
+
]
|
| 1712 |
+
for bar, color in zip(ax.patches, colors):
|
| 1713 |
+
bar.set_color(color)
|
| 1714 |
+
|
| 1715 |
+
# Customize the colors of the SHAP values displayed on the bars
|
| 1716 |
+
for text, color in zip(ax.texts, ["black"] * len(ax.texts)):
|
| 1717 |
+
text.set_color(color)
|
| 1718 |
+
|
| 1719 |
+
# Save figure
|
| 1720 |
+
plt.tight_layout()
|
| 1721 |
+
plt.savefig(
|
| 1722 |
+
os.path.join(
|
| 1723 |
+
artifact_dir,
|
| 1724 |
+
"shap_local_plots",
|
| 1725 |
+
i,
|
| 1726 |
+
model_name
|
| 1727 |
+
+ "_shap_local_id_"
|
| 1728 |
+
+ str(id)
|
| 1729 |
+
+ "_"
|
| 1730 |
+
+ calib_name
|
| 1731 |
+
+ "_"
|
| 1732 |
+
+ file_suffix
|
| 1733 |
+
+ ".png",
|
| 1734 |
+
),
|
| 1735 |
+
bbox_inches="tight",
|
| 1736 |
+
)
|
| 1737 |
+
plt.close()
|
| 1738 |
+
if return_enc_converted_df is True:
|
| 1739 |
+
return data_enc_conv
|
| 1740 |
+
|
| 1741 |
+
|
| 1742 |
+
def plot_averaged_summary_plot(
|
| 1743 |
+
avg_shap_values_train, train_data, model_name, calib_type, file_suffix
|
| 1744 |
+
):
|
| 1745 |
+
"""Plots summary SHAP plot using the shap values averaged across cross validation folds.
|
| 1746 |
+
|
| 1747 |
+
Args:
|
| 1748 |
+
avg_shap_values_train (array): SHAP values averaged across CV folds.
|
| 1749 |
+
train_data (dataframe): dataframe containing feature names and values.
|
| 1750 |
+
model_name (str): name of model.
|
| 1751 |
+
calib_type (str): type of calibration performed.
|
| 1752 |
+
file_suffix (str): type of model.
|
| 1753 |
+
|
| 1754 |
+
Returns:
|
| 1755 |
+
None.
|
| 1756 |
+
|
| 1757 |
+
"""
|
| 1758 |
+
# Create folder to contain summary plots
|
| 1759 |
+
os.makedirs("./tmp/shap_summary_plots", exist_ok=True)
|
| 1760 |
+
|
| 1761 |
+
shap.summary_plot(
|
| 1762 |
+
np.array(avg_shap_values_train), train_data, max_display=None, show=False
|
| 1763 |
+
)
|
| 1764 |
+
plt.savefig(
|
| 1765 |
+
os.path.join(
|
| 1766 |
+
"./tmp/shap_summary_plots/",
|
| 1767 |
+
model_name + "_shap_" + calib_type + "_" + file_suffix + ".png",
|
| 1768 |
+
)
|
| 1769 |
+
)
|
| 1770 |
+
plt.close()
|
| 1771 |
+
|
| 1772 |
+
|
| 1773 |
+
def plot_shap_decision_plot(
|
| 1774 |
+
explainer, shap_values, test_df, link, row_ids_to_plot, artifact_dir
|
| 1775 |
+
):
|
| 1776 |
+
"""Plots SHAP decision plots for specified row numbers.
|
| 1777 |
+
|
| 1778 |
+
Args:
|
| 1779 |
+
explainer (object): explainer used for explaining model output.
|
| 1780 |
+
shap_values (array): matrix of SHAP values (# samples x # features).
|
| 1781 |
+
test_df (dataframe): test df which provides the values and feature names.
|
| 1782 |
+
link (str): specifies transformation for the x-axis. Use "logit" to transform
|
| 1783 |
+
log-odds into probabilities.
|
| 1784 |
+
row_ids_to_plot (list): list of ids for which to plot the decision plot for.
|
| 1785 |
+
artifact_dir (str): output directory where plot is saved.
|
| 1786 |
+
|
| 1787 |
+
Returns:
|
| 1788 |
+
None.
|
| 1789 |
+
|
| 1790 |
+
"""
|
| 1791 |
+
# Create folders to contain decision plots
|
| 1792 |
+
os.makedirs(os.path.join(artifact_dir, "decision_plots"), exist_ok=True)
|
| 1793 |
+
|
| 1794 |
+
# Check that all rows specified are present in the cv fold. If rows not in the same cv
|
| 1795 |
+
# fold, descision plot not plotted.
|
| 1796 |
+
if all(x in test_df.index.tolist() for x in row_ids_to_plot):
|
| 1797 |
+
# If expected value provided for both classes, get the value for class 1
|
| 1798 |
+
try:
|
| 1799 |
+
expected_val = explainer.expected_value[1]
|
| 1800 |
+
except:
|
| 1801 |
+
expected_val = explainer.expected_value
|
| 1802 |
+
else:
|
| 1803 |
+
raise IndexError("All rows specified are not present in the same CV fold.")
|
| 1804 |
+
|
| 1805 |
+
# Create decision plot with all samples
|
| 1806 |
+
shap_return = shap.decision_plot(
|
| 1807 |
+
expected_val, shap_values, test_df, link=link, show=False, return_objects=True
|
| 1808 |
+
)
|
| 1809 |
+
plt.tight_layout()
|
| 1810 |
+
plt.savefig(os.path.join(artifact_dir, "decision_plots", "cv_decision_plot.png"))
|
| 1811 |
+
plt.close()
|
| 1812 |
+
|
| 1813 |
+
# Create decision plot for row ids specified keeping the same feature order as in
|
| 1814 |
+
# the decision plot with all samples
|
| 1815 |
+
for i in row_ids_to_plot:
|
| 1816 |
+
shap.decision_plot(
|
| 1817 |
+
expected_val,
|
| 1818 |
+
shap_values[i],
|
| 1819 |
+
test_df.iloc[i],
|
| 1820 |
+
feature_order=shap_return.feature_idx,
|
| 1821 |
+
link=link,
|
| 1822 |
+
show=False,
|
| 1823 |
+
)
|
| 1824 |
+
plt.tight_layout()
|
| 1825 |
+
plt.savefig(
|
| 1826 |
+
os.path.join(
|
| 1827 |
+
artifact_dir, "decision_plots", "decision_plot_rownum" + str(i) + ".png"
|
| 1828 |
+
)
|
| 1829 |
+
)
|
| 1830 |
+
plt.close()
|
| 1831 |
+
|
| 1832 |
+
|
| 1833 |
+
def plot_shap_summary_plot_per_cv_fold(
|
| 1834 |
+
shap_values_train, X_train, calib_type, fold_num, model_name, file_suffix
|
| 1835 |
+
):
|
| 1836 |
+
"""Plots SHAP summary plot for each CV fold.
|
| 1837 |
+
|
| 1838 |
+
Args:
|
| 1839 |
+
shap_values_train (array): shap values for the train set.
|
| 1840 |
+
X_train (dataframe): contains values for the train set.
|
| 1841 |
+
calib_type (str): type of calibration performed.
|
| 1842 |
+
fold_num (str): CV fold number.
|
| 1843 |
+
model_name (str): name of model.
|
| 1844 |
+
file_suffix (str): type of model run.
|
| 1845 |
+
|
| 1846 |
+
Returns:
|
| 1847 |
+
None.
|
| 1848 |
+
|
| 1849 |
+
"""
|
| 1850 |
+
# Create folders to contain shap summary plots for cv folds for each calibration type
|
| 1851 |
+
os.makedirs("./tmp/shap_cv_folds_" + calib_type, exist_ok=True)
|
| 1852 |
+
|
| 1853 |
+
shap.summary_plot(
|
| 1854 |
+
np.array(shap_values_train), X_train, max_display=None, show=False
|
| 1855 |
+
)
|
| 1856 |
+
plt.savefig(
|
| 1857 |
+
os.path.join(
|
| 1858 |
+
"./tmp/",
|
| 1859 |
+
"shap_cv_folds_" + calib_type,
|
| 1860 |
+
model_name
|
| 1861 |
+
+ "_shap_"
|
| 1862 |
+
+ calib_type
|
| 1863 |
+
+ "_cv_fold_"
|
| 1864 |
+
+ str(fold_num)
|
| 1865 |
+
+ "_"
|
| 1866 |
+
+ file_suffix
|
| 1867 |
+
+ ".png",
|
| 1868 |
+
)
|
| 1869 |
+
)
|
| 1870 |
+
plt.close()
|
| 1871 |
+
|
| 1872 |
+
|
| 1873 |
+
def get_calibrated_shap_by_classifier(
|
| 1874 |
+
calib_model, x_test, x_train, features, calib_type, model_name, file_suffix
|
| 1875 |
+
):
|
| 1876 |
+
"""
|
| 1877 |
+
Iterated over base classifier in calibrated model and averages the shap
|
| 1878 |
+
values for both test and train.
|
| 1879 |
+
|
| 1880 |
+
Parameters
|
| 1881 |
+
----------
|
| 1882 |
+
calib_model : calibrated sklearn model
|
| 1883 |
+
Trained calibrated sklearn model.
|
| 1884 |
+
x_test : pandas dataframe
|
| 1885 |
+
test features dataframe.
|
| 1886 |
+
x_train : pandas dataframe
|
| 1887 |
+
training features dataframe.
|
| 1888 |
+
features : list
|
| 1889 |
+
name of columns.
|
| 1890 |
+
calib_type : str
|
| 1891 |
+
type of calibration.
|
| 1892 |
+
model_name : str
|
| 1893 |
+
name of model.
|
| 1894 |
+
file_suffix : str
|
| 1895 |
+
type of model run.
|
| 1896 |
+
|
| 1897 |
+
Returns
|
| 1898 |
+
-------
|
| 1899 |
+
shap_values_v : array
|
| 1900 |
+
test shap values.
|
| 1901 |
+
shap_values_t : array
|
| 1902 |
+
train shap values.
|
| 1903 |
+
|
| 1904 |
+
"""
|
| 1905 |
+
# https://github.com/slundberg/shap/issues/899
|
| 1906 |
+
shap_values_list_val = []
|
| 1907 |
+
shap_values_list_train = []
|
| 1908 |
+
base_list = []
|
| 1909 |
+
fold_num = 0
|
| 1910 |
+
|
| 1911 |
+
for calibrated_classifier in calib_model.calibrated_classifiers_:
|
| 1912 |
+
fold_num = fold_num + 1
|
| 1913 |
+
explainer = shap.TreeExplainer(calibrated_classifier.estimator)
|
| 1914 |
+
shap_values_val = explainer.shap_values(x_test)
|
| 1915 |
+
shap_values_train = explainer.shap_values(x_train)
|
| 1916 |
+
if len(np.shape(shap_values_train)) == 3:
|
| 1917 |
+
shap_values_train = shap_values_train[1]
|
| 1918 |
+
shap_values_val = shap_values_val[1]
|
| 1919 |
+
shap_values_list_val.append(shap_values_val)
|
| 1920 |
+
shap_values_list_train.append(shap_values_train)
|
| 1921 |
+
base_list.append(explainer.expected_value)
|
| 1922 |
+
plot_shap_summary_plot_per_cv_fold(
|
| 1923 |
+
shap_values_train, x_train, calib_type, fold_num, model_name, file_suffix
|
| 1924 |
+
)
|
| 1925 |
+
shap_values_v = np.array(shap_values_list_val).sum(axis=0) / len(
|
| 1926 |
+
shap_values_list_val
|
| 1927 |
+
)
|
| 1928 |
+
shap_values_t = np.array(shap_values_list_train).sum(axis=0) / len(
|
| 1929 |
+
shap_values_list_train
|
| 1930 |
+
)
|
| 1931 |
+
|
| 1932 |
+
shap_values_global = pd.DataFrame(np.abs(shap_values_t), columns=features)
|
| 1933 |
+
global_shap = shap_values_global.mean(axis=0)
|
| 1934 |
+
global_shap_round = global_shap.round(3, None)
|
| 1935 |
+
x = global_shap_round.to_json(double_precision=3)
|
| 1936 |
+
x = x.replace("'", '"')
|
| 1937 |
+
mean_shap = json.loads(x)
|
| 1938 |
+
return shap_values_v, shap_values_t
|
| 1939 |
+
|
| 1940 |
+
|
| 1941 |
+
def get_uncalibrated_shap(
|
| 1942 |
+
uncal_model_estimators, x_test, x_train, features, model_name, file_suffix
|
| 1943 |
+
):
|
| 1944 |
+
"""
|
| 1945 |
+
Iterated over base classifier in calibrated model and averages the shap
|
| 1946 |
+
values for both test and train.
|
| 1947 |
+
|
| 1948 |
+
Parameters
|
| 1949 |
+
----------
|
| 1950 |
+
uncal_model_estimators : uncalibrated model
|
| 1951 |
+
Trained uncalibrated model.
|
| 1952 |
+
x_test : pandas dataframe
|
| 1953 |
+
test features dataframe.
|
| 1954 |
+
x_train : pandas dataframe
|
| 1955 |
+
training features dataframe.
|
| 1956 |
+
features : list
|
| 1957 |
+
name of columns.
|
| 1958 |
+
model_name : str
|
| 1959 |
+
name of model.
|
| 1960 |
+
file_suffix : str
|
| 1961 |
+
type of model run.
|
| 1962 |
+
|
| 1963 |
+
Returns
|
| 1964 |
+
-------
|
| 1965 |
+
shap_values_v : array
|
| 1966 |
+
test shap values.
|
| 1967 |
+
shap_values_t : array
|
| 1968 |
+
train shap values.
|
| 1969 |
+
|
| 1970 |
+
"""
|
| 1971 |
+
# https://github.com/slundberg/shap/issues/899
|
| 1972 |
+
shap_values_list_val = []
|
| 1973 |
+
shap_values_list_train = []
|
| 1974 |
+
base_list = []
|
| 1975 |
+
fold_num = 0
|
| 1976 |
+
for estimator in uncal_model_estimators:
|
| 1977 |
+
explainer = shap.TreeExplainer(estimator)
|
| 1978 |
+
shap_values_val = explainer.shap_values(x_test)
|
| 1979 |
+
shap_values_train = explainer.shap_values(x_train)
|
| 1980 |
+
if len(np.shape(shap_values_train)) == 3:
|
| 1981 |
+
shap_values_train = shap_values_train[1]
|
| 1982 |
+
shap_values_val = shap_values_val[1]
|
| 1983 |
+
shap_values_list_val.append(shap_values_val)
|
| 1984 |
+
shap_values_list_train.append(shap_values_train)
|
| 1985 |
+
base_list.append(explainer.expected_value)
|
| 1986 |
+
fold_num = fold_num + 1
|
| 1987 |
+
plot_shap_summary_plot_per_cv_fold(
|
| 1988 |
+
shap_values_train, x_train, "uncalib", fold_num, model_name, file_suffix
|
| 1989 |
+
)
|
| 1990 |
+
shap_values_v = np.array(shap_values_list_val).sum(axis=0) / len(
|
| 1991 |
+
shap_values_list_val
|
| 1992 |
+
)
|
| 1993 |
+
shap_values_t = np.array(shap_values_list_train).sum(axis=0) / len(
|
| 1994 |
+
shap_values_list_train
|
| 1995 |
+
)
|
| 1996 |
+
shap_values_global = pd.DataFrame(np.abs(shap_values_t), columns=features)
|
| 1997 |
+
global_shap = shap_values_global.mean(axis=0)
|
| 1998 |
+
global_shap_round = global_shap.round(3, None)
|
| 1999 |
+
x = global_shap_round.to_json(double_precision=3)
|
| 2000 |
+
x = x.replace("'", '"')
|
| 2001 |
+
mean_shap = json.loads(x)
|
| 2002 |
+
return shap_values_v, shap_values_t
|
| 2003 |
+
|
| 2004 |
+
|
| 2005 |
+
def plot_shap_interaction_value_heatmap(
|
| 2006 |
+
estimators, train_features, column_names, model_name, file_suffix
|
| 2007 |
+
):
|
| 2008 |
+
"""Calculate SHAP interaction values and plot on heatmap.
|
| 2009 |
+
|
| 2010 |
+
Args:
|
| 2011 |
+
estimators (list): list of estimators used during cross validation.
|
| 2012 |
+
train_features (array): train set values.
|
| 2013 |
+
column_names (list): names of columns.
|
| 2014 |
+
model_name (str): name of model.
|
| 2015 |
+
file_suffix (str): name of model run.
|
| 2016 |
+
|
| 2017 |
+
|
| 2018 |
+
Returns:
|
| 2019 |
+
None.
|
| 2020 |
+
|
| 2021 |
+
"""
|
| 2022 |
+
os.makedirs("./tmp/interaction_plot", exist_ok=True)
|
| 2023 |
+
fold_num = 0
|
| 2024 |
+
for estimator in estimators:
|
| 2025 |
+
fold_num = fold_num + 1
|
| 2026 |
+
explainer = shap.TreeExplainer(estimator)
|
| 2027 |
+
shap_interaction = explainer.shap_interaction_values(train_features)
|
| 2028 |
+
|
| 2029 |
+
# Some values come in the shape (#class, #samples, #features, #features). Subset
|
| 2030 |
+
# these cases to class 1.
|
| 2031 |
+
if len(np.shape(shap_interaction)) == 4:
|
| 2032 |
+
shap_interaction = shap_interaction[1]
|
| 2033 |
+
|
| 2034 |
+
# Plot heatmap
|
| 2035 |
+
mean_shap = np.abs(shap_interaction).mean(0)
|
| 2036 |
+
df = pd.DataFrame(mean_shap, index=column_names, columns=column_names)
|
| 2037 |
+
df.where(df.values == np.diagonal(df), df.values * 2, inplace=True)
|
| 2038 |
+
fig = plt.figure(figsize=(35, 20), facecolor="#002637", edgecolor="r")
|
| 2039 |
+
ax = fig.add_subplot()
|
| 2040 |
+
sns.heatmap(
|
| 2041 |
+
df.round(decimals=3),
|
| 2042 |
+
cmap="coolwarm",
|
| 2043 |
+
annot=True,
|
| 2044 |
+
fmt=".6g",
|
| 2045 |
+
cbar=False,
|
| 2046 |
+
ax=ax,
|
| 2047 |
+
)
|
| 2048 |
+
ax.tick_params(axis="x", colors="w", labelsize=15, rotation=90)
|
| 2049 |
+
ax.tick_params(axis="y", colors="w", labelsize=15)
|
| 2050 |
+
plt.suptitle("SHAP interaction values", color="white", fontsize=60, y=0.97)
|
| 2051 |
+
plt.yticks(rotation=0)
|
| 2052 |
+
plt.savefig(
|
| 2053 |
+
"./tmp/interaction_plot/shap_interaction_heatmap_cv_"
|
| 2054 |
+
+ str(fold_num)
|
| 2055 |
+
+ "_"
|
| 2056 |
+
+ model_name
|
| 2057 |
+
+ "_"
|
| 2058 |
+
+ file_suffix
|
| 2059 |
+
+ ".png"
|
| 2060 |
+
)
|
| 2061 |
+
plt.close()
|
training/perform_forward_validation.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import pickle
|
| 4 |
+
import model_h
|
| 5 |
+
import mlflow
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import sys
|
| 10 |
+
import scipy
|
| 11 |
+
import yaml
|
| 12 |
+
|
| 13 |
+
with open("./training/config.yaml", "r") as config:
|
| 14 |
+
config = yaml.safe_load(config)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def perform_ks_test(train_data, forward_val_data):
|
| 18 |
+
"""Perform Kolmogorov-Smirnov test.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
train_data (pd.DataFrame): data used to train model.
|
| 22 |
+
forward_val_data (pd.DataFrame): data used for the forward validation.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
pd.DataFrame: dataframe containing the results of the K-S test.
|
| 26 |
+
"""
|
| 27 |
+
for num, feature_name in enumerate(train_data.columns.tolist()):
|
| 28 |
+
statistic, pvalue = scipy.stats.ks_2samp(
|
| 29 |
+
train_data[feature_name], forward_val_data[feature_name]
|
| 30 |
+
)
|
| 31 |
+
pvalue = round(pvalue, 4)
|
| 32 |
+
if num == 0:
|
| 33 |
+
df_ks = pd.DataFrame(
|
| 34 |
+
{
|
| 35 |
+
"FeatureName": feature_name,
|
| 36 |
+
"KS_PValue": pvalue,
|
| 37 |
+
"KS_TestStatistic": statistic,
|
| 38 |
+
},
|
| 39 |
+
index=[num],
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
df_ks_feat = pd.DataFrame(
|
| 43 |
+
{
|
| 44 |
+
"FeatureName": feature_name,
|
| 45 |
+
"KS_PValue": pvalue,
|
| 46 |
+
"KS_TestStatistic": statistic,
|
| 47 |
+
},
|
| 48 |
+
index=[num],
|
| 49 |
+
)
|
| 50 |
+
df_ks = pd.concat([df_ks, df_ks_feat])
|
| 51 |
+
df_ks["KS_DistributionsIdentical"] = np.where(df_ks["KS_PValue"] < 0.05, 0, 1)
|
| 52 |
+
return df_ks
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compute_wasserstein_distance(train_data, forward_val_data):
|
| 56 |
+
"""Calculate the wasserstein distance.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
train_data (pd.DataFrame): data used to train model.
|
| 60 |
+
forward_val_data (pd.DataFrame): data used for the forward validation.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
pd.DataFrame: dataframe containing the wasserstein distance results.
|
| 64 |
+
"""
|
| 65 |
+
for num, feature_name in enumerate(train_data.columns.tolist()):
|
| 66 |
+
w_distance = scipy.stats.wasserstein_distance(
|
| 67 |
+
train_data[feature_name], forward_val_data[feature_name]
|
| 68 |
+
)
|
| 69 |
+
if num == 0:
|
| 70 |
+
df_wd = pd.DataFrame(
|
| 71 |
+
{"FeatureName": feature_name, "WassersteinDistance": w_distance},
|
| 72 |
+
index=[num],
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
df_wd_feat = pd.DataFrame(
|
| 76 |
+
{"FeatureName": feature_name, "WassersteinDistance": w_distance},
|
| 77 |
+
index=[num],
|
| 78 |
+
)
|
| 79 |
+
df_wd = pd.concat([df_wd, df_wd_feat])
|
| 80 |
+
df_wd = df_wd.sort_values(by="WassersteinDistance", ascending=True)
|
| 81 |
+
return df_wd
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
##############################################################
|
| 85 |
+
# Load data
|
| 86 |
+
##############################################################
|
| 87 |
+
model_type = config["model_settings"]["model_type"]
|
| 88 |
+
|
| 89 |
+
# Setup log file
|
| 90 |
+
log = open(
|
| 91 |
+
os.path.join(
|
| 92 |
+
config["outputs"]["logging_dir"], "run_forward_val_" + model_type + ".log"
|
| 93 |
+
),
|
| 94 |
+
"w",
|
| 95 |
+
)
|
| 96 |
+
sys.stdout = log
|
| 97 |
+
|
| 98 |
+
# Load test data
|
| 99 |
+
forward_val_data_imputed = pd.read_pickle(
|
| 100 |
+
os.path.join(
|
| 101 |
+
config["outputs"]["model_input_data_dir"],
|
| 102 |
+
"forward_val_imputed_{}.pkl".format(model_type),
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
forward_val_data_not_imputed = pd.read_pickle(
|
| 106 |
+
os.path.join(
|
| 107 |
+
config["outputs"]["model_input_data_dir"],
|
| 108 |
+
"forward_val_not_imputed_{}.pkl".format(model_type),
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Load exac event type data
|
| 113 |
+
#test_exac_data = pd.read_pickle("./data/forward_val_exac_data.pkl")
|
| 114 |
+
|
| 115 |
+
# Load data the model was trained on
|
| 116 |
+
train_data = model_h.load_data_for_modelling(
|
| 117 |
+
os.path.join(
|
| 118 |
+
config["outputs"]["model_input_data_dir"],
|
| 119 |
+
"crossval_imputed_{}.pkl".format(model_type),
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
##############################################################
|
| 124 |
+
# Check for data drift
|
| 125 |
+
##############################################################
|
| 126 |
+
train_data_for_data_drift = train_data.drop(columns=["StudyId", "IndexDate"])
|
| 127 |
+
forward_val_data_for_data_drift = forward_val_data_imputed.drop(columns=["StudyId", "IndexDate"])
|
| 128 |
+
|
| 129 |
+
df_ks = perform_ks_test(train_data_for_data_drift, forward_val_data_for_data_drift)
|
| 130 |
+
df_wd = compute_wasserstein_distance(
|
| 131 |
+
train_data_for_data_drift, forward_val_data_for_data_drift
|
| 132 |
+
)
|
| 133 |
+
df_data_drift = df_wd.merge(df_ks, on="FeatureName", how="left")
|
| 134 |
+
print(df_data_drift)
|
| 135 |
+
|
| 136 |
+
##############################################################
|
| 137 |
+
# Prepare data for running model
|
| 138 |
+
##############################################################
|
| 139 |
+
# Value counts for hospital and community exacerbations
|
| 140 |
+
print(forward_val_data_imputed["ExacWithin3Months"].value_counts())
|
| 141 |
+
print(
|
| 142 |
+
forward_val_data_imputed[forward_val_data_imputed["ExacWithin3Months"] == 1][
|
| 143 |
+
"HospExacWithin3Months"
|
| 144 |
+
].value_counts()
|
| 145 |
+
)
|
| 146 |
+
print(
|
| 147 |
+
forward_val_data_imputed[forward_val_data_imputed["ExacWithin3Months"] == 1][
|
| 148 |
+
"CommExacWithin3Months"
|
| 149 |
+
].value_counts()
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Separate features and target
|
| 153 |
+
forward_val_features_imp = forward_val_data_imputed.drop(
|
| 154 |
+
columns=["StudyId", "IndexDate", "ExacWithin3Months", 'HospExacWithin3Months',
|
| 155 |
+
'CommExacWithin3Months']
|
| 156 |
+
)
|
| 157 |
+
forward_val_target_imp = forward_val_data_imputed["ExacWithin3Months"]
|
| 158 |
+
forward_val_features_no_imp = forward_val_data_not_imputed.drop(
|
| 159 |
+
columns=["StudyId", "IndexDate", "ExacWithin3Months", 'HospExacWithin3Months',
|
| 160 |
+
'CommExacWithin3Months']
|
| 161 |
+
)
|
| 162 |
+
forward_val_target_no_imp = forward_val_data_not_imputed["ExacWithin3Months"]
|
| 163 |
+
|
| 164 |
+
# Check that the target in imputed and not imputed datasets are the same. If not,
|
| 165 |
+
# raise an error
|
| 166 |
+
if not forward_val_target_no_imp.equals(forward_val_target_imp):
|
| 167 |
+
raise ValueError(
|
| 168 |
+
"Target variable is not the same in imputed and non imputed datasets in the test set."
|
| 169 |
+
)
|
| 170 |
+
test_target = forward_val_target_no_imp
|
| 171 |
+
|
| 172 |
+
# Make sure all features are numeric
|
| 173 |
+
for features in [forward_val_features_imp, forward_val_features_no_imp]:
|
| 174 |
+
for col in features:
|
| 175 |
+
features[col] = pd.to_numeric(features[col], errors="coerce")
|
| 176 |
+
|
| 177 |
+
# Make a list of models to perform forward validation on. Contains model name, whether
|
| 178 |
+
# imputation was performed, and the threshold used in the original model
|
| 179 |
+
models = [
|
| 180 |
+
("balanced_random_forest", "imputed", 0.27),
|
| 181 |
+
# ("xgb", "not_imputed", 0.44),
|
| 182 |
+
# ("random_forest", "imputed", 0.30),
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
##############################################################
|
| 186 |
+
# Run models
|
| 187 |
+
##############################################################
|
| 188 |
+
mlflow.set_tracking_uri("sqlite:///mlruns.db")
|
| 189 |
+
mlflow.set_experiment("model_h_drop_1_hosp_comm")
|
| 190 |
+
|
| 191 |
+
with mlflow.start_run(run_name="sig_forward_val_models_10_2023"):
|
| 192 |
+
for model_info in models:
|
| 193 |
+
print(model_info[0])
|
| 194 |
+
with mlflow.start_run(run_name=model_info[0], nested=True):
|
| 195 |
+
# Create the artifacts directory if it doesn't exist
|
| 196 |
+
os.makedirs(config["outputs"]["artifact_dir"], exist_ok=True)
|
| 197 |
+
# Remove existing directory contents to not mix files between different runs
|
| 198 |
+
shutil.rmtree(config["outputs"]["artifact_dir"])
|
| 199 |
+
|
| 200 |
+
#### Load model ####
|
| 201 |
+
with open("./data/model/trained_iso_" + model_info[0] + "_pkl", "rb") as f:
|
| 202 |
+
model = pickle.load(f)
|
| 203 |
+
|
| 204 |
+
# Select the correct data based on model used
|
| 205 |
+
if model_info[1] == "imputed":
|
| 206 |
+
test_features = forward_val_features_imp
|
| 207 |
+
else:
|
| 208 |
+
test_features = forward_val_features_no_imp
|
| 209 |
+
|
| 210 |
+
#### Run model and get predictions for forward validation data ####
|
| 211 |
+
test_probs = model.predict_proba(test_features)[:, 1]
|
| 212 |
+
test_preds = model.predict(test_features)
|
| 213 |
+
|
| 214 |
+
#### Calculate metrics ####
|
| 215 |
+
metrics = model_h.calc_eval_metrics_for_model(
|
| 216 |
+
test_target,
|
| 217 |
+
test_preds,
|
| 218 |
+
test_probs,
|
| 219 |
+
"forward_val",
|
| 220 |
+
best_threshold=model_info[2],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
#### Plot confusion matrix ####
|
| 224 |
+
model_h.plot_confusion_matrix(
|
| 225 |
+
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, model_info[2]],
|
| 226 |
+
test_probs,
|
| 227 |
+
test_target,
|
| 228 |
+
model_info[0],
|
| 229 |
+
model_type,
|
| 230 |
+
"forward_val",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
#### Plot calibration curves ####
|
| 234 |
+
for bins in [6, 10]:
|
| 235 |
+
plt.figure(figsize=(8, 8))
|
| 236 |
+
plt.plot([0, 1], [0, 1], linestyle="--")
|
| 237 |
+
model_h.plot_calibration_curve(
|
| 238 |
+
test_target, test_probs, bins, "quantile", "Forward Validation"
|
| 239 |
+
)
|
| 240 |
+
plt.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
|
| 241 |
+
plt.title(model_info[0])
|
| 242 |
+
plt.tight_layout()
|
| 243 |
+
plt.savefig(
|
| 244 |
+
os.path.join(
|
| 245 |
+
config["outputs"]["artifact_dir"],
|
| 246 |
+
model_info[0]
|
| 247 |
+
+ "_"
|
| 248 |
+
+ "quantile"
|
| 249 |
+
+ "_bins"
|
| 250 |
+
+ str(bins)
|
| 251 |
+
+ model_type
|
| 252 |
+
+ ".png",
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
plt.close()
|
| 256 |
+
|
| 257 |
+
#### Calculate model performance by event type ####
|
| 258 |
+
# Create df to contain prediction data and event type data
|
| 259 |
+
preds_events_df_forward_val = model_h.create_df_probabilities_and_predictions(
|
| 260 |
+
test_probs,
|
| 261 |
+
model_info[2],
|
| 262 |
+
forward_val_data_imputed["StudyId"].tolist(),
|
| 263 |
+
test_target,
|
| 264 |
+
forward_val_data_imputed[["ExacWithin3Months", 'HospExacWithin3Months',
|
| 265 |
+
'CommExacWithin3Months']],
|
| 266 |
+
model_info[0],
|
| 267 |
+
model_type,
|
| 268 |
+
output_dir="./data/prediction_and_events/",
|
| 269 |
+
calib_type="forward_val",
|
| 270 |
+
)
|
| 271 |
+
# Subset to each event type and calculate metrics
|
| 272 |
+
metrics_by_event_type_forward_val = model_h.calc_metrics_by_event_type(
|
| 273 |
+
preds_events_df_forward_val, calib_type="forward_val"
|
| 274 |
+
)
|
| 275 |
+
# Subset to each event type and plot ROC curve
|
| 276 |
+
model_h.plot_roc_curve_by_event_type(
|
| 277 |
+
preds_events_df_forward_val, model_info[0], "forward_val"
|
| 278 |
+
)
|
| 279 |
+
# Subset to each event type and plot PR curve
|
| 280 |
+
model_h.plot_prec_recall_by_event_type(
|
| 281 |
+
preds_events_df_forward_val, model_info[0], "forward_val"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
#### Plot distribution of model scores for uncalibrated model ####
|
| 285 |
+
model_h.plot_score_distribution(
|
| 286 |
+
test_target,
|
| 287 |
+
test_probs,
|
| 288 |
+
config["outputs"]["artifact_dir"],
|
| 289 |
+
model_info[0],
|
| 290 |
+
model_type,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
#### Log to MLFlow ####
|
| 294 |
+
mlflow.log_metrics(metrics)
|
| 295 |
+
mlflow.log_artifacts(config["outputs"]["artifact_dir"])
|
| 296 |
+
mlflow.end_run()
|
training/perform_hyper_param_tuning.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import mlflow
|
| 5 |
+
import shutil
|
| 6 |
+
import model_h
|
| 7 |
+
|
| 8 |
+
# Model training and evaluation
|
| 9 |
+
from sklearn.linear_model import LogisticRegression
|
| 10 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 11 |
+
from imblearn.ensemble import BalancedRandomForestClassifier
|
| 12 |
+
import xgboost as xgb
|
| 13 |
+
from skopt import BayesSearchCV
|
| 14 |
+
from skopt.space import Real, Integer
|
| 15 |
+
|
| 16 |
+
##############################################################
|
| 17 |
+
# Specify which model to perform cross validation on
|
| 18 |
+
##############################################################
|
| 19 |
+
model_only_hosp = False
|
| 20 |
+
if model_only_hosp is True:
|
| 21 |
+
file_suffix = "_only_hosp"
|
| 22 |
+
else:
|
| 23 |
+
file_suffix = "_hosp_comm"
|
| 24 |
+
|
| 25 |
+
##############################################################
|
| 26 |
+
# Load data
|
| 27 |
+
##############################################################
|
| 28 |
+
# Load CV folds
|
| 29 |
+
fold_patients = np.load(
|
| 30 |
+
'./data/cohort_info/fold_patients' + file_suffix + '.npy', allow_pickle=True)
|
| 31 |
+
|
| 32 |
+
# Load imputed train data
|
| 33 |
+
train_data_imp = model_h.load_data_for_modelling(
|
| 34 |
+
'./data/model_data/train_data_cv_imp' + file_suffix + '.pkl')
|
| 35 |
+
|
| 36 |
+
# Load not imputed train data
|
| 37 |
+
train_data_no_imp = model_h.load_data_for_modelling(
|
| 38 |
+
'./data/model_data/train_data_cv_no_imp' + file_suffix + '.pkl')
|
| 39 |
+
|
| 40 |
+
# Load imputed test data
|
| 41 |
+
test_data_imp = model_h.load_data_for_modelling(
|
| 42 |
+
'./data/model_data/test_data_imp' + file_suffix + '.pkl')
|
| 43 |
+
# Load not imputed test data
|
| 44 |
+
test_data_no_imp = model_h.load_data_for_modelling(
|
| 45 |
+
'./data/model_data/test_data_no_imp' + file_suffix + '.pkl')
|
| 46 |
+
|
| 47 |
+
# Create a tuple with training and validation indicies for each fold. Can be done with
|
| 48 |
+
# either imputed or not imputed data as both have same patients
|
| 49 |
+
cross_val_fold_indices = []
|
| 50 |
+
for fold in fold_patients:
|
| 51 |
+
fold_val_ids = train_data_no_imp[train_data_no_imp.StudyId.isin(fold)]
|
| 52 |
+
fold_train_ids = train_data_no_imp[~(
|
| 53 |
+
train_data_no_imp.StudyId.isin(fold_val_ids.StudyId))]
|
| 54 |
+
|
| 55 |
+
# Get index of rows in val and train
|
| 56 |
+
fold_val_index = fold_val_ids.index
|
| 57 |
+
fold_train_index = fold_train_ids.index
|
| 58 |
+
|
| 59 |
+
# Append tuple of training and val indices
|
| 60 |
+
cross_val_fold_indices.append((fold_train_index, fold_val_index))
|
| 61 |
+
|
| 62 |
+
# Create list of model features
|
| 63 |
+
cols_to_drop = ['StudyId', 'ExacWithin3Months', 'IndexDate']
|
| 64 |
+
features_list = [col for col in train_data_no_imp.columns if col not in cols_to_drop]
|
| 65 |
+
|
| 66 |
+
# Train data
|
| 67 |
+
# Separate features from target for data with no imputation performed
|
| 68 |
+
train_features_no_imp = train_data_no_imp[features_list].astype('float')
|
| 69 |
+
train_target_no_imp = train_data_no_imp.ExacWithin3Months.astype('float')
|
| 70 |
+
# Separate features from target for data with no imputation performed
|
| 71 |
+
train_features_imp = train_data_imp[features_list].astype('float')
|
| 72 |
+
train_target_imp = train_data_imp.ExacWithin3Months.astype('float')
|
| 73 |
+
|
| 74 |
+
# Test data
|
| 75 |
+
# Separate features from target for data with no imputation performed
|
| 76 |
+
test_features_no_imp = test_data_no_imp[features_list].astype('float')
|
| 77 |
+
test_target_no_imp = test_data_no_imp.ExacWithin3Months.astype('float')
|
| 78 |
+
# Separate features from target for data with no imputation performed
|
| 79 |
+
test_features_imp = test_data_imp[features_list].astype('float')
|
| 80 |
+
test_target_imp = test_data_imp.ExacWithin3Months.astype('float')
|
| 81 |
+
|
| 82 |
+
# Check that the target in imputed and not imputed datasets are the same. If not,
|
| 83 |
+
# raise an error
|
| 84 |
+
if not train_target_no_imp.equals(train_target_imp):
|
| 85 |
+
raise ValueError(
|
| 86 |
+
'Target variable is not the same in imputed and non imputed datasets in the train set.')
|
| 87 |
+
if not test_target_no_imp.equals(test_target_imp):
|
| 88 |
+
raise ValueError(
|
| 89 |
+
'Target variable is not the same in imputed and non imputed datasets in the test set.')
|
| 90 |
+
train_target = train_target_no_imp
|
| 91 |
+
test_target = test_target_no_imp
|
| 92 |
+
|
| 93 |
+
# Make sure all features are numeric
|
| 94 |
+
for features in [train_features_no_imp, train_features_imp,
|
| 95 |
+
test_features_no_imp, test_features_imp]:
|
| 96 |
+
for col in features:
|
| 97 |
+
features[col] = pd.to_numeric(features[col], errors='coerce')
|
| 98 |
+
|
| 99 |
+
##############################################################
|
| 100 |
+
# Specify which models to evaluate
|
| 101 |
+
##############################################################
|
| 102 |
+
# Set up MLflow
|
| 103 |
+
mlflow.set_tracking_uri("sqlite:///mlruns.db")
|
| 104 |
+
mlflow.set_experiment('model_h_drop_1' + file_suffix)
|
| 105 |
+
|
| 106 |
+
# Set CV scoring strategies and any model parameters
|
| 107 |
+
scoring_methods = ['average_precision']
|
| 108 |
+
scale_pos_weight = train_target.value_counts()[0] / train_target.value_counts()[1]
|
| 109 |
+
|
| 110 |
+
# Set up models, each tuple contains 4 elements: model, model name, imputation status,
|
| 111 |
+
# type of model
|
| 112 |
+
models = []
|
| 113 |
+
# Run different models depending on which parallel model is being used.
|
| 114 |
+
if model_only_hosp is True:
|
| 115 |
+
# Logistic regression
|
| 116 |
+
models.append((LogisticRegression(),
|
| 117 |
+
'logistic_regression', 'imputed', 'linear'))
|
| 118 |
+
# Balanced random forest
|
| 119 |
+
models.append((BalancedRandomForestClassifier(),
|
| 120 |
+
'balanced_random_forest', 'imputed', 'tree'))
|
| 121 |
+
# XGBoost
|
| 122 |
+
models.append((xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss'),
|
| 123 |
+
'xgb', 'not_imputed', 'tree'))
|
| 124 |
+
if model_only_hosp is False:
|
| 125 |
+
# Logistic regression
|
| 126 |
+
models.append((LogisticRegression(),
|
| 127 |
+
'logistic_regression', 'imputed', 'linear'))
|
| 128 |
+
# Random forest
|
| 129 |
+
models.append((RandomForestClassifier(),
|
| 130 |
+
'random_forest', 'imputed', 'tree'))
|
| 131 |
+
# Balanced random forest
|
| 132 |
+
models.append((BalancedRandomForestClassifier(),
|
| 133 |
+
'balanced_random_forest', 'imputed', 'tree'))
|
| 134 |
+
# XGBoost
|
| 135 |
+
models.append((xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss'),
|
| 136 |
+
'xgb', 'not_imputed', 'tree'))
|
| 137 |
+
|
| 138 |
+
# Define search spaces
|
| 139 |
+
log_reg_search_spaces = {'penalty': ['l2', None],
|
| 140 |
+
'class_weight': ['balanced', None],
|
| 141 |
+
'max_iter': Integer(50, 300),
|
| 142 |
+
'C': Real(0.001, 10),
|
| 143 |
+
}
|
| 144 |
+
rf_search_spaces = {'max_depth': Integer(4, 10),
|
| 145 |
+
'n_estimators': Integer(70, 850),
|
| 146 |
+
'min_samples_split': Integer(2, 10),
|
| 147 |
+
'class_weight': ['balanced', None],
|
| 148 |
+
}
|
| 149 |
+
xgb_search_spaces = {'max_depth': Integer(4, 10),
|
| 150 |
+
'n_estimators': Integer(70, 850),
|
| 151 |
+
'subsample': Real(0.55, 0.95),
|
| 152 |
+
'colsample_bytree': Real(0.55, 0.95),
|
| 153 |
+
'learning_rate': Real(0.05, 0.14),
|
| 154 |
+
'scale_pos_weight': Real(1, scale_pos_weight),
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
##############################################################
|
| 158 |
+
# Run models
|
| 159 |
+
##############################################################
|
| 160 |
+
#In MLflow run, perform K-fold cross validation and capture mean score across folds.
|
| 161 |
+
with mlflow.start_run(run_name='hyperparameter_tuning_2023_tot_length'):
|
| 162 |
+
for scoring_method in scoring_methods:
|
| 163 |
+
for model in models:
|
| 164 |
+
with mlflow.start_run(run_name=model[1], nested=True):
|
| 165 |
+
print(model[1])
|
| 166 |
+
# Create the artifacts directory if it doesn't exist
|
| 167 |
+
artifact_dir = './tmp'
|
| 168 |
+
os.makedirs(artifact_dir, exist_ok=True)
|
| 169 |
+
# Remove existing directory contents to not mix files between different runs
|
| 170 |
+
shutil.rmtree(artifact_dir)
|
| 171 |
+
|
| 172 |
+
# Run hyperparameter tuning
|
| 173 |
+
if (model[1] == 'balanced_random_forest') | (model[1] == 'random_forest'):
|
| 174 |
+
opt = BayesSearchCV(model[0],
|
| 175 |
+
search_spaces= rf_search_spaces,
|
| 176 |
+
n_iter=200,
|
| 177 |
+
random_state=0,
|
| 178 |
+
cv=cross_val_fold_indices,
|
| 179 |
+
scoring=scoring_method,
|
| 180 |
+
)
|
| 181 |
+
# Execute bayesian optimization
|
| 182 |
+
np.int = int
|
| 183 |
+
opt.fit(train_features_imp, train_target)
|
| 184 |
+
|
| 185 |
+
if model[1] == 'logistic_regression':
|
| 186 |
+
opt = BayesSearchCV(model[0],
|
| 187 |
+
search_spaces= log_reg_search_spaces,
|
| 188 |
+
n_iter=200,
|
| 189 |
+
random_state=0,
|
| 190 |
+
cv=cross_val_fold_indices,
|
| 191 |
+
scoring=scoring_method,
|
| 192 |
+
)
|
| 193 |
+
np.int = int
|
| 194 |
+
opt.fit(train_features_imp, train_target)
|
| 195 |
+
|
| 196 |
+
if model[1] == 'xgb':
|
| 197 |
+
opt = BayesSearchCV(model[0],
|
| 198 |
+
search_spaces= xgb_search_spaces,
|
| 199 |
+
n_iter=200,
|
| 200 |
+
random_state=0,
|
| 201 |
+
cv=cross_val_fold_indices,
|
| 202 |
+
scoring=scoring_method,
|
| 203 |
+
)
|
| 204 |
+
np.int = int
|
| 205 |
+
opt.fit(train_features_no_imp, train_target)
|
| 206 |
+
|
| 207 |
+
# Get scores from hyperparameter tuning
|
| 208 |
+
print(opt.best_params_)
|
| 209 |
+
print(opt.best_score_)
|
| 210 |
+
|
| 211 |
+
# Log scores from hyperparameter tuning
|
| 212 |
+
mlflow.log_param('opt_scorer', scoring_method)
|
| 213 |
+
mlflow.log_params(opt.best_params_)
|
| 214 |
+
mlflow.log_metric("opt_best_score", opt.best_score_)
|
| 215 |
+
mlflow.end_run()
|
training/process_comorbidities.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Derive features from comorbidities dataset for 2 models:
|
| 3 |
+
Parallel model 1: uses both hospital and community exacerbation events
|
| 4 |
+
Parallel model 2: uses only hospital exacerbation events
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import yaml
|
| 12 |
+
import model_h
|
| 13 |
+
|
| 14 |
+
with open("./training/config.yaml", "r") as config:
|
| 15 |
+
config = yaml.safe_load(config)
|
| 16 |
+
|
| 17 |
+
# Specify which model to generate features for
|
| 18 |
+
model_type = config["model_settings"]["model_type"]
|
| 19 |
+
|
| 20 |
+
# Setup log file
|
| 21 |
+
log = open("./training/logging/process_comorbidities_" + model_type + ".log", "w")
|
| 22 |
+
sys.stdout = log
|
| 23 |
+
|
| 24 |
+
# Dataset to process - set through config file
|
| 25 |
+
data_to_process = config["model_settings"]["data_to_process"]
|
| 26 |
+
|
| 27 |
+
# Load cohort data
|
| 28 |
+
if data_to_process == "forward_val":
|
| 29 |
+
exac_data = pd.read_pickle("./data/patient_labels_forward_val_hosp_comm.pkl")
|
| 30 |
+
patient_details = pd.read_pickle("./data/patient_details_forward_val.pkl")
|
| 31 |
+
else:
|
| 32 |
+
exac_data = pd.read_pickle("./data/patient_labels_" + model_type + ".pkl")
|
| 33 |
+
patient_details = pd.read_pickle("./data/patient_details.pkl")
|
| 34 |
+
exac_data = exac_data[["StudyId", "IndexDate"]]
|
| 35 |
+
patient_details = exac_data.merge(
|
| 36 |
+
patient_details[["StudyId", "PatientId"]],
|
| 37 |
+
on="StudyId",
|
| 38 |
+
how="left",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
comorbidities = pd.read_csv(
|
| 42 |
+
config["inputs"]["raw_data_paths"]["comorbidities"], delimiter="|"
|
| 43 |
+
)
|
| 44 |
+
comorbidities = patient_details.merge(comorbidities, on="PatientId", how="left")
|
| 45 |
+
|
| 46 |
+
# Only keep records submitted before index date
|
| 47 |
+
comorbidities["Created"] = pd.to_datetime(comorbidities["Created"], utc=True)
|
| 48 |
+
comorbidities["TimeSinceSubmission"] = (
|
| 49 |
+
comorbidities["IndexDate"] - comorbidities["Created"]
|
| 50 |
+
).dt.days
|
| 51 |
+
comorbidities = comorbidities[comorbidities["TimeSinceSubmission"] > 0]
|
| 52 |
+
|
| 53 |
+
# If multiple records submitted for same patient keep the most recent record (in relation
|
| 54 |
+
# to index date)
|
| 55 |
+
comorbidities = comorbidities.sort_values(
|
| 56 |
+
by=["StudyId", "IndexDate", "TimeSinceSubmission"]
|
| 57 |
+
)
|
| 58 |
+
comorbidities = comorbidities.drop_duplicates(
|
| 59 |
+
subset=["StudyId", "IndexDate"], keep="first"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Get list of comorbidities captured in the service
|
| 63 |
+
comorbidity_list = list(comorbidities)
|
| 64 |
+
comorbidity_list = [
|
| 65 |
+
e
|
| 66 |
+
for e in comorbidity_list
|
| 67 |
+
if e
|
| 68 |
+
not in ("PatientId", "Id", "StudyId", "IndexDate", "TimeSinceSubmission", "Created")
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# Map True/False values to integers
|
| 72 |
+
bool_mapping = {True: 1, False: 0}
|
| 73 |
+
comorbidities[comorbidity_list] = (
|
| 74 |
+
comorbidities[comorbidity_list].replace(bool_mapping).fillna(0)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Get comorbidity counts for each patient
|
| 78 |
+
comorbidities["Comorbidities"] = comorbidities[comorbidity_list].sum(axis=1)
|
| 79 |
+
|
| 80 |
+
# Drop comorbidities columns from train data but retain AsthmaOverlap
|
| 81 |
+
comorbidity_list.remove("AsthmaOverlap")
|
| 82 |
+
comorbidities = comorbidities.drop(columns=comorbidity_list)
|
| 83 |
+
comorbidities = comorbidities.drop(columns=["Id", "Created", "TimeSinceSubmission"])
|
| 84 |
+
|
| 85 |
+
# Bin number of comorbidities
|
| 86 |
+
comorb_bins = [0, 1, 3, np.inf]
|
| 87 |
+
comorb_labels = ["No comorbidities", "1-2", "3+"]
|
| 88 |
+
comorbidities["Comorbidities"] = model_h.bin_numeric_column(
|
| 89 |
+
col=comorbidities["Comorbidities"], bins=comorb_bins, labels=comorb_labels
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
comorbidities = comorbidities.drop(columns=["PatientId"])
|
| 93 |
+
|
| 94 |
+
# Save data
|
| 95 |
+
os.makedirs(config["outputs"]["processed_data_dir"], exist_ok=True)
|
| 96 |
+
if data_to_process == "forward_val":
|
| 97 |
+
comorbidities.to_pickle(
|
| 98 |
+
os.path.join(
|
| 99 |
+
config["outputs"]["processed_data_dir"],
|
| 100 |
+
"comorbidities_forward_val_" + model_type + ".pkl",
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
comorbidities.to_pickle(
|
| 105 |
+
os.path.join(
|
| 106 |
+
config["outputs"]["processed_data_dir"],
|
| 107 |
+
"comorbidities_" + model_type + ".pkl",
|
| 108 |
+
)
|
| 109 |
+
)
|
training/process_demographics.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Derive features from demographics for 2 models:
|
| 3 |
+
Parallel model 1: uses both hospital and community exacerbation events
|
| 4 |
+
Parallel model 2: uses only hospital exacerbation events
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import model_h
|
| 12 |
+
import yaml
|
| 13 |
+
|
| 14 |
+
with open("./training/config.yaml", "r") as config:
|
| 15 |
+
config = yaml.safe_load(config)
|
| 16 |
+
|
| 17 |
+
# Specify which model to generate features for
|
| 18 |
+
model_type = config["model_settings"]["model_type"]
|
| 19 |
+
|
| 20 |
+
# Setup log file
|
| 21 |
+
log = open("./training/logging/process_demographics_" + model_type + ".log", "w")
|
| 22 |
+
sys.stdout = log
|
| 23 |
+
|
| 24 |
+
# Dataset to process - set through config file
|
| 25 |
+
data_to_process = config["model_settings"]["data_to_process"]
|
| 26 |
+
|
| 27 |
+
# Load cohort data
|
| 28 |
+
if data_to_process == "forward_val":
|
| 29 |
+
data = pd.read_pickle("./data/patient_labels_forward_val_hosp_comm.pkl")
|
| 30 |
+
patient_details = pd.read_pickle("./data/patient_details_forward_val.pkl")
|
| 31 |
+
else:
|
| 32 |
+
data = pd.read_pickle("./data/patient_labels_" + model_type + ".pkl")
|
| 33 |
+
patient_details = pd.read_pickle("./data/patient_details.pkl")
|
| 34 |
+
data = data.merge(
|
| 35 |
+
patient_details[["StudyId"]],
|
| 36 |
+
on="StudyId",
|
| 37 |
+
how="left",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Calculate age
|
| 41 |
+
data["DateOfBirth"] = pd.to_datetime(data["DateOfBirth"], utc=True)
|
| 42 |
+
data["Age"] = (data["IndexDate"] - data["DateOfBirth"]).dt.days
|
| 43 |
+
data["Age"] = np.floor(data["Age"] / 365)
|
| 44 |
+
data = data.drop(columns="DateOfBirth")
|
| 45 |
+
|
| 46 |
+
# Bin patient age
|
| 47 |
+
age_bins = [0, 50, 60, 70, 80, np.inf]
|
| 48 |
+
age_labels = ["<50", "50-59", "60-69", "70-79", "80+"]
|
| 49 |
+
data["AgeBinned"] = model_h.bin_numeric_column(
|
| 50 |
+
col=data["Age"], bins=age_bins, labels=age_labels
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Smoking status: TODO
|
| 54 |
+
|
| 55 |
+
# Map the M and F sex column to binary (1=F)
|
| 56 |
+
sex_mapping = {"F": 1, "M": 0}
|
| 57 |
+
data["Sex_F"] = data.Sex.map(sex_mapping)
|
| 58 |
+
data = data.drop(columns=["Sex"])
|
| 59 |
+
|
| 60 |
+
# Save data
|
| 61 |
+
os.makedirs(config["outputs"]["processed_data_dir"], exist_ok=True)
|
| 62 |
+
if data_to_process == "forward_val":
|
| 63 |
+
data.to_pickle(
|
| 64 |
+
os.path.join(
|
| 65 |
+
config["outputs"]["processed_data_dir"],
|
| 66 |
+
"demographics_forward_val_" + model_type + ".pkl",
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
data.to_pickle(
|
| 71 |
+
os.path.join(
|
| 72 |
+
config["outputs"]["processed_data_dir"],
|
| 73 |
+
"demographics_" + model_type + ".pkl",
|
| 74 |
+
)
|
| 75 |
+
)
|
training/process_exacerbation_history.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Derive features from exacerbation event history for 2 models:
|
| 3 |
+
Parallel model 1: uses both hospital and community exacerbation events
|
| 4 |
+
Parallel model 2: uses only hospital exacerbation events
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import sys
|
| 10 |
+
import re
|
| 11 |
+
import model_h
|
| 12 |
+
import os
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
with open("./training/config.yaml", "r") as config:
|
| 16 |
+
config = yaml.safe_load(config)
|
| 17 |
+
|
| 18 |
+
# Specify which model to generate features for
|
| 19 |
+
model_type = config["model_settings"]["model_type"]
|
| 20 |
+
if model_type == "only_hosp":
|
| 21 |
+
cols_required = ["IsHospExac", "IsHospAdmission"]
|
| 22 |
+
pharmacy_prescriptions_req = False
|
| 23 |
+
if model_type == "hosp_comm":
|
| 24 |
+
cols_required = ["IsExac", "IsHospExac", "IsCommExac", "IsHospAdmission"]
|
| 25 |
+
pharmacy_prescriptions_req = True
|
| 26 |
+
|
| 27 |
+
# Setup log file
|
| 28 |
+
log = open("./training/logging/process_exacerbation_history_" + model_type + ".log", "w")
|
| 29 |
+
sys.stdout = log
|
| 30 |
+
|
| 31 |
+
# Dataset to process - set through config file
|
| 32 |
+
data_to_process = config["model_settings"]["data_to_process"]
|
| 33 |
+
|
| 34 |
+
# Load cohort data
|
| 35 |
+
if data_to_process == "forward_val":
|
| 36 |
+
data = pd.read_pickle("./data/patient_labels_forward_val_hosp_comm.pkl")
|
| 37 |
+
patient_details = pd.read_pickle("./data/patient_details_forward_val.pkl")
|
| 38 |
+
else:
|
| 39 |
+
data = pd.read_pickle("./data/patient_labels_" + model_type + ".pkl")
|
| 40 |
+
patient_details = pd.read_pickle("./data/patient_details.pkl")
|
| 41 |
+
data = data[["StudyId", "IndexDate"]]
|
| 42 |
+
data = data.merge(
|
| 43 |
+
patient_details[["StudyId", "PatientId", "FirstSubmissionDate"]],
|
| 44 |
+
on="StudyId",
|
| 45 |
+
how="left",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Read mapping between StudyId and SafeHavenID
|
| 49 |
+
id_mapping = pd.read_pickle("./data/sh_to_studyid_mapping.pkl")
|
| 50 |
+
|
| 51 |
+
# Remove mapping for patient SU125 as the mapping for this patient is incorrect
|
| 52 |
+
id_mapping["SafeHavenID"] = np.where(
|
| 53 |
+
id_mapping["StudyId"] == "SU125", np.NaN, id_mapping["SafeHavenID"]
|
| 54 |
+
)
|
| 55 |
+
id_mapping = id_mapping.merge(
|
| 56 |
+
data[["StudyId"]], on="StudyId", how="inner"
|
| 57 |
+
).drop_duplicates()
|
| 58 |
+
print(
|
| 59 |
+
"Num patients with SafeHaven mapping: {} of {}".format(
|
| 60 |
+
len(id_mapping), data.StudyId.nunique()
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Add column with SafeHavenID to main df
|
| 65 |
+
data = data.merge(id_mapping, on="StudyId", how="left")
|
| 66 |
+
|
| 67 |
+
# Calculate the lookback start date. Will need this to aggreggate data for model
|
| 68 |
+
# features
|
| 69 |
+
data["LookbackStartDate"] = data["IndexDate"] - pd.DateOffset(
|
| 70 |
+
days=config["model_settings"]["lookback_period"]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
############################################################################
|
| 74 |
+
# Derive features from patient history
|
| 75 |
+
############################################################################
|
| 76 |
+
#########################################
|
| 77 |
+
# Num previous exacerbations/admissions
|
| 78 |
+
#########################################
|
| 79 |
+
exacs = pd.read_pickle("./data/{}_exacs.pkl".format(model_type))
|
| 80 |
+
exacs = exacs.fillna(0)
|
| 81 |
+
print(exacs.columns)
|
| 82 |
+
print(data.columns)
|
| 83 |
+
exacs = data[["StudyId", "PatientId", "LookbackStartDate", "IndexDate"]].merge(
|
| 84 |
+
exacs, on=["StudyId", "PatientId"], how="left"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Calculate total number of exacerbations prior to index date and normalise by length in
|
| 88 |
+
# service
|
| 89 |
+
exac_counts_tot = exacs[exacs["DateOfEvent"] < exacs["IndexDate"]]
|
| 90 |
+
exac_counts_tot = exac_counts_tot.groupby(["StudyId", "IndexDate"])["IsExac"].sum()
|
| 91 |
+
exac_counts_tot = pd.DataFrame(exac_counts_tot).reset_index()
|
| 92 |
+
exac_counts_tot = exac_counts_tot.merge(
|
| 93 |
+
patient_details[["StudyId", "FirstSubmissionDate"]], on="StudyId", how="left"
|
| 94 |
+
)
|
| 95 |
+
exac_counts_tot["LengthInServiceBeforeIndex"] = (
|
| 96 |
+
exac_counts_tot["IndexDate"] - exac_counts_tot["FirstSubmissionDate"]
|
| 97 |
+
).dt.days
|
| 98 |
+
exac_counts_tot["LengthInServiceBeforeIndex"] = (
|
| 99 |
+
exac_counts_tot["LengthInServiceBeforeIndex"] / 30
|
| 100 |
+
)
|
| 101 |
+
exac_counts_tot["TotExacPerMonthBeforeIndex"] = (
|
| 102 |
+
exac_counts_tot["IsExac"] / exac_counts_tot["LengthInServiceBeforeIndex"]
|
| 103 |
+
)
|
| 104 |
+
exac_counts_tot = exac_counts_tot.drop(
|
| 105 |
+
columns=["IsExac", "FirstSubmissionDate", "LengthInServiceBeforeIndex"]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Calculate number of previous exacerbations in 6 months before index date
|
| 109 |
+
exac_counts_6mo = exacs[
|
| 110 |
+
(exacs["DateOfEvent"] >= exacs["LookbackStartDate"])
|
| 111 |
+
& (exacs["DateOfEvent"] < exacs["IndexDate"])
|
| 112 |
+
]
|
| 113 |
+
exac_counts_6mo = exac_counts_6mo.groupby(["StudyId", "IndexDate"])[cols_required].sum()
|
| 114 |
+
|
| 115 |
+
# Remove 'Is' prefix and add 'Num' prefix and 'Prior6mo' suffix
|
| 116 |
+
new_col_names = []
|
| 117 |
+
for col in cols_required:
|
| 118 |
+
base_col_name = re.findall(r"[A-Z][^A-Z]*", col)
|
| 119 |
+
base_col_name.pop(0)
|
| 120 |
+
base_col_name = "".join(base_col_name)
|
| 121 |
+
new_col_names.append("Num" + base_col_name + "Prior6mo")
|
| 122 |
+
|
| 123 |
+
# Rename columns and merge to main df
|
| 124 |
+
exac_counts_6mo = exac_counts_6mo.rename(
|
| 125 |
+
columns=dict(zip(cols_required, new_col_names))
|
| 126 |
+
).reset_index()
|
| 127 |
+
data = data.merge(exac_counts_6mo, on=["StudyId", "IndexDate"], how="left")
|
| 128 |
+
data = data.merge(exac_counts_tot, on=["StudyId", "IndexDate"], how="left")
|
| 129 |
+
data = data.fillna(0)
|
| 130 |
+
|
| 131 |
+
#########################################
|
| 132 |
+
# Days since previous exacerbation
|
| 133 |
+
#########################################
|
| 134 |
+
# Calculate the number of days since last exacerbation before index date
|
| 135 |
+
days_since_exac = exacs[exacs["DateOfEvent"] < exacs["IndexDate"]]
|
| 136 |
+
days_since_exac = days_since_exac[days_since_exac[cols_required[0]] == 1]
|
| 137 |
+
days_since_exac = days_since_exac.sort_values(
|
| 138 |
+
by=["StudyId", "IndexDate", "DateOfEvent"], ascending=False
|
| 139 |
+
)
|
| 140 |
+
days_since_exac = days_since_exac.drop_duplicates(
|
| 141 |
+
subset=["StudyId", "IndexDate"], keep="first"
|
| 142 |
+
)
|
| 143 |
+
days_since_exac["DaysSinceLastExac"] = (
|
| 144 |
+
days_since_exac["IndexDate"] - days_since_exac["DateOfEvent"]
|
| 145 |
+
).dt.days
|
| 146 |
+
data = data.merge(
|
| 147 |
+
days_since_exac[["StudyId", "IndexDate", "DaysSinceLastExac"]],
|
| 148 |
+
on=["StudyId", "IndexDate"],
|
| 149 |
+
how="left",
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# If patient have missing values in DaysSinceLastExac, find exacerbations in SafeHaven
|
| 153 |
+
# data as service data is only manually entered up to one year before onboarding
|
| 154 |
+
missing_data_to_lookup = data[data["DaysSinceLastExac"].isna()].drop_duplicates(
|
| 155 |
+
subset="StudyId", keep="first"
|
| 156 |
+
)
|
| 157 |
+
missing_data_to_lookup = missing_data_to_lookup[
|
| 158 |
+
["StudyId", "SafeHavenID", "IndexDate", "FirstSubmissionDate"]
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
# For both parallel models 1 and 2 find hospital exacerbations in SMR data for patients
|
| 162 |
+
# who have missing data in the DaysSinceLastExac column
|
| 163 |
+
smr = pd.read_csv(config["inputs"]["raw_data_paths"]["admissions"])
|
| 164 |
+
smr = missing_data_to_lookup.merge(smr, on="SafeHavenID", how="left")
|
| 165 |
+
|
| 166 |
+
# Only find exacerbations prior to onboarding
|
| 167 |
+
smr = smr.rename(columns={"ADMDATE": "DateOfEvent"})
|
| 168 |
+
smr["DateOfEvent"] = pd.to_datetime(smr["DateOfEvent"], utc=True)
|
| 169 |
+
smr = smr[smr["DateOfEvent"] < smr["FirstSubmissionDate"]]
|
| 170 |
+
|
| 171 |
+
# COPD admission defined as: admission with J40-J44 under principal diagnosis or J20 in
|
| 172 |
+
# the principal diagnosis with one of J41-J44 in secondary diagnosis field
|
| 173 |
+
principal_diag = [
|
| 174 |
+
"Bronchitis, not specified as acute or chronic",
|
| 175 |
+
"chronic bronchitis",
|
| 176 |
+
"Emphysema",
|
| 177 |
+
"MacLeod syndrome",
|
| 178 |
+
"chronic obstructive pulmonary disease",
|
| 179 |
+
]
|
| 180 |
+
principal_diag_alt = ["Acute bronchitis"]
|
| 181 |
+
secondary_diag_alt = [
|
| 182 |
+
"chronic bronchitis",
|
| 183 |
+
"Emphysema",
|
| 184 |
+
"MacLeod syndrome",
|
| 185 |
+
"chronic obstructive pulmonary disease",
|
| 186 |
+
]
|
| 187 |
+
condition_primary = smr["DIAG1Desc"].str.contains(
|
| 188 |
+
r"\b(?:" + "|".join(principal_diag) + r")\b", case=False, regex=True
|
| 189 |
+
)
|
| 190 |
+
condition_secondary = (
|
| 191 |
+
smr["DIAG1Desc"].str.contains(
|
| 192 |
+
r"\b(?:" + "|".join(principal_diag_alt) + r")\b", case=False, regex=True
|
| 193 |
+
)
|
| 194 |
+
) & (
|
| 195 |
+
smr["DIAG2Desc"].str.contains(
|
| 196 |
+
r"\b(?:" + "|".join(secondary_diag_alt) + r")\b", case=False, regex=True
|
| 197 |
+
)
|
| 198 |
+
)
|
| 199 |
+
smr["COPD_admission_smr"] = np.where(condition_primary, True, False)
|
| 200 |
+
smr["COPD_admission_smr"] = np.where(
|
| 201 |
+
condition_secondary, True, smr["COPD_admission_smr"]
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
smr = smr[smr["COPD_admission_smr"] == True]
|
| 205 |
+
|
| 206 |
+
# Find rescue med prescriptions prior to onboarding for parallel model 1 (where both
|
| 207 |
+
# hospital and community exacerbations are used)
|
| 208 |
+
if pharmacy_prescriptions_req is True:
|
| 209 |
+
# Read pharmacy data and filter for model cohort
|
| 210 |
+
pharmacy = pd.read_csv(config["inputs"]["raw_data_paths"]["prescribing"])
|
| 211 |
+
pharmacy = missing_data_to_lookup.merge(pharmacy, on="SafeHavenID", how="left")
|
| 212 |
+
|
| 213 |
+
# Pull out rescue med prescriptions only
|
| 214 |
+
steroid_codes = [
|
| 215 |
+
"0603020T0AAACAC",
|
| 216 |
+
"0603020T0AABKBK",
|
| 217 |
+
"0603020T0AAAXAX",
|
| 218 |
+
"0603020T0AAAGAG",
|
| 219 |
+
"0603020T0AABHBH",
|
| 220 |
+
"0603020T0AAACAC",
|
| 221 |
+
"0603020T0AABKBK",
|
| 222 |
+
"0603020T0AABNBN",
|
| 223 |
+
"0603020T0AAAGAG",
|
| 224 |
+
"0603020T0AABHBH",
|
| 225 |
+
]
|
| 226 |
+
antibiotic_codes = [
|
| 227 |
+
"0501013B0AAAAAA",
|
| 228 |
+
"0501013B0AAABAB",
|
| 229 |
+
"0501030I0AAABAB",
|
| 230 |
+
"0501030I0AAAAAA",
|
| 231 |
+
"0501050B0AAAAAA",
|
| 232 |
+
"0501050B0AAADAD",
|
| 233 |
+
"0501013K0AAAJAJ",
|
| 234 |
+
]
|
| 235 |
+
rescue_med_bnf_codes = steroid_codes + antibiotic_codes
|
| 236 |
+
pharmacy = pharmacy[pharmacy.PI_BNF_Item_Code.isin(rescue_med_bnf_codes)]
|
| 237 |
+
|
| 238 |
+
# Only keep rescue meds before patient onboarding
|
| 239 |
+
pharmacy = pharmacy.rename(columns={"PRESC_DATE": "DateOfEvent"})
|
| 240 |
+
pharmacy["DateOfEvent"] = pd.to_datetime(
|
| 241 |
+
pharmacy["DateOfEvent"], utc=True
|
| 242 |
+
).dt.normalize()
|
| 243 |
+
pharmacy = pharmacy[pharmacy["DateOfEvent"] < pharmacy["FirstSubmissionDate"]]
|
| 244 |
+
|
| 245 |
+
# Combine pharmacy data with smr admissions data
|
| 246 |
+
smr = pd.concat([smr, pharmacy])
|
| 247 |
+
|
| 248 |
+
# Calculate the days since last exacerbation
|
| 249 |
+
smr["DaysSinceLastExac"] = (smr["IndexDate"] - smr["DateOfEvent"]).dt.days
|
| 250 |
+
smr = smr.sort_values(by=["StudyId", "IndexDate", "DaysSinceLastExac"], ascending=True)
|
| 251 |
+
smr = smr.drop_duplicates(subset=["StudyId", "IndexDate"], keep="first")
|
| 252 |
+
|
| 253 |
+
# Merge back to main df
|
| 254 |
+
data = data.merge(
|
| 255 |
+
smr[["StudyId", "IndexDate", "DaysSinceLastExac"]],
|
| 256 |
+
on=["StudyId", "IndexDate"],
|
| 257 |
+
how="left",
|
| 258 |
+
)
|
| 259 |
+
data["DaysSinceLastExac"] = np.where(
|
| 260 |
+
data["DaysSinceLastExac_x"].notnull(),
|
| 261 |
+
data["DaysSinceLastExac_x"],
|
| 262 |
+
data["DaysSinceLastExac_y"],
|
| 263 |
+
)
|
| 264 |
+
data = data.drop(columns=["DaysSinceLastExac_x", "DaysSinceLastExac_y"])
|
| 265 |
+
print(
|
| 266 |
+
"Number of patients with missing DaysSinceLastExac:{}".format(
|
| 267 |
+
data[data["DaysSinceLastExac"].isna()].StudyId.nunique()
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Bin days since last exacerbation
|
| 272 |
+
exac_bins = [0, 21, 90, 180, np.inf]
|
| 273 |
+
exac_labels = ["<21 days", "21 - 89 days", "90 - 179 days", ">= 180 days"]
|
| 274 |
+
data["DaysSinceLastExac"] = model_h.bin_numeric_column(
|
| 275 |
+
col=data["DaysSinceLastExac"], bins=exac_bins, labels=exac_labels
|
| 276 |
+
)
|
| 277 |
+
# If DaysSinceLastExac is nan, put into >= 180 category
|
| 278 |
+
data["DaysSinceLastExac"] = data["DaysSinceLastExac"].replace("nan", ">= 180 days")
|
| 279 |
+
|
| 280 |
+
data = data.drop(columns=["FirstSubmissionDate", "LookbackStartDate", "PatientId"])
|
| 281 |
+
|
| 282 |
+
# Save data
|
| 283 |
+
os.makedirs(config["outputs"]["processed_data_dir"], exist_ok=True)
|
| 284 |
+
if data_to_process == "forward_val":
|
| 285 |
+
data.to_pickle(
|
| 286 |
+
os.path.join(
|
| 287 |
+
config["outputs"]["processed_data_dir"],
|
| 288 |
+
"exac_history_forward_val_" + model_type + ".pkl",
|
| 289 |
+
)
|
| 290 |
+
)
|
| 291 |
+
else:
|
| 292 |
+
data.to_pickle(
|
| 293 |
+
os.path.join(
|
| 294 |
+
config["outputs"]["processed_data_dir"],
|
| 295 |
+
"exac_history_" + model_type + ".pkl",
|
| 296 |
+
)
|
| 297 |
+
)
|
training/process_labs.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Derive features from lab tests for 2 models:
|
| 3 |
+
Parallel model 1: uses both hospital and community exacerbation events
|
| 4 |
+
Parallel model 2: uses only hospital exacerbation events
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import model_h
|
| 12 |
+
import ggc.preprocessing.labs as labs_preprocessing
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def calc_lab_metric(lab_df, data, lab_name, metric, weigh_data_by_recency=False):
|
| 17 |
+
"""
|
| 18 |
+
Calculate metrics on laboratory data.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
lab_df (pd.DataFrame): dataframe containing labs to be used in calculations.
|
| 22 |
+
data (pd.DataFrame): main dataframe to which columns containing the results from
|
| 23 |
+
the lab calculations are merged onto.
|
| 24 |
+
lab_name (list): name of labs required for metric calculations.
|
| 25 |
+
metric (str): name of metric to be calculated. The possible metrics are:
|
| 26 |
+
'MaxLifetime': calculates the maximum value of lab for patient within
|
| 27 |
+
entire dataset before their index date.
|
| 28 |
+
'MinLifetime': calculates the minimum value of lab for patient within
|
| 29 |
+
entire dataset before their index date.
|
| 30 |
+
'Max1Year': calculates the maximum value of lab for patient within 1
|
| 31 |
+
year prior to index date.
|
| 32 |
+
'Min1Year': calculates the maximum value of lab for patient within 1
|
| 33 |
+
year prior to index date.
|
| 34 |
+
'Latest': finds the closest lab value prior to index date.
|
| 35 |
+
weigh_data_by_recency (bool): option to weigh data based on how recent it is. Older
|
| 36 |
+
observations are decreased or increased towards the median. Defaults to False.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
pd.DataFrame: the input dataframe with additional columns with calculated
|
| 40 |
+
metrics.
|
| 41 |
+
"""
|
| 42 |
+
# Subset labs to only those specified in lab_names
|
| 43 |
+
cols_to_keep = ["StudyId", "IndexDate", "TimeSinceLab"]
|
| 44 |
+
cols_to_keep.append(lab_name)
|
| 45 |
+
labs_calc = lab_df[cols_to_keep]
|
| 46 |
+
|
| 47 |
+
# Subset labs to correct time frames and calculate metrics
|
| 48 |
+
if (metric == "Max1Year") | (metric == "Min1Year"):
|
| 49 |
+
labs_calc = labs_calc[labs_calc["TimeSinceLab"] <= 365]
|
| 50 |
+
if (metric == "MaxLifetime") | (metric == "Max1Year"):
|
| 51 |
+
labs_calc = labs_calc.groupby(["StudyId", "IndexDate"]).max()
|
| 52 |
+
if (metric == "MinLifetime") | (metric == "Min1Year"):
|
| 53 |
+
labs_calc = labs_calc.groupby(["StudyId", "IndexDate"]).min()
|
| 54 |
+
labs_calc = labs_calc.drop(columns=["TimeSinceLab"])
|
| 55 |
+
if metric == "Latest":
|
| 56 |
+
labs_calc = labs_calc[labs_calc["TimeSinceLab"] <= 365]
|
| 57 |
+
labs_calc = labs_calc.sort_values(
|
| 58 |
+
by=["StudyId", "IndexDate", "TimeSinceLab"], ascending=True
|
| 59 |
+
)
|
| 60 |
+
labs_calc["TimeSinceLab"] = np.where(
|
| 61 |
+
labs_calc[lab_name].isna(), np.NaN, labs_calc["TimeSinceLab"]
|
| 62 |
+
)
|
| 63 |
+
labs_calc = labs_calc.bfill()
|
| 64 |
+
labs_calc = labs_calc.drop_duplicates(
|
| 65 |
+
subset=["StudyId", "IndexDate"], keep="first"
|
| 66 |
+
)
|
| 67 |
+
if weigh_data_by_recency is True:
|
| 68 |
+
median_val = labs_calc[lab_name].median()
|
| 69 |
+
labs_calc = model_h.weigh_features_by_recency(
|
| 70 |
+
df=labs_calc,
|
| 71 |
+
feature=lab_name,
|
| 72 |
+
feature_recency_days="TimeSinceLab",
|
| 73 |
+
median_value=median_val,
|
| 74 |
+
decay_rate=0.001,
|
| 75 |
+
)
|
| 76 |
+
labs_calc = labs_calc.set_index(["StudyId", "IndexDate"])
|
| 77 |
+
|
| 78 |
+
# Add prefix to lab names and merge with main df
|
| 79 |
+
labs_calc = labs_calc.add_prefix(metric)
|
| 80 |
+
labs_calc = labs_calc.reset_index()
|
| 81 |
+
data = data.merge(labs_calc, on=["StudyId", "IndexDate"], how="left")
|
| 82 |
+
return data
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
with open("./training/config.yaml", "r") as config:
|
| 86 |
+
config = yaml.safe_load(config)
|
| 87 |
+
|
| 88 |
+
# Specify which model to generate features for
|
| 89 |
+
model_type = config["model_settings"]["model_type"]
|
| 90 |
+
|
| 91 |
+
# Setup log file
|
| 92 |
+
log = open("./training/logging/process_labs_" + model_type + ".log", "w")
|
| 93 |
+
sys.stdout = log
|
| 94 |
+
|
| 95 |
+
# Dataset to process - set through config file
|
| 96 |
+
data_to_process = config["model_settings"]["data_to_process"]
|
| 97 |
+
|
| 98 |
+
# Load cohort data
|
| 99 |
+
if data_to_process == "forward_val":
|
| 100 |
+
data = pd.read_pickle("./data/patient_labels_forward_val_hosp_comm.pkl")
|
| 101 |
+
patient_details = pd.read_pickle("./data/patient_details_forward_val.pkl")
|
| 102 |
+
else:
|
| 103 |
+
data = pd.read_pickle("./data/patient_labels_" + model_type + ".pkl")
|
| 104 |
+
patient_details = pd.read_pickle("./data/patient_details.pkl")
|
| 105 |
+
data = data[["StudyId", "IndexDate"]]
|
| 106 |
+
patient_details = data.merge(
|
| 107 |
+
patient_details[["StudyId", "PatientId"]],
|
| 108 |
+
on="StudyId",
|
| 109 |
+
how="left",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Read mapping between StudyId and SafeHavenID
|
| 113 |
+
id_mapping = pd.read_pickle("./data/sh_to_studyid_mapping.pkl")
|
| 114 |
+
|
| 115 |
+
# Remove mapping for patient SU125 as the mapping for this patient is incorrect
|
| 116 |
+
id_mapping["SafeHavenID"] = np.where(
|
| 117 |
+
id_mapping["StudyId"] == "SU125", np.NaN, id_mapping["SafeHavenID"]
|
| 118 |
+
)
|
| 119 |
+
id_mapping = id_mapping.merge(
|
| 120 |
+
data[["StudyId"]], on="StudyId", how="inner"
|
| 121 |
+
).drop_duplicates()
|
| 122 |
+
print(
|
| 123 |
+
"Num patients with SafeHaven mapping: {} of {}".format(
|
| 124 |
+
len(id_mapping), data.StudyId.nunique()
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Add column with SafeHavenID to main df
|
| 129 |
+
patient_details = patient_details.merge(id_mapping, on="StudyId", how="left")
|
| 130 |
+
|
| 131 |
+
# Calculate the lookback start date. Will need this to aggreggate data for model
|
| 132 |
+
# features
|
| 133 |
+
patient_details["LookbackStartDate"] = patient_details["IndexDate"] - pd.DateOffset(
|
| 134 |
+
days=config["model_settings"]["lookback_period"]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
############################################################################
|
| 138 |
+
# Derive features from labs
|
| 139 |
+
############################################################################
|
| 140 |
+
# Convert column names into format required for labs processing using the ggc package
|
| 141 |
+
cols_to_use = [
|
| 142 |
+
"SafeHavenID",
|
| 143 |
+
"ClinicalCodeDescription",
|
| 144 |
+
"QuantityUnit",
|
| 145 |
+
"RangeHighValue",
|
| 146 |
+
"RangeLowValue",
|
| 147 |
+
"QuantityValue",
|
| 148 |
+
"SampleDate",
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
labs = pd.read_csv(config["inputs"]["raw_data_paths"]["labs"], usecols=cols_to_use)
|
| 152 |
+
|
| 153 |
+
# Subset labs table to only patients of interest
|
| 154 |
+
labs = labs[labs.SafeHavenID.isin(patient_details.SafeHavenID)]
|
| 155 |
+
|
| 156 |
+
# Process labs
|
| 157 |
+
lookup_table = pd.read_csv(config["inputs"]["raw_data_paths"]["labs_lookup_table"])
|
| 158 |
+
tests_of_interest = [
|
| 159 |
+
"Eosinophils",
|
| 160 |
+
"Albumin",
|
| 161 |
+
"Neutrophils",
|
| 162 |
+
"White Blood Count",
|
| 163 |
+
"Lymphocytes",
|
| 164 |
+
]
|
| 165 |
+
labs_processed = labs_preprocessing.clean_labs_data(
|
| 166 |
+
df=labs,
|
| 167 |
+
tests_of_interest=tests_of_interest,
|
| 168 |
+
units_lookup=lookup_table,
|
| 169 |
+
print_log=True,
|
| 170 |
+
)
|
| 171 |
+
labs_processed = patient_details[["StudyId", "IndexDate", "SafeHavenID"]].merge(
|
| 172 |
+
labs_processed, on="SafeHavenID", how="left"
|
| 173 |
+
)
|
| 174 |
+
labs_processed["SampleDate"] = pd.to_datetime(labs_processed["SampleDate"], utc=True)
|
| 175 |
+
labs_processed["TimeSinceLab"] = (
|
| 176 |
+
labs_processed["IndexDate"] - labs_processed["SampleDate"]
|
| 177 |
+
).dt.days
|
| 178 |
+
|
| 179 |
+
# Only keep labs performed before IndexDate
|
| 180 |
+
labs_processed = labs_processed[labs_processed["TimeSinceLab"] >= 0]
|
| 181 |
+
|
| 182 |
+
# Convert lab names to columns
|
| 183 |
+
labs_processed = pd.pivot_table(
|
| 184 |
+
labs_processed,
|
| 185 |
+
values="QuantityValue",
|
| 186 |
+
index=["StudyId", "IndexDate", "TimeSinceLab"],
|
| 187 |
+
columns=["ClinicalCodeDescription"],
|
| 188 |
+
)
|
| 189 |
+
labs_processed = labs_processed.reset_index()
|
| 190 |
+
|
| 191 |
+
# Calculate neutrophil/lymphocyte ratio
|
| 192 |
+
labs_processed["NeutLymphRatio"] = (
|
| 193 |
+
labs_processed["Neutrophils"] / labs_processed["Lymphocytes"]
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Calculate lowest albumin in past year
|
| 197 |
+
data = calc_lab_metric(labs_processed, data, lab_name="Albumin", metric="Min1Year")
|
| 198 |
+
|
| 199 |
+
# Calculate the latest lab value
|
| 200 |
+
lab_names = [
|
| 201 |
+
"NeutLymphRatio",
|
| 202 |
+
"Albumin",
|
| 203 |
+
"Eosinophils",
|
| 204 |
+
"Neutrophils",
|
| 205 |
+
"White Blood Count",
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
for lab_name in lab_names:
|
| 209 |
+
data = calc_lab_metric(
|
| 210 |
+
labs_processed, data, lab_name, metric="Latest", weigh_data_by_recency=True
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Save data
|
| 214 |
+
os.makedirs(config["outputs"]["processed_data_dir"], exist_ok=True)
|
| 215 |
+
if data_to_process == "forward_val":
|
| 216 |
+
data.to_pickle(
|
| 217 |
+
os.path.join(
|
| 218 |
+
config["outputs"]["processed_data_dir"],
|
| 219 |
+
"labs_forward_val_" + model_type + ".pkl",
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
data.to_pickle(
|
| 224 |
+
os.path.join(
|
| 225 |
+
config["outputs"]["processed_data_dir"],
|
| 226 |
+
"labs_" + model_type + ".pkl",
|
| 227 |
+
)
|
| 228 |
+
)
|
training/process_pros.py
ADDED
|
@@ -0,0 +1,1031 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Derive features PRO responses for 2 models:
|
| 3 |
+
Parallel model 1: uses both hospital and community exacerbation events
|
| 4 |
+
Parallel model 2: uses only hospital exacerbation events
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def calc_total_pro_engagement(pro_df, pro_name):
|
| 17 |
+
"""
|
| 18 |
+
Calculates PRO engagement per patient across their entire time within the service.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
pro_df (pd.DataFrame): dataframe containing the onboarding date and the latest
|
| 22 |
+
prediction date.
|
| 23 |
+
pro_name (str): name of the PRO.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
pd.DataFrame: the input dateframe with an additional column stating the total
|
| 27 |
+
engagement for each patient across the service.
|
| 28 |
+
"""
|
| 29 |
+
# Calculate time in service according to type of PRO
|
| 30 |
+
if pro_name == "EQ5D":
|
| 31 |
+
date_unit = "M"
|
| 32 |
+
if pro_name == "MRC":
|
| 33 |
+
date_unit = "W"
|
| 34 |
+
if (pro_name == "CAT") | (pro_name == "SymptomDiary"):
|
| 35 |
+
date_unit = "D"
|
| 36 |
+
pro_df["TimeInService"] = np.floor(
|
| 37 |
+
(
|
| 38 |
+
(pro_df.LatestPredictionDate - pro_df.FirstSubmissionDate)
|
| 39 |
+
/ np.timedelta64(1, date_unit)
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# PRO engagement for the total time in service
|
| 44 |
+
pro_response_count = pro_df.groupby("StudyId").count()[["PatientId"]].reset_index()
|
| 45 |
+
pro_response_count = pro_response_count.rename(
|
| 46 |
+
columns={"PatientId": "Response" + pro_name}
|
| 47 |
+
)
|
| 48 |
+
pro_df = pro_df.merge(pro_response_count, on="StudyId", how="left")
|
| 49 |
+
pro_df["TotalEngagement" + pro_name] = round(
|
| 50 |
+
pro_df["Response" + pro_name] / pro_df["TimeInService"], 2
|
| 51 |
+
)
|
| 52 |
+
return pro_df
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def calc_pro_engagement_in_time_window(pro_df, pro_name, time_window, data):
|
| 56 |
+
"""
|
| 57 |
+
Calculates PRO engagement per patient across a specified time window. The time
|
| 58 |
+
window is in format 'months', and consists of the specified time period prior to
|
| 59 |
+
IndexDate.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
pro_df (pd.DataFrame): dataframe containing the index dates and PRO response
|
| 63 |
+
submission dates.
|
| 64 |
+
pro_name (str): name of the PRO.
|
| 65 |
+
time_window (int): number of months in which to calculate PRO engagement.
|
| 66 |
+
data (pd.DataFrame): main dataframe.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
pd.DataFrame: a dataframe containing the calculated PRO engagement.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
# Calculate time in service according to type of PRO.
|
| 73 |
+
if pro_name == "EQ5D":
|
| 74 |
+
unit_val = 1
|
| 75 |
+
if pro_name == "MRC":
|
| 76 |
+
unit_val = 4
|
| 77 |
+
if (pro_name == "CAT") | (pro_name == "SymptomDiary"):
|
| 78 |
+
unit_val = 30
|
| 79 |
+
|
| 80 |
+
pro_df["SubmissionTime"] = pd.to_datetime(pro_df["SubmissionTime"], utc=True)
|
| 81 |
+
pro_engagement_6mo = pro_df.copy()
|
| 82 |
+
pro_engagement_6mo["TimeSinceSubmission"] = (
|
| 83 |
+
pro_engagement_6mo["IndexDate"] - pro_engagement_6mo["SubmissionTime"]
|
| 84 |
+
).dt.days
|
| 85 |
+
|
| 86 |
+
# Only include PRO responses within the specified time window
|
| 87 |
+
pro_engagement_6mo = pro_engagement_6mo[
|
| 88 |
+
pro_engagement_6mo["TimeSinceSubmission"].between(
|
| 89 |
+
0, (time_window * 30), inclusive="both"
|
| 90 |
+
)
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
# Calculate number of PRO responses within specified time window
|
| 94 |
+
pro_engagement_6mo = (
|
| 95 |
+
pro_engagement_6mo.groupby(["StudyId", "IndexDate"])
|
| 96 |
+
.count()[["PatientId"]]
|
| 97 |
+
.reset_index()
|
| 98 |
+
)
|
| 99 |
+
pro_engagement_6mo = pro_engagement_6mo.rename(
|
| 100 |
+
columns={"PatientId": "ResponseCountTW" + str(time_window)}
|
| 101 |
+
)
|
| 102 |
+
pro_engagement_6mo["Engagement" + pro_name + "TW" + str(time_window)] = round(
|
| 103 |
+
pro_engagement_6mo["ResponseCountTW" + str(time_window)]
|
| 104 |
+
/ (time_window * unit_val),
|
| 105 |
+
2,
|
| 106 |
+
)
|
| 107 |
+
pro_engagement_6mo = data[["StudyId", "IndexDate"]].merge(
|
| 108 |
+
pro_engagement_6mo, on=["StudyId", "IndexDate"], how="left"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Fill N/As with 0 as no engagement was observed for those patients
|
| 112 |
+
pro_engagement_6mo = pro_engagement_6mo.fillna(0)
|
| 113 |
+
return pro_engagement_6mo
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def calc_pro_engagement_at_specific_month(pro_df, pro_name, month_num, data):
|
| 117 |
+
# Calculate time in service according to type of PRO.
|
| 118 |
+
if pro_name == "EQ5D":
|
| 119 |
+
unit_val = 1
|
| 120 |
+
if pro_name == "MRC":
|
| 121 |
+
unit_val = 4
|
| 122 |
+
if (pro_name == "CAT") | (pro_name == "SymptomDiary"):
|
| 123 |
+
unit_val = 30
|
| 124 |
+
|
| 125 |
+
pro_df["SubmissionTime"] = pd.to_datetime(pro_df["SubmissionTime"], utc=True)
|
| 126 |
+
pro_engagement = pro_df.copy()
|
| 127 |
+
pro_engagement["TimeSinceSubmission"] = (
|
| 128 |
+
pro_engagement["IndexDate"] - pro_engagement["SubmissionTime"]
|
| 129 |
+
).dt.days
|
| 130 |
+
|
| 131 |
+
# Only include PRO responses for the month specified
|
| 132 |
+
# Calculate the number of months between index date and specified month
|
| 133 |
+
months_between_index_and_specified = month_num - 1
|
| 134 |
+
pro_engagement = pro_engagement[
|
| 135 |
+
pro_engagement["TimeSinceSubmission"].between(
|
| 136 |
+
(months_between_index_and_specified * 30),
|
| 137 |
+
(month_num * 30),
|
| 138 |
+
inclusive="both",
|
| 139 |
+
)
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
# Calculate number of PRO responses within specified time window
|
| 143 |
+
pro_engagement = (
|
| 144 |
+
pro_engagement.groupby(["StudyId", "IndexDate"])
|
| 145 |
+
.count()[["PatientId"]]
|
| 146 |
+
.reset_index()
|
| 147 |
+
)
|
| 148 |
+
pro_engagement = pro_engagement.rename(
|
| 149 |
+
columns={"PatientId": "ResponseCountMonth" + str(month_num)}
|
| 150 |
+
)
|
| 151 |
+
pro_engagement["Engagement" + pro_name + "Month" + str(month_num)] = round(
|
| 152 |
+
pro_engagement["ResponseCountMonth" + str(month_num)] / (1 * unit_val),
|
| 153 |
+
2,
|
| 154 |
+
)
|
| 155 |
+
pro_engagement = data[["StudyId", "IndexDate"]].merge(
|
| 156 |
+
pro_engagement, on=["StudyId", "IndexDate"], how="left"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Fill N/As with 0 as no engagement was observed for those patients
|
| 160 |
+
pro_engagement = pro_engagement.fillna(0)
|
| 161 |
+
return pro_engagement
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def calc_last_pro_score(pro_df, pro_name):
|
| 165 |
+
"""
|
| 166 |
+
Calculates the most recent PRO response. The latest PRO score is set to be within 2
|
| 167 |
+
months of the index date to allow recency of data without having many missing
|
| 168 |
+
values.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
pro_df (pd.DataFrame): dataframe containing the index dates and PRO response
|
| 172 |
+
submission dates.
|
| 173 |
+
pro_name (str): name of the PRO.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
pd.DataFrame: the input dateframe with additional columns stating the latest PRO
|
| 177 |
+
score for each PRO question.
|
| 178 |
+
"""
|
| 179 |
+
# Calculate last PRO score
|
| 180 |
+
pro_df["TimeSinceSubmission"] = (
|
| 181 |
+
pro_df["IndexDate"] - pro_df["SubmissionTime"]
|
| 182 |
+
).dt.days
|
| 183 |
+
pro_df = pro_df[pro_df["TimeSinceSubmission"] > 0]
|
| 184 |
+
pro_df = pro_df.sort_values(
|
| 185 |
+
by=["StudyId", "IndexDate", "TimeSinceSubmission"], ascending=True
|
| 186 |
+
)
|
| 187 |
+
latest_pro = pro_df.drop_duplicates(subset=["StudyId", "IndexDate"], keep="first")
|
| 188 |
+
|
| 189 |
+
# Ensure that the latest PRO Score is within 2 months of the index date
|
| 190 |
+
latest_pro = latest_pro[latest_pro["TimeSinceSubmission"] <= 365]
|
| 191 |
+
|
| 192 |
+
# Select specific columns
|
| 193 |
+
question_cols = latest_pro.columns[
|
| 194 |
+
latest_pro.columns.str.startswith(pro_name)
|
| 195 |
+
].tolist()
|
| 196 |
+
question_cols.extend(
|
| 197 |
+
["StudyId", "IndexDate", "Score", "SubmissionTime", "TimeSinceSubmission"]
|
| 198 |
+
)
|
| 199 |
+
latest_pro = latest_pro[question_cols]
|
| 200 |
+
|
| 201 |
+
# if pro_name == "EQ5D":
|
| 202 |
+
# median_val_q1 = latest_pro["EQ5DScoreWithoutQ6"].median()
|
| 203 |
+
# print(median_val_q1)
|
| 204 |
+
# latest_pro = weigh_features_by_recency(
|
| 205 |
+
# df=latest_pro,
|
| 206 |
+
# feature="EQ5DScoreWithoutQ6",
|
| 207 |
+
# feature_recency_days="TimeSinceSubmission",
|
| 208 |
+
# median_value=median_val_q1,
|
| 209 |
+
# decay_rate=0.001,
|
| 210 |
+
# )
|
| 211 |
+
# print(latest_pro.columns)
|
| 212 |
+
#
|
| 213 |
+
# # Add prefix to question columns
|
| 214 |
+
# cols_to_rename = latest_pro.columns[
|
| 215 |
+
# ~latest_pro.columns.isin(
|
| 216 |
+
# ["StudyId", "IndexDate", "Score", "SubmissionTime"]
|
| 217 |
+
# )
|
| 218 |
+
# ]
|
| 219 |
+
# latest_pro = latest_pro.rename(
|
| 220 |
+
# columns=dict(zip(cols_to_rename, "Latest" + cols_to_rename))
|
| 221 |
+
# )
|
| 222 |
+
#
|
| 223 |
+
# # Rename columns where prefix not added
|
| 224 |
+
# latest_pro = latest_pro.rename(
|
| 225 |
+
# columns={
|
| 226 |
+
# "Score": "Latest" + pro_name + "Score",
|
| 227 |
+
# "SubmissionTime": "LatestPRODate",
|
| 228 |
+
# }
|
| 229 |
+
# )
|
| 230 |
+
#
|
| 231 |
+
# elif pro_name == "MRC":
|
| 232 |
+
# median_val_q1 = latest_pro["Score"].median()
|
| 233 |
+
# print(median_val_q1)
|
| 234 |
+
# latest_pro = weigh_features_by_recency(
|
| 235 |
+
# df=latest_pro,
|
| 236 |
+
# feature="Score",
|
| 237 |
+
# feature_recency_days="TimeSinceSubmission",
|
| 238 |
+
# median_value=median_val_q1,
|
| 239 |
+
# decay_rate=0.001,
|
| 240 |
+
# )
|
| 241 |
+
# print(latest_pro.columns)
|
| 242 |
+
|
| 243 |
+
# # Add prefix to question columns
|
| 244 |
+
# cols_to_rename = latest_pro.columns[
|
| 245 |
+
# ~latest_pro.columns.isin(
|
| 246 |
+
# ["StudyId", "IndexDate", "Score", "SubmissionTime", "ScoreWeighted"]
|
| 247 |
+
# )
|
| 248 |
+
# ]
|
| 249 |
+
# latest_pro = latest_pro.rename(
|
| 250 |
+
# columns=dict(zip(cols_to_rename, "Latest" + cols_to_rename))
|
| 251 |
+
# )
|
| 252 |
+
|
| 253 |
+
# # Rename columns where prefix not added
|
| 254 |
+
# latest_pro = latest_pro.rename(
|
| 255 |
+
# columns={
|
| 256 |
+
# "Score": "Latest" + pro_name + "Score",
|
| 257 |
+
# "ScoreWeighted": "Latest" + pro_name + "ScoreWeighted",
|
| 258 |
+
# "SubmissionTime": "LatestPRODate",
|
| 259 |
+
# }
|
| 260 |
+
# )
|
| 261 |
+
|
| 262 |
+
# else:
|
| 263 |
+
# Add prefix to question columns
|
| 264 |
+
cols_to_rename = latest_pro.columns[
|
| 265 |
+
~latest_pro.columns.isin(["StudyId", "IndexDate", "Score", "SubmissionTime"])
|
| 266 |
+
]
|
| 267 |
+
latest_pro = latest_pro.rename(
|
| 268 |
+
columns=dict(zip(cols_to_rename, "Latest" + cols_to_rename))
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Rename columns where prefix not added
|
| 272 |
+
latest_pro = latest_pro.rename(
|
| 273 |
+
columns={
|
| 274 |
+
"Score": "Latest" + pro_name + "Score",
|
| 275 |
+
"SubmissionTime": "LatestPRODate",
|
| 276 |
+
}
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
pro_df = pro_df.merge(latest_pro, on=["StudyId", "IndexDate"], how="left")
|
| 280 |
+
return pro_df
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def calc_pro_score_prior_to_latest(pro_df, pro_name, time_prior_to_latest=60):
|
| 284 |
+
"""
|
| 285 |
+
Finds the PRO score prior to the latest PRO score before index date.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
pro_df (pd.DataFrame): dataframe containing the latest PRO score and PRO
|
| 289 |
+
response submission dates.
|
| 290 |
+
pro_name (str): name of the PRO.
|
| 291 |
+
time_prior_to_latest (int, optional): time period before latest PRO score in
|
| 292 |
+
days. Default time frame set to 60 days (two months).
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
pd.DataFrame: the input dateframe with additional columns stating the previous
|
| 296 |
+
score closest to the latest PRO score for each PRO question.
|
| 297 |
+
"""
|
| 298 |
+
pro_previous = pro_df.copy()
|
| 299 |
+
pro_previous = pro_previous[
|
| 300 |
+
pro_previous["SubmissionTime"] < pro_previous["LatestPRODate"]
|
| 301 |
+
]
|
| 302 |
+
pro_previous = pro_previous.sort_values(
|
| 303 |
+
by=["StudyId", "IndexDate", "SubmissionTime"], ascending=[True, True, False]
|
| 304 |
+
)
|
| 305 |
+
pro_previous = pro_previous.drop_duplicates(
|
| 306 |
+
subset=["StudyId", "IndexDate"], keep="first"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Make sure that previous score is within two months of the LatestPRODate
|
| 310 |
+
pro_previous["TimeSinceLatestPRODate"] = (
|
| 311 |
+
pro_previous["LatestPRODate"] - pro_previous["SubmissionTime"]
|
| 312 |
+
).dt.days
|
| 313 |
+
pro_previous = pro_previous[
|
| 314 |
+
pro_previous["TimeSinceLatestPRODate"] <= time_prior_to_latest
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
# Add prefix to question columns
|
| 318 |
+
cols_to_rename = [col for col in pro_previous if col.startswith(pro_name)]
|
| 319 |
+
cols_to_rename = pro_previous[cols_to_rename].columns
|
| 320 |
+
pro_previous = pro_previous.rename(
|
| 321 |
+
columns=dict(zip(cols_to_rename, "Prev" + cols_to_rename))
|
| 322 |
+
)
|
| 323 |
+
pro_previous = pro_previous[["StudyId", "IndexDate", "Score"]].join(
|
| 324 |
+
pro_previous.filter(regex="^Prev")
|
| 325 |
+
)
|
| 326 |
+
pro_previous = pro_previous.rename(columns={"Score": "Prev" + pro_name + "Score"})
|
| 327 |
+
pro_df = pro_df.merge(pro_previous, on=["StudyId", "IndexDate"], how="left")
|
| 328 |
+
return pro_df
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def define_mapping_for_calcs(pro_name, questions, prefixes):
|
| 332 |
+
"""
|
| 333 |
+
Defines the mapping for calculations between PRO responses.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
pro_name (str): name of the PRO.
|
| 337 |
+
questions (list): question names of PRO.
|
| 338 |
+
prefixes (list): prefixes to identify which columns to use in calculations. The
|
| 339 |
+
possible prefixes are: 'Avg', 'Prev', 'LongerAvg', 'WeekPrevAvg'.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
dict: mapping that maps columns for performing calculations.
|
| 343 |
+
"""
|
| 344 |
+
# Create empty dictionary to append questions
|
| 345 |
+
mapping = defaultdict(list)
|
| 346 |
+
|
| 347 |
+
# Iterate through questions and create mapping for calculations
|
| 348 |
+
for question in questions:
|
| 349 |
+
if (pro_name == "EQ5D") | (pro_name == "MRC"):
|
| 350 |
+
map_key = "Latest" + pro_name + question
|
| 351 |
+
if (pro_name == "CAT") | (pro_name == "SymptomDiary"):
|
| 352 |
+
map_key = "WeekAvg" + pro_name + question
|
| 353 |
+
for prefix in prefixes:
|
| 354 |
+
mapping[map_key].append(prefix + pro_name + question)
|
| 355 |
+
return mapping
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def calc_pro_average(pro_df, pro_name, time_window=None, avg_period=None):
|
| 359 |
+
"""
|
| 360 |
+
Calculate the PRO average before the latest PRO score and within a specified time
|
| 361 |
+
window.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
pro_df (pd.DataFrame): dataframe containing index dates and PRO submission
|
| 365 |
+
dates.
|
| 366 |
+
pro_name (str): name of the PRO.
|
| 367 |
+
time_window (int, optional): time window (in months) used for calculating the
|
| 368 |
+
average of PRO responses. Defaults to None.
|
| 369 |
+
avg_period (str, optional): identifies which prefix to add to output columns.
|
| 370 |
+
Defaults to None.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
pd.Dataframe: the input dateframe with additional columns with the calculated
|
| 374 |
+
averages.
|
| 375 |
+
"""
|
| 376 |
+
# Calculate average in PRO responses for the time window specified prior to the
|
| 377 |
+
# index date
|
| 378 |
+
pro_df = pro_df.loc[
|
| 379 |
+
:,
|
| 380 |
+
~(
|
| 381 |
+
pro_df.columns.str.startswith("Avg")
|
| 382 |
+
| pro_df.columns.str.startswith("Longer")
|
| 383 |
+
),
|
| 384 |
+
]
|
| 385 |
+
|
| 386 |
+
if avg_period is None:
|
| 387 |
+
prefix = "Avg"
|
| 388 |
+
pro_df["AvgStartDate"] = pro_df["IndexDate"] - pd.DateOffset(months=time_window)
|
| 389 |
+
avg_pro = pro_df[
|
| 390 |
+
(pro_df["SubmissionTime"] >= pro_df["AvgStartDate"])
|
| 391 |
+
& (pro_df["SubmissionTime"] < pro_df["LatestPRODate"])
|
| 392 |
+
]
|
| 393 |
+
else:
|
| 394 |
+
pro_df["WeekStartDate"] = pro_df["IndexDate"] - pd.DateOffset(weeks=1)
|
| 395 |
+
pro_df["WeekPrevStartDate"] = pro_df["WeekStartDate"] - pd.DateOffset(weeks=1)
|
| 396 |
+
|
| 397 |
+
# When looking at daily PROs, three averages are calculated:
|
| 398 |
+
# The weekly average is the average of PRO scores in the week prior to IndexDate
|
| 399 |
+
if avg_period == "WeeklyAvg":
|
| 400 |
+
prefix = "WeekAvg"
|
| 401 |
+
avg_pro = pro_df[
|
| 402 |
+
(pro_df["SubmissionTime"] >= pro_df["WeekStartDate"])
|
| 403 |
+
& (pro_df["SubmissionTime"] <= pro_df["IndexDate"])
|
| 404 |
+
]
|
| 405 |
+
# The weekly previous average is the average of PRO scores in the week prior to the
|
| 406 |
+
# WeeklyAvg. This is needed to calculate the difference of scores between the most
|
| 407 |
+
# recent week and the week before that
|
| 408 |
+
elif avg_period == "WeekPrevAvg":
|
| 409 |
+
prefix = "WeekPrevAvg"
|
| 410 |
+
avg_pro = pro_df[
|
| 411 |
+
(pro_df["SubmissionTime"] >= pro_df["WeekPrevStartDate"])
|
| 412 |
+
& (pro_df["SubmissionTime"] < pro_df["WeekStartDate"])
|
| 413 |
+
]
|
| 414 |
+
# Longer average calculated is the time window specified prior to the WeekStartDate
|
| 415 |
+
elif avg_period == "LongerAvg":
|
| 416 |
+
prefix = "LongerAvg"
|
| 417 |
+
pro_df["AvgStartDate"] = pro_df["IndexDate"] - pd.DateOffset(months=time_window)
|
| 418 |
+
avg_pro = pro_df[
|
| 419 |
+
(pro_df["SubmissionTime"] >= pro_df["AvgStartDate"])
|
| 420 |
+
& (pro_df["SubmissionTime"] < pro_df["WeekStartDate"])
|
| 421 |
+
]
|
| 422 |
+
|
| 423 |
+
# Select specific columns
|
| 424 |
+
cols_required = avg_pro.columns[avg_pro.columns.str.startswith(pro_name)].tolist()
|
| 425 |
+
cols_required.extend(["StudyId", "IndexDate", "Score"])
|
| 426 |
+
avg_pro = avg_pro[cols_required]
|
| 427 |
+
|
| 428 |
+
# Calculate average pro scores
|
| 429 |
+
avg_pro = avg_pro.groupby(["StudyId", "IndexDate"]).mean().reset_index()
|
| 430 |
+
|
| 431 |
+
# Add prefix to question columns
|
| 432 |
+
cols_to_rename = avg_pro.columns[
|
| 433 |
+
~avg_pro.columns.isin(["StudyId", "IndexDate", "Score"])
|
| 434 |
+
]
|
| 435 |
+
avg_pro = avg_pro.rename(columns=dict(zip(cols_to_rename, prefix + cols_to_rename)))
|
| 436 |
+
|
| 437 |
+
# Rename columns where prefix not added
|
| 438 |
+
avg_pro = avg_pro.rename(columns={"Score": prefix + pro_name + "Score"})
|
| 439 |
+
|
| 440 |
+
# Merge average PRO with rest of the df
|
| 441 |
+
pro_df = pro_df.merge(avg_pro, on=["StudyId", "IndexDate"], how="left")
|
| 442 |
+
return pro_df
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def calc_diff_pro_scores(pro_df, pro_name, latest_pro, other_pro, time_window=None):
|
| 446 |
+
"""
|
| 447 |
+
Calculate the difference between PRO scores.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
pro_df (pd.DataFrame): dataframe containing columns required for calculations.
|
| 451 |
+
pro_name (str): name of the PRO.
|
| 452 |
+
latest_pro (str): column name containing the latest PRO response for PROs EQ5D
|
| 453 |
+
and MRC, and the latest week average for PROs CAT and SymptomDiary.
|
| 454 |
+
other_pro (str): column name containing the other variable for calculating
|
| 455 |
+
difference.
|
| 456 |
+
time_window (int, optional): time window (in months) used to specify which
|
| 457 |
+
column to use when calculating differences.
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
pd.Dataframe: the input dateframe with additional columns with the calculated
|
| 461 |
+
differences.
|
| 462 |
+
"""
|
| 463 |
+
# Remove prefix of score
|
| 464 |
+
split_feat_name = re.findall(r"[A-Z][^A-Z]*", latest_pro)
|
| 465 |
+
|
| 466 |
+
# Remove first element of list to get the base name of feature
|
| 467 |
+
split_feat_name.pop(0)
|
| 468 |
+
|
| 469 |
+
# Remove the second element in list if PRO is CAT or SymptomDiary
|
| 470 |
+
if pro_name in ["CAT", "SymptomDiary"]:
|
| 471 |
+
split_feat_name.pop(0)
|
| 472 |
+
|
| 473 |
+
# Combine remaining elements of list
|
| 474 |
+
stripped_feat_name = "".join(split_feat_name)
|
| 475 |
+
|
| 476 |
+
if time_window is None:
|
| 477 |
+
pro_df["DiffLatestPrev" + stripped_feat_name] = (
|
| 478 |
+
pro_df[latest_pro] - pro_df[other_pro]
|
| 479 |
+
)
|
| 480 |
+
else:
|
| 481 |
+
pro_df["DiffLatestAvg" + stripped_feat_name + "TW" + str(time_window)] = (
|
| 482 |
+
pro_df[latest_pro] - pro_df[other_pro]
|
| 483 |
+
)
|
| 484 |
+
return pro_df
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def calc_variation(pro_df, pro_name):
|
| 488 |
+
"""
|
| 489 |
+
Calculate the variation (standard deviation) of PRO responses for a time window of
|
| 490 |
+
1 month.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
pro_df (pd.DataFrame): dataframe containing index dates and PRO submission
|
| 494 |
+
dates.
|
| 495 |
+
pro_name (str): name of the PRO.
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
pd.Dataframe: the input dateframe with additional columns with the calculated
|
| 499 |
+
variance.
|
| 500 |
+
"""
|
| 501 |
+
# Only calculate variation in the scores within 1 month before the IndexDate
|
| 502 |
+
if "TimeSinceSubmission" not in pro_df:
|
| 503 |
+
pro_df["TimeSinceSubmission"] = (
|
| 504 |
+
pro_df["IndexDate"] - pro_df["SubmissionTime"]
|
| 505 |
+
).dt.days
|
| 506 |
+
pro_var = pro_df[
|
| 507 |
+
(pro_df["TimeSinceSubmission"] > 0) & (pro_df["TimeSinceSubmission"] <= 30)
|
| 508 |
+
]
|
| 509 |
+
|
| 510 |
+
# Select specific columns
|
| 511 |
+
cols_required = pro_var.columns[pro_var.columns.str.startswith(pro_name)].tolist()
|
| 512 |
+
cols_required.extend(["StudyId", "IndexDate", "Score"])
|
| 513 |
+
pro_var = pro_var[cols_required]
|
| 514 |
+
|
| 515 |
+
# Calculate variation
|
| 516 |
+
pro_var = pro_var.groupby(["StudyId", "IndexDate"]).std().reset_index()
|
| 517 |
+
|
| 518 |
+
# Add prefix to question columns
|
| 519 |
+
cols_to_rename = pro_var.columns[
|
| 520 |
+
~pro_var.columns.isin(["StudyId", "IndexDate", "Score"])
|
| 521 |
+
]
|
| 522 |
+
pro_var = pro_var.rename(columns=dict(zip(cols_to_rename, "Var" + cols_to_rename)))
|
| 523 |
+
|
| 524 |
+
# Rename columns where prefix not added
|
| 525 |
+
pro_var = pro_var.rename(columns={"Score": "Var" + pro_name + "Score"})
|
| 526 |
+
|
| 527 |
+
# Merge back to main df
|
| 528 |
+
pro_df = pro_df.merge(pro_var, on=["StudyId", "IndexDate"], how="left")
|
| 529 |
+
return pro_df
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def calc_sum_binary_vals(pro_df, binary_cols, time_window=1):
|
| 533 |
+
"""
|
| 534 |
+
For SymptomDiary questions that contain binary values, calculate the sum of the
|
| 535 |
+
binary values for a specified time window.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
pro_df (pd.DataFrame): dataframe containing index dates and PRO submission
|
| 539 |
+
dates.
|
| 540 |
+
binary_cols (list): column names for which sum of binary values is to be
|
| 541 |
+
calculated for.
|
| 542 |
+
time_window (int, optional): time window (in months) for which the sum of the
|
| 543 |
+
binary values is calculated for. Defaults to 1.
|
| 544 |
+
|
| 545 |
+
Returns:
|
| 546 |
+
pd.Dataframe: a dataframe containing the sum of the binary values.
|
| 547 |
+
"""
|
| 548 |
+
# Make sure only entries before the index date and after the time window start date
|
| 549 |
+
# are used
|
| 550 |
+
pro_df["TimeWindowStartDate"] = pro_df["IndexDate"] - pd.DateOffset(
|
| 551 |
+
months=time_window
|
| 552 |
+
)
|
| 553 |
+
pro_df = pro_df[
|
| 554 |
+
(pro_df["SubmissionTime"] >= pro_df["TimeWindowStartDate"])
|
| 555 |
+
& (pro_df["SubmissionTime"] <= pro_df["IndexDate"])
|
| 556 |
+
]
|
| 557 |
+
sum_df = pro_df.groupby(["StudyId", "IndexDate"])[binary_cols].sum()
|
| 558 |
+
|
| 559 |
+
# Rename columns
|
| 560 |
+
sum_df = sum_df.add_prefix("Sum")
|
| 561 |
+
sum_df = sum_df.add_suffix("TW" + str(time_window))
|
| 562 |
+
sum_df = sum_df.reset_index()
|
| 563 |
+
return sum_df
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def scale_sum_to_response_rate(pro_df, sum, engagement_rate):
|
| 567 |
+
"""
|
| 568 |
+
Scale the sum calculated using copd.calc_sum_binary_vals() to the response
|
| 569 |
+
rate to obtain a feature that is comparable between patients.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
pro_df (pd.DataFrame): dataframe containing the columns for scaling the sum by
|
| 573 |
+
the engagement rate.
|
| 574 |
+
sum (str): column name that contains the data for the sum of the binary values.
|
| 575 |
+
engagement_rate (str): column name that contains the data for the response rate.
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
pd.Dataframe: the input dateframe with additional columns with the scaled sum.
|
| 579 |
+
"""
|
| 580 |
+
pro_df["Scaled" + sum] = pro_df[sum] / pro_df[engagement_rate]
|
| 581 |
+
return pro_df
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
with open("./training/config.yaml", "r") as config:
|
| 585 |
+
config = yaml.safe_load(config)
|
| 586 |
+
|
| 587 |
+
# Specify which model to generate features for
|
| 588 |
+
model_type = config["model_settings"]["model_type"]
|
| 589 |
+
|
| 590 |
+
# Setup log file
|
| 591 |
+
log = open("./training/logging/process_pros_" + model_type + ".log", "w")
|
| 592 |
+
sys.stdout = log
|
| 593 |
+
|
| 594 |
+
# Dataset to process - set through config file
|
| 595 |
+
data_to_process = config["model_settings"]["data_to_process"]
|
| 596 |
+
|
| 597 |
+
# Load cohort data
|
| 598 |
+
if data_to_process == "forward_val":
|
| 599 |
+
data = pd.read_pickle("./data/patient_labels_forward_val_hosp_comm.pkl")
|
| 600 |
+
patient_details = pd.read_pickle("./data/patient_details_forward_val.pkl")
|
| 601 |
+
else:
|
| 602 |
+
data = pd.read_pickle("./data/patient_labels_" + model_type + ".pkl")
|
| 603 |
+
patient_details = pd.read_pickle("./data/patient_details.pkl")
|
| 604 |
+
data = data[["StudyId", "IndexDate"]]
|
| 605 |
+
patient_details = data.merge(
|
| 606 |
+
patient_details[["StudyId", "FirstSubmissionDate", "LatestPredictionDate"]],
|
| 607 |
+
on="StudyId",
|
| 608 |
+
how="left",
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# Calculate the lookback start date. Will need this to aggreggate data for model
|
| 612 |
+
# features
|
| 613 |
+
data["LookbackStartDate"] = data["IndexDate"] - pd.DateOffset(
|
| 614 |
+
days=config["model_settings"]["lookback_period"]
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
############################################
|
| 618 |
+
# Monthly PROs - EQ5D
|
| 619 |
+
############################################
|
| 620 |
+
pro_eq5d = pd.read_csv(config["inputs"]["raw_data_paths"]["pro_eq5d"], delimiter="|")
|
| 621 |
+
pro_eq5d = pro_eq5d.merge(
|
| 622 |
+
patient_details,
|
| 623 |
+
on="StudyId",
|
| 624 |
+
how="inner",
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# EQ5DQ6 is a much less structured question compared to the other questions in EQ5D.
|
| 628 |
+
# A new score will be calculated using only EQ5DQ1-Q5 to prevent Q6 affecting the score
|
| 629 |
+
pro_eq5d["EQ5DScoreWithoutQ6"] = pro_eq5d.loc[:, "EQ5DQ1":"EQ5DQ5"].sum(axis=1)
|
| 630 |
+
|
| 631 |
+
# Calculate engagement over service
|
| 632 |
+
pro_eq5d = calc_total_pro_engagement(pro_eq5d, "EQ5D")
|
| 633 |
+
|
| 634 |
+
# Calculate engagement for a time window of 1 month (time window chosen based on signal
|
| 635 |
+
# output observed from results of feature_eng_multiple_testing)
|
| 636 |
+
pro_eq5d_engagement = calc_pro_engagement_in_time_window(
|
| 637 |
+
pro_eq5d, "EQ5D", time_window=1, data=data
|
| 638 |
+
)
|
| 639 |
+
pro_eq5d = pro_eq5d.merge(pro_eq5d_engagement, on=["StudyId", "IndexDate"], how="left")
|
| 640 |
+
|
| 641 |
+
# Calculate last PRO score
|
| 642 |
+
pro_eq5d = calc_last_pro_score(pro_eq5d, "EQ5D")
|
| 643 |
+
|
| 644 |
+
# Mapping to calculate the difference between the latest PRO scores and the average
|
| 645 |
+
# PRO score
|
| 646 |
+
question_names_eq5d = ["Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Score", "ScoreWithoutQ6"]
|
| 647 |
+
mapping_eq5d = define_mapping_for_calcs("EQ5D", question_names_eq5d, prefixes=["Avg"])
|
| 648 |
+
|
| 649 |
+
# Calculate average PRO score for a time window of 1 month prior to IndexDate,
|
| 650 |
+
# ignoring the latest PRO score
|
| 651 |
+
pro_eq5d = calc_pro_average(pro_eq5d, "EQ5D", time_window=1)
|
| 652 |
+
for key in mapping_eq5d:
|
| 653 |
+
calc_diff_pro_scores(pro_eq5d, "EQ5D", key, mapping_eq5d[key][0], time_window=1)
|
| 654 |
+
|
| 655 |
+
# Calculate variation of scores across 1 month
|
| 656 |
+
pro_eq5d = calc_variation(pro_eq5d, "EQ5D")
|
| 657 |
+
|
| 658 |
+
# Remove unwanted columns and duplicates
|
| 659 |
+
pro_eq5d = pro_eq5d.loc[
|
| 660 |
+
:,
|
| 661 |
+
~(
|
| 662 |
+
pro_eq5d.columns.str.startswith("Avg")
|
| 663 |
+
| pro_eq5d.columns.str.startswith("EQ5D")
|
| 664 |
+
| pro_eq5d.columns.str.startswith("Response")
|
| 665 |
+
),
|
| 666 |
+
]
|
| 667 |
+
pro_eq5d = pro_eq5d.drop(
|
| 668 |
+
columns=[
|
| 669 |
+
"Score",
|
| 670 |
+
"SubmissionTime",
|
| 671 |
+
"FirstSubmissionDate",
|
| 672 |
+
"TimeInService",
|
| 673 |
+
"TimeSinceSubmission",
|
| 674 |
+
"LatestPredictionDate",
|
| 675 |
+
"LatestPRODate",
|
| 676 |
+
]
|
| 677 |
+
)
|
| 678 |
+
pro_eq5d = pro_eq5d.drop_duplicates()
|
| 679 |
+
|
| 680 |
+
############################################
|
| 681 |
+
# Weekly PROs - MRC
|
| 682 |
+
############################################
|
| 683 |
+
pro_mrc = pd.read_csv(config["inputs"]["raw_data_paths"]["pro_mrc"], delimiter="|")
|
| 684 |
+
pro_mrc = pro_mrc.merge(
|
| 685 |
+
patient_details,
|
| 686 |
+
on="StudyId",
|
| 687 |
+
how="inner",
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
# Calculate engagement over service
|
| 691 |
+
pro_mrc = calc_total_pro_engagement(pro_mrc, "MRC")
|
| 692 |
+
|
| 693 |
+
# Calculate engagement for a time window of 1 month
|
| 694 |
+
pro_mrc_engagement = calc_pro_engagement_in_time_window(
|
| 695 |
+
pro_mrc, "MRC", time_window=1, data=data
|
| 696 |
+
)
|
| 697 |
+
pro_mrc = pro_mrc.merge(pro_mrc_engagement, on=["StudyId", "IndexDate"], how="left")
|
| 698 |
+
|
| 699 |
+
# Calculate last PRO score
|
| 700 |
+
pro_mrc = calc_last_pro_score(pro_mrc, "MRC")
|
| 701 |
+
|
| 702 |
+
# Mapping to calculate the difference between the latest PRO scores and the average
|
| 703 |
+
# PRO score
|
| 704 |
+
question_names_mrc = ["Q1"]
|
| 705 |
+
mapping_mrc = define_mapping_for_calcs("MRC", question_names_mrc, prefixes=["Avg"])
|
| 706 |
+
|
| 707 |
+
# Calculate average PRO score for a time window of 1 month prior to IndexDate,
|
| 708 |
+
# ignoring the latest PRO score
|
| 709 |
+
pro_mrc = calc_pro_average(pro_mrc, "MRC", time_window=1)
|
| 710 |
+
for key in mapping_mrc:
|
| 711 |
+
calc_diff_pro_scores(pro_mrc, "MRC", key, mapping_mrc[key][0], time_window=1)
|
| 712 |
+
|
| 713 |
+
# Calculate variation of scores across 1 month
|
| 714 |
+
pro_mrc = calc_variation(pro_mrc, "MRC")
|
| 715 |
+
|
| 716 |
+
# Remove unwanted columns and duplicates
|
| 717 |
+
pro_mrc = pro_mrc.loc[
|
| 718 |
+
:,
|
| 719 |
+
~(
|
| 720 |
+
pro_mrc.columns.str.startswith("Avg")
|
| 721 |
+
| pro_mrc.columns.str.startswith("MRC")
|
| 722 |
+
| pro_mrc.columns.str.startswith("Response")
|
| 723 |
+
),
|
| 724 |
+
]
|
| 725 |
+
pro_mrc = pro_mrc.drop(
|
| 726 |
+
columns=[
|
| 727 |
+
"SubmissionTime",
|
| 728 |
+
"Score",
|
| 729 |
+
"FirstSubmissionDate",
|
| 730 |
+
"TimeInService",
|
| 731 |
+
"TimeSinceSubmission",
|
| 732 |
+
"LatestPredictionDate",
|
| 733 |
+
"LatestPRODate",
|
| 734 |
+
]
|
| 735 |
+
)
|
| 736 |
+
pro_mrc = pro_mrc.drop_duplicates()
|
| 737 |
+
|
| 738 |
+
############################################
|
| 739 |
+
# Daily PROs - CAT
|
| 740 |
+
############################################
|
| 741 |
+
pro_cat_full = pd.read_csv(config["inputs"]["raw_data_paths"]["pro_cat"], delimiter="|")
|
| 742 |
+
pro_cat = pro_cat_full.merge(
|
| 743 |
+
patient_details,
|
| 744 |
+
on="StudyId",
|
| 745 |
+
how="inner",
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Calculate engagement over service
|
| 749 |
+
pro_cat = calc_total_pro_engagement(pro_cat, "CAT")
|
| 750 |
+
|
| 751 |
+
# Calculate engagement for a time window of 1 month
|
| 752 |
+
pro_cat_engagement = calc_pro_engagement_in_time_window(
|
| 753 |
+
pro_cat, "CAT", time_window=1, data=data
|
| 754 |
+
)
|
| 755 |
+
pro_cat = pro_cat.merge(pro_cat_engagement, on=["StudyId", "IndexDate"], how="left")
|
| 756 |
+
|
| 757 |
+
# Calculate engagement in the month prior to the most recent month to index date
|
| 758 |
+
pro_cat_month1 = calc_pro_engagement_at_specific_month(
|
| 759 |
+
pro_cat, "CAT", month_num=1, data=data
|
| 760 |
+
)
|
| 761 |
+
pro_cat_month2 = calc_pro_engagement_at_specific_month(
|
| 762 |
+
pro_cat, "CAT", month_num=2, data=data
|
| 763 |
+
)
|
| 764 |
+
pro_cat_month3 = calc_pro_engagement_at_specific_month(
|
| 765 |
+
pro_cat, "CAT", month_num=3, data=data
|
| 766 |
+
)
|
| 767 |
+
pro_cat = pro_cat.merge(pro_cat_month1, on=["StudyId", "IndexDate"], how="left")
|
| 768 |
+
pro_cat = pro_cat.merge(pro_cat_month2, on=["StudyId", "IndexDate"], how="left")
|
| 769 |
+
pro_cat = pro_cat.merge(pro_cat_month3, on=["StudyId", "IndexDate"], how="left")
|
| 770 |
+
pro_cat["EngagementDiffMonth1and2"] = (
|
| 771 |
+
pro_cat["EngagementCATMonth1"] - pro_cat["EngagementCATMonth2"]
|
| 772 |
+
)
|
| 773 |
+
pro_cat["EngagementDiffMonth1and3"] = (
|
| 774 |
+
pro_cat["EngagementCATMonth1"] - pro_cat["EngagementCATMonth3"]
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
# Calculate PRO average for the week before the index date
|
| 778 |
+
pro_cat = calc_pro_average(pro_cat, "CAT", avg_period="WeeklyAvg")
|
| 779 |
+
|
| 780 |
+
# Calculate variation of scores across 1 month
|
| 781 |
+
pro_cat = calc_variation(pro_cat, "CAT")
|
| 782 |
+
|
| 783 |
+
# Remove unwanted columns and duplicates
|
| 784 |
+
pro_cat = pro_cat.loc[
|
| 785 |
+
:,
|
| 786 |
+
~(
|
| 787 |
+
pro_cat.columns.str.startswith("CAT")
|
| 788 |
+
| pro_cat.columns.str.startswith("Response")
|
| 789 |
+
),
|
| 790 |
+
]
|
| 791 |
+
pro_cat = pro_cat.drop(
|
| 792 |
+
columns=[
|
| 793 |
+
"Score",
|
| 794 |
+
"SubmissionTime",
|
| 795 |
+
"FirstSubmissionDate",
|
| 796 |
+
"TimeSinceSubmission",
|
| 797 |
+
"LatestPredictionDate",
|
| 798 |
+
"TimeInService",
|
| 799 |
+
"WeekStartDate",
|
| 800 |
+
"WeekPrevStartDate",
|
| 801 |
+
]
|
| 802 |
+
)
|
| 803 |
+
pro_cat = pro_cat.drop_duplicates()
|
| 804 |
+
|
| 805 |
+
############################################
|
| 806 |
+
# Daily PROs - Symptom Diary
|
| 807 |
+
############################################
|
| 808 |
+
|
| 809 |
+
# Symptom diary have some questions that are numeric and some that are categorical
|
| 810 |
+
pro_sd_full = pd.read_csv(
|
| 811 |
+
config["inputs"]["raw_data_paths"]["pro_symptom_diary"], delimiter="|"
|
| 812 |
+
)
|
| 813 |
+
pro_sd = pro_sd_full.merge(
|
| 814 |
+
patient_details,
|
| 815 |
+
on="StudyId",
|
| 816 |
+
how="inner",
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# Calculate engagement over service
|
| 820 |
+
pro_sd = calc_total_pro_engagement(pro_sd, "SymptomDiary")
|
| 821 |
+
pro_sd_engagement = pro_sd[
|
| 822 |
+
["StudyId", "PatientId", "IndexDate", "TotalEngagementSymptomDiary"]
|
| 823 |
+
]
|
| 824 |
+
|
| 825 |
+
# Calculate engagement for 1 month prior to IndexDate
|
| 826 |
+
pro_sd_engagement_tw = calc_pro_engagement_in_time_window(
|
| 827 |
+
pro_sd, "SymptomDiary", time_window=1, data=data
|
| 828 |
+
)
|
| 829 |
+
pro_sd_engagement = pro_sd_engagement.merge(
|
| 830 |
+
pro_sd_engagement_tw, on=["StudyId", "IndexDate"], how="left"
|
| 831 |
+
)
|
| 832 |
+
pro_sd_engagement = pro_sd_engagement.drop_duplicates()
|
| 833 |
+
|
| 834 |
+
###############################
|
| 835 |
+
# Categorical questions
|
| 836 |
+
# (Q8, Q9, Q10)
|
| 837 |
+
###############################
|
| 838 |
+
pro_cat_q5 = pro_cat_full[["StudyId", "SubmissionTime", "CATQ5"]]
|
| 839 |
+
pro_sd_categ = pro_sd_full[
|
| 840 |
+
[
|
| 841 |
+
"StudyId",
|
| 842 |
+
"SubmissionTime",
|
| 843 |
+
"SymptomDiaryQ8",
|
| 844 |
+
"SymptomDiaryQ9",
|
| 845 |
+
"SymptomDiaryQ10",
|
| 846 |
+
"Score",
|
| 847 |
+
]
|
| 848 |
+
]
|
| 849 |
+
|
| 850 |
+
# Split timestamp column into separate date and time columns as same day entries in CAT
|
| 851 |
+
# and SymptomDiary have different timestamps
|
| 852 |
+
for df in [pro_cat_q5, pro_sd_categ]:
|
| 853 |
+
df["Date"] = (pd.to_datetime(df["SubmissionTime"], utc=True)).dt.date
|
| 854 |
+
pro_sd_cat = pro_sd_categ.merge(pro_cat_q5, on=["StudyId", "Date"], how="outer")
|
| 855 |
+
|
| 856 |
+
# If CATQ5 is a 0, then Symptom Diary questions 8, 9 and 10 don't get asked. Add this as
|
| 857 |
+
# an option to the columns. There are some cases where patients have a 0 in CATQ5 but
|
| 858 |
+
# have also answered Symptom Diary questions 8, 9, and 10 - keep these answers as is.
|
| 859 |
+
for col in ["SymptomDiaryQ8", "SymptomDiaryQ9", "SymptomDiaryQ10"]:
|
| 860 |
+
pro_sd_cat[col] = np.where(
|
| 861 |
+
(pro_sd_cat["CATQ5"] == 0) & (pro_sd_cat[col].isna()),
|
| 862 |
+
"Question Not Asked",
|
| 863 |
+
pro_sd_cat[col],
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
# Calculate the most recent score for SymptomDiary categorical questions
|
| 867 |
+
pro_sd_cat = pro_sd_cat.merge(data[["StudyId", "IndexDate"]], on="StudyId", how="inner")
|
| 868 |
+
pro_sd_cat = pro_sd_cat.rename(columns={"SubmissionTime_x": "SubmissionTime"})
|
| 869 |
+
pro_sd_cat["SubmissionTime"] = pd.to_datetime(pro_sd_cat["SubmissionTime"], utc=True)
|
| 870 |
+
pro_sd_cat = calc_last_pro_score(pro_sd_cat, "SymptomDiary")
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
pro_sd_cat = pro_sd_cat.drop(
|
| 874 |
+
columns=[
|
| 875 |
+
"SubmissionTime",
|
| 876 |
+
"SubmissionTime_y",
|
| 877 |
+
"CATQ5",
|
| 878 |
+
"SymptomDiaryQ8",
|
| 879 |
+
"SymptomDiaryQ9",
|
| 880 |
+
"Date",
|
| 881 |
+
"SymptomDiaryQ10",
|
| 882 |
+
"Score",
|
| 883 |
+
"LatestSymptomDiaryScore",
|
| 884 |
+
"LatestPRODate",
|
| 885 |
+
"TimeSinceSubmission",
|
| 886 |
+
]
|
| 887 |
+
)
|
| 888 |
+
pro_sd_cat = pro_sd_cat.drop_duplicates()
|
| 889 |
+
|
| 890 |
+
###############################
|
| 891 |
+
# Numeric questions
|
| 892 |
+
# (Q1, Q2)
|
| 893 |
+
# Q3 included for comparison
|
| 894 |
+
###############################
|
| 895 |
+
# Calculate PRO average for the week before the index date
|
| 896 |
+
pro_sd_numeric = pro_sd[
|
| 897 |
+
[
|
| 898 |
+
"StudyId",
|
| 899 |
+
"PatientId",
|
| 900 |
+
"IndexDate",
|
| 901 |
+
"SubmissionTime",
|
| 902 |
+
"Score",
|
| 903 |
+
"SymptomDiaryQ1",
|
| 904 |
+
"SymptomDiaryQ2",
|
| 905 |
+
"SymptomDiaryQ3",
|
| 906 |
+
]
|
| 907 |
+
]
|
| 908 |
+
pro_sd_numeric = calc_pro_average(
|
| 909 |
+
pro_sd_numeric, "SymptomDiary", avg_period="WeeklyAvg"
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
# Calculate variation of scores across 1 month
|
| 913 |
+
pro_sd_numeric = calc_variation(pro_sd_numeric, "SymptomDiary")
|
| 914 |
+
|
| 915 |
+
###############################
|
| 916 |
+
# Binary questions
|
| 917 |
+
# (Q3)
|
| 918 |
+
###############################
|
| 919 |
+
# Calculate sum of binary values for a time window of 1 months
|
| 920 |
+
sd_sum_all = pro_sd_numeric[["StudyId", "IndexDate"]]
|
| 921 |
+
sd_sum_all = sd_sum_all.drop_duplicates()
|
| 922 |
+
sd_sum = calc_sum_binary_vals(
|
| 923 |
+
pro_sd_numeric, binary_cols=["SymptomDiaryQ3"], time_window=1
|
| 924 |
+
)
|
| 925 |
+
sd_sum_all = sd_sum_all.merge(sd_sum, on=["StudyId", "IndexDate"], how="left")
|
| 926 |
+
|
| 927 |
+
# Scale sums by how often patients responded
|
| 928 |
+
sd_sum_all = sd_sum_all.merge(
|
| 929 |
+
pro_sd_engagement, on=["StudyId", "IndexDate"], how="left"
|
| 930 |
+
)
|
| 931 |
+
mapping_scaling = {"SumSymptomDiaryQ3TW1": "EngagementSymptomDiaryTW1"}
|
| 932 |
+
for key in mapping_scaling:
|
| 933 |
+
scale_sum_to_response_rate(sd_sum_all, key, mapping_scaling[key])
|
| 934 |
+
|
| 935 |
+
# Combine numeric, categorical and binary dfs
|
| 936 |
+
pro_sd_all = pro_sd_numeric.merge(
|
| 937 |
+
sd_sum_all, on=["StudyId", "PatientId", "IndexDate"], how="left"
|
| 938 |
+
)
|
| 939 |
+
pro_sd_all = pro_sd_all.merge(pro_sd_cat, on=["StudyId", "IndexDate"], how="left")
|
| 940 |
+
|
| 941 |
+
# Remove unwanted columns from numeric df
|
| 942 |
+
pro_sd_all = pro_sd_all.loc[
|
| 943 |
+
:,
|
| 944 |
+
~(
|
| 945 |
+
pro_sd_all.columns.str.startswith("Symptom")
|
| 946 |
+
| pro_sd_all.columns.str.startswith("Sum")
|
| 947 |
+
| pro_sd_all.columns.str.startswith("Response")
|
| 948 |
+
),
|
| 949 |
+
]
|
| 950 |
+
pro_sd_all = pro_sd_all.drop(
|
| 951 |
+
columns=[
|
| 952 |
+
"Score",
|
| 953 |
+
"SubmissionTime",
|
| 954 |
+
"TimeWindowStartDate",
|
| 955 |
+
"WeekStartDate",
|
| 956 |
+
"WeekPrevStartDate",
|
| 957 |
+
"TimeSinceSubmission",
|
| 958 |
+
]
|
| 959 |
+
)
|
| 960 |
+
pro_sd_all = pro_sd_all.drop_duplicates()
|
| 961 |
+
|
| 962 |
+
# Combine pros into one df
|
| 963 |
+
pro_df = pro_eq5d.merge(pro_mrc, on=["StudyId", "PatientId", "IndexDate"], how="left")
|
| 964 |
+
pro_df = pro_df.merge(pro_cat, on=["StudyId", "PatientId", "IndexDate"], how="left")
|
| 965 |
+
pro_df = pro_df.merge(pro_sd_all, on=["StudyId", "PatientId", "IndexDate"], how="left")
|
| 966 |
+
|
| 967 |
+
###############################
|
| 968 |
+
# Map some categorical features
|
| 969 |
+
###############################
|
| 970 |
+
|
| 971 |
+
# Replace SDQ8 with strings for phlegm difficulty
|
| 972 |
+
q8_dict = {
|
| 973 |
+
"1.0": "Not difficult",
|
| 974 |
+
"2.0": "A little difficult",
|
| 975 |
+
"3.0": "Quite difficult",
|
| 976 |
+
"4.0": "Very difficult",
|
| 977 |
+
}
|
| 978 |
+
for key in q8_dict:
|
| 979 |
+
pro_df["LatestSymptomDiaryQ8"] = pro_df["LatestSymptomDiaryQ8"].str.replace(
|
| 980 |
+
key, q8_dict[key]
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
# Replace SDQ9 with strings for phlegm consistency
|
| 984 |
+
q9_dict = {
|
| 985 |
+
"1.0": "Watery",
|
| 986 |
+
"2.0": "Sticky liquid",
|
| 987 |
+
"3.0": "Semi-solid",
|
| 988 |
+
"4.0": "Solid",
|
| 989 |
+
}
|
| 990 |
+
for key in q9_dict:
|
| 991 |
+
pro_df["LatestSymptomDiaryQ9"] = pro_df["LatestSymptomDiaryQ9"].str.replace(
|
| 992 |
+
key, q9_dict[key]
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
# Replace SDQ10 with strings for phlegm colour
|
| 996 |
+
q10_dict = {
|
| 997 |
+
"1.0": "White",
|
| 998 |
+
"2.0": "Yellow",
|
| 999 |
+
"3.0": "Green",
|
| 1000 |
+
"4.0": "Dark green",
|
| 1001 |
+
}
|
| 1002 |
+
for key in q10_dict:
|
| 1003 |
+
pro_df["LatestSymptomDiaryQ10"] = pro_df["LatestSymptomDiaryQ10"].str.replace(
|
| 1004 |
+
key, q10_dict[key]
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
pro_df = pro_df.drop(
|
| 1008 |
+
columns=[
|
| 1009 |
+
"PatientId",
|
| 1010 |
+
"LatestTimeSinceSubmission",
|
| 1011 |
+
"LatestTimeSinceSubmission_x",
|
| 1012 |
+
"LatestTimeSinceSubmission_y",
|
| 1013 |
+
]
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
# Save data
|
| 1017 |
+
os.makedirs(config["outputs"]["processed_data_dir"], exist_ok=True)
|
| 1018 |
+
if data_to_process == "forward_val":
|
| 1019 |
+
pro_df.to_pickle(
|
| 1020 |
+
os.path.join(
|
| 1021 |
+
config["outputs"]["processed_data_dir"],
|
| 1022 |
+
"pros_forward_val_" + model_type + ".pkl",
|
| 1023 |
+
)
|
| 1024 |
+
)
|
| 1025 |
+
else:
|
| 1026 |
+
pro_df.to_pickle(
|
| 1027 |
+
os.path.join(
|
| 1028 |
+
config["outputs"]["processed_data_dir"],
|
| 1029 |
+
"pros_" + model_type + ".pkl",
|
| 1030 |
+
)
|
| 1031 |
+
)
|
training/process_spirometry.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Derive features from spirometry for 2 models:
|
| 3 |
+
Parallel model 1: uses both hospital and community exacerbation events
|
| 4 |
+
Parallel model 2: uses only hospital exacerbation events
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import yaml
|
| 12 |
+
import model_h
|
| 13 |
+
|
| 14 |
+
with open("./training/config.yaml", "r") as config:
|
| 15 |
+
config = yaml.safe_load(config)
|
| 16 |
+
|
| 17 |
+
# Specify which model to generate features for
|
| 18 |
+
model_type = config["model_settings"]["model_type"]
|
| 19 |
+
|
| 20 |
+
# Setup log file
|
| 21 |
+
log = open("./training/logging/process_spirometry_" + model_type + ".log", "w")
|
| 22 |
+
sys.stdout = log
|
| 23 |
+
|
| 24 |
+
# Dataset to process - set through config file
|
| 25 |
+
data_to_process = config["model_settings"]["data_to_process"]
|
| 26 |
+
|
| 27 |
+
# Load cohort data
|
| 28 |
+
if data_to_process == "forward_val":
|
| 29 |
+
data = pd.read_pickle("./data/patient_labels_forward_val_hosp_comm.pkl")
|
| 30 |
+
patient_details = pd.read_pickle("./data/patient_details_forward_val.pkl")
|
| 31 |
+
else:
|
| 32 |
+
data = pd.read_pickle("./data/patient_labels_" + model_type + ".pkl")
|
| 33 |
+
patient_details = pd.read_pickle("./data/patient_details.pkl")
|
| 34 |
+
data = data[["StudyId", "IndexDate"]]
|
| 35 |
+
patient_details = data.merge(
|
| 36 |
+
patient_details[["StudyId", "PatientId"]],
|
| 37 |
+
on="StudyId",
|
| 38 |
+
how="left",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
copd_status = pd.read_csv(
|
| 43 |
+
config["inputs"]["raw_data_paths"]["copd_status"], delimiter="|"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
copd_status = patient_details.merge(copd_status, on="PatientId", how="left")
|
| 47 |
+
copd_status["LungFunction_Date"] = pd.to_datetime(
|
| 48 |
+
copd_status["LungFunction_Date"], utc=True
|
| 49 |
+
)
|
| 50 |
+
copd_status["TimeSinceLungFunc"] = (
|
| 51 |
+
copd_status["IndexDate"] - copd_status["LungFunction_Date"]
|
| 52 |
+
).dt.days
|
| 53 |
+
print(
|
| 54 |
+
"COPD Status Details: Number of patients with a lung function date < 1 year \
|
| 55 |
+
from index date: {} of {}".format(
|
| 56 |
+
len(copd_status[copd_status["TimeSinceLungFunc"] < 365]), len(patient_details)
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
copd_status = copd_status[
|
| 60 |
+
[
|
| 61 |
+
"StudyId",
|
| 62 |
+
"IndexDate",
|
| 63 |
+
"RequiredAcuteNIV",
|
| 64 |
+
"RequiredICUAdmission",
|
| 65 |
+
"LungFunction_FEV1PercentPredicted",
|
| 66 |
+
"LungFunction_FEV1Litres",
|
| 67 |
+
"LungFunction_FEV1FVCRatio",
|
| 68 |
+
"TimeSinceLungFunc",
|
| 69 |
+
]
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
# Map bool values
|
| 73 |
+
bool_mapping = {True: 1, False: 0}
|
| 74 |
+
copd_status["RequiredAcuteNIV"] = copd_status.RequiredAcuteNIV.map(bool_mapping)
|
| 75 |
+
copd_status["RequiredICUAdmission"] = copd_status.RequiredICUAdmission.map(bool_mapping)
|
| 76 |
+
|
| 77 |
+
# Convert columns in COPD Status to numeric
|
| 78 |
+
copd_status["LungFunction_FEV1PercentPredicted"] = copd_status[
|
| 79 |
+
"LungFunction_FEV1PercentPredicted"
|
| 80 |
+
].str.replace("%", "")
|
| 81 |
+
for col in copd_status.drop(
|
| 82 |
+
columns=["StudyId", "IndexDate", "RequiredAcuteNIV", "RequiredICUAdmission"]
|
| 83 |
+
).columns:
|
| 84 |
+
copd_status[col] = pd.to_numeric(copd_status[col])
|
| 85 |
+
|
| 86 |
+
# Bin patient spirometry at onboarding
|
| 87 |
+
spirometry_bins = [0, 30, 50, 80, np.inf]
|
| 88 |
+
spirometry_labels = ["Very severe", "Severe", "Moderate", "Mild"]
|
| 89 |
+
copd_status["FEV1PercentPredicted"] = model_h.bin_numeric_column(
|
| 90 |
+
col=copd_status["LungFunction_FEV1PercentPredicted"],
|
| 91 |
+
bins=spirometry_bins,
|
| 92 |
+
labels=spirometry_labels,
|
| 93 |
+
)
|
| 94 |
+
copd_status = copd_status.drop(columns=["LungFunction_FEV1PercentPredicted"])
|
| 95 |
+
|
| 96 |
+
# Assign patients without spirometry in service data to the Mild category
|
| 97 |
+
copd_status.loc[
|
| 98 |
+
copd_status["FEV1PercentPredicted"] == "nan", "FEV1PercentPredicted"
|
| 99 |
+
] = "Mild"
|
| 100 |
+
|
| 101 |
+
# Save data
|
| 102 |
+
os.makedirs(config["outputs"]["processed_data_dir"], exist_ok=True)
|
| 103 |
+
if data_to_process == "forward_val":
|
| 104 |
+
copd_status.to_pickle(
|
| 105 |
+
os.path.join(
|
| 106 |
+
config["outputs"]["processed_data_dir"],
|
| 107 |
+
"spirometry_forward_val_" + model_type + ".pkl",
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
copd_status.to_pickle(
|
| 112 |
+
os.path.join(
|
| 113 |
+
config["outputs"]["processed_data_dir"],
|
| 114 |
+
"spirometry_" + model_type + ".pkl",
|
| 115 |
+
)
|
| 116 |
+
)
|
training/pros_multiple_time_windows.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Derive features from PRO responses for multiple time windows and select time window
|
| 3 |
+
that gives the best signal.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import model_h
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_cols_for_plotting(pro_name, question_col_names=None, var_engagement=False):
|
| 14 |
+
"""
|
| 15 |
+
Create a mapping for the PRO questions specified that allows plotting of the results
|
| 16 |
+
from the same question with different time windows on the same grid. The key of the
|
| 17 |
+
dictionary is the PRO question (e.g. 'EQ5DQ1') and the values are a list containing
|
| 18 |
+
column names to be plotted (e.g. ['LatestEQ5DQ1', 'DiffLatestAvgEQ5DQ1TW1']).
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
pro_name (str): name of PRO.
|
| 22 |
+
question_col_names (list, optional): a list of question names required for
|
| 23 |
+
plotting. Defaults to None.
|
| 24 |
+
var_engagement (bool, optional): whether the variable to be plot is engagement.
|
| 25 |
+
Defaults to False.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
dict of (str:list): dictionary containing mapping where each key maps to a list
|
| 29 |
+
of column names.
|
| 30 |
+
"""
|
| 31 |
+
cols_for_plotting = defaultdict(list)
|
| 32 |
+
|
| 33 |
+
if var_engagement is False:
|
| 34 |
+
for question in question_col_names:
|
| 35 |
+
for time_window_num in range(1, 7):
|
| 36 |
+
col_name = (
|
| 37 |
+
"DiffLatestAvg" + pro_name + question + "TW" + str(time_window_num)
|
| 38 |
+
)
|
| 39 |
+
cols_for_plotting[pro_name + question].append(col_name)
|
| 40 |
+
if (pro_name == "SymptomDiary") & (question == "Q3"):
|
| 41 |
+
col_name = (
|
| 42 |
+
"ScaledSum" + pro_name + question + "TW" + str(time_window_num)
|
| 43 |
+
)
|
| 44 |
+
cols_for_plotting["ScaledSum" + pro_name + question].append(
|
| 45 |
+
col_name
|
| 46 |
+
)
|
| 47 |
+
cols_for_plotting[pro_name + question].append(
|
| 48 |
+
"DiffLatestPrev" + pro_name + question
|
| 49 |
+
)
|
| 50 |
+
if (pro_name == "EQ5D") | (pro_name == "MRC"):
|
| 51 |
+
cols_for_plotting[pro_name + question].append(
|
| 52 |
+
"Latest" + pro_name + question
|
| 53 |
+
)
|
| 54 |
+
if (pro_name == "CAT") | (pro_name == "SymptomDiary"):
|
| 55 |
+
cols_for_plotting[pro_name + question].append(
|
| 56 |
+
"WeekAvg" + pro_name + question
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if var_engagement is True:
|
| 60 |
+
for time_window_num in range(1, 7):
|
| 61 |
+
col_name = "Engagement" + pro_name + "TW" + str(time_window_num)
|
| 62 |
+
cols_for_plotting[pro_name].append(col_name)
|
| 63 |
+
return cols_for_plotting
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def plot_feature_signal(
|
| 67 |
+
data, nrows, ncols, figsize, cols_to_plot, fig_name, outcome="ExacWithin3Months"
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Plot boxplots for each multiple columns onto the same grid if multiple columns are
|
| 71 |
+
specified.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
data (pd.DataFrame): dataframe containing all data to plot and outcome column.
|
| 75 |
+
nrows (int): number of rows for the subplot grid.
|
| 76 |
+
ncols (int): number of columns for the subplot grid.
|
| 77 |
+
figsize (tuple): (width, height) in inches.
|
| 78 |
+
cols_to_plot (list): column names to plot.
|
| 79 |
+
fig_name (str): name of figure required to save figure.
|
| 80 |
+
outcome (str, optional): name of column to group values by for plotting the
|
| 81 |
+
data. Defaults to 'ExacWithinMonths'.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
None.
|
| 85 |
+
"""
|
| 86 |
+
fig, ax = plt.subplots(nrows, ncols, figsize=figsize)
|
| 87 |
+
if (nrows > 1) | (ncols > 1):
|
| 88 |
+
ax = ax.flatten()
|
| 89 |
+
for i, col in enumerate(cols_to_plot):
|
| 90 |
+
data.boxplot(
|
| 91 |
+
col,
|
| 92 |
+
outcome,
|
| 93 |
+
ax=ax[i],
|
| 94 |
+
flierprops={"markersize": 2},
|
| 95 |
+
medianprops={"color": "black"},
|
| 96 |
+
# oxprops={"color": "black"},
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
for i, col in enumerate(cols_to_plot):
|
| 100 |
+
data.boxplot(
|
| 101 |
+
col,
|
| 102 |
+
outcome,
|
| 103 |
+
flierprops={"markersize": 2},
|
| 104 |
+
medianprops={"color": "black"},
|
| 105 |
+
)
|
| 106 |
+
plt.tight_layout()
|
| 107 |
+
plt.savefig("./plots/boxplots/" + fig_name + ".png")
|
| 108 |
+
plt.close()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
data = pd.read_pickle("./data/patient_labels_hosp_comm.pkl")
|
| 112 |
+
patient_details = pd.read_pickle("./data/patient_details.pkl")
|
| 113 |
+
|
| 114 |
+
data = data.merge(
|
| 115 |
+
patient_details[["StudyId", "FirstSubmissionDate", "LatestPredictionDate"]],
|
| 116 |
+
on="StudyId",
|
| 117 |
+
how="left",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Calculate the lookback start date. Will need this to aggreggate data for model
|
| 121 |
+
# features
|
| 122 |
+
data["LookbackStartDate"] = data["IndexDate"] - pd.DateOffset(days=180)
|
| 123 |
+
|
| 124 |
+
############################################################################
|
| 125 |
+
# Derive features from PRO responses
|
| 126 |
+
############################################################################
|
| 127 |
+
############################################
|
| 128 |
+
# Monthly PROs - EQ5D
|
| 129 |
+
############################################
|
| 130 |
+
pro_eq5d = pd.read_csv("<YOUR_DATA_PATH>/copd-dataset/CopdDatasetProEQ5D.txt", delimiter="|")
|
| 131 |
+
pro_eq5d = pro_eq5d.merge(
|
| 132 |
+
data[["StudyId", "IndexDate", "FirstSubmissionDate", "LatestPredictionDate"]],
|
| 133 |
+
on="StudyId",
|
| 134 |
+
how="inner",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# EQ5DQ6 is a much less structured question compared to the other questions in EQ5D. A
|
| 138 |
+
# new score will be calculated using only EQ5DQ1-Q5 to prevent Q6 affecting the score.
|
| 139 |
+
pro_eq5d["EQ5DScoreWithoutQ6"] = pro_eq5d.loc[:, "EQ5DQ1":"EQ5DQ5"].sum(axis=1)
|
| 140 |
+
|
| 141 |
+
# Calculate engagement over service
|
| 142 |
+
pro_eq5d = model_h.calc_total_pro_engagement(pro_eq5d, "EQ5D")
|
| 143 |
+
|
| 144 |
+
# Calculate engagement over multiple time windows
|
| 145 |
+
for time_window in range(1, 7):
|
| 146 |
+
pro_eq5d_engagement = model_h.calc_pro_engagement_in_time_window(
|
| 147 |
+
pro_eq5d, "EQ5D", time_window=time_window, data=data
|
| 148 |
+
)
|
| 149 |
+
pro_eq5d = pro_eq5d.merge(
|
| 150 |
+
pro_eq5d_engagement, on=["StudyId", "IndexDate"], how="left"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Calculate last PRO score
|
| 154 |
+
pro_eq5d = model_h.calc_last_pro_score(pro_eq5d, "EQ5D")
|
| 155 |
+
|
| 156 |
+
# Calculate the PRO score prior to the last PRO score
|
| 157 |
+
pro_eq5d = model_h.calc_pro_score_prior_to_latest(pro_eq5d, "EQ5D")
|
| 158 |
+
|
| 159 |
+
#############################
|
| 160 |
+
# Scores across time windows
|
| 161 |
+
#############################
|
| 162 |
+
# Mapping to calculate the difference between the latest PRO scores and both the average
|
| 163 |
+
# and previous PRO score
|
| 164 |
+
question_names_eq5d = ["Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Score", "ScoreWithoutQ6"]
|
| 165 |
+
mapping_eq5d = model_h.define_mapping_for_calcs(
|
| 166 |
+
"EQ5D", question_names_eq5d, prefixes=["Avg", "Prev"]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Calculate average PRO score across various time windows (months) prior to IndexDate,
|
| 170 |
+
# ignoring the latest PRO score
|
| 171 |
+
for time_window in range(1, 7):
|
| 172 |
+
pro_eq5d = model_h.calc_pro_average(pro_eq5d, "EQ5D", time_window=time_window)
|
| 173 |
+
for key in mapping_eq5d:
|
| 174 |
+
model_h.calc_diff_pro_scores(
|
| 175 |
+
pro_eq5d, "EQ5D", key, mapping_eq5d[key][0], time_window=time_window
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Calculate difference between latest PRO score and PRO score prior to the latest
|
| 179 |
+
for key in mapping_eq5d:
|
| 180 |
+
model_h.calc_diff_pro_scores(pro_eq5d, "EQ5D", key, mapping_eq5d[key][1])
|
| 181 |
+
|
| 182 |
+
# Remove unwanted columns and duplicates
|
| 183 |
+
pro_eq5d = pro_eq5d.loc[
|
| 184 |
+
:,
|
| 185 |
+
~(
|
| 186 |
+
pro_eq5d.columns.str.startswith("Avg")
|
| 187 |
+
| pro_eq5d.columns.str.startswith("EQ5D")
|
| 188 |
+
| pro_eq5d.columns.str.startswith("Prev")
|
| 189 |
+
| pro_eq5d.columns.str.startswith("Response")
|
| 190 |
+
),
|
| 191 |
+
]
|
| 192 |
+
pro_eq5d = pro_eq5d.drop(
|
| 193 |
+
columns=[
|
| 194 |
+
"Score",
|
| 195 |
+
"SubmissionTime",
|
| 196 |
+
"FirstSubmissionDate",
|
| 197 |
+
"TimeInService",
|
| 198 |
+
"TimeSinceSubmission",
|
| 199 |
+
"LatestPredictionDate",
|
| 200 |
+
"LatestPRODate",
|
| 201 |
+
]
|
| 202 |
+
)
|
| 203 |
+
pro_eq5d = pro_eq5d.drop_duplicates()
|
| 204 |
+
|
| 205 |
+
############################################
|
| 206 |
+
# Weekly PROs - MRC
|
| 207 |
+
############################################
|
| 208 |
+
pro_mrc = pd.read_csv("<YOUR_DATA_PATH>/copd-dataset/CopdDatasetProMrc.txt", delimiter="|")
|
| 209 |
+
pro_mrc = pro_mrc.merge(
|
| 210 |
+
data[["StudyId", "IndexDate", "FirstSubmissionDate", "LatestPredictionDate"]],
|
| 211 |
+
on="StudyId",
|
| 212 |
+
how="inner",
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Calculate engagement over service
|
| 216 |
+
pro_mrc = model_h.calc_total_pro_engagement(pro_mrc, "MRC")
|
| 217 |
+
|
| 218 |
+
# Calculate engagement over multiple time windows
|
| 219 |
+
for time_window in range(1, 7):
|
| 220 |
+
pro_mrc_engagement = model_h.calc_pro_engagement_in_time_window(
|
| 221 |
+
pro_mrc, "MRC", time_window=time_window, data=data
|
| 222 |
+
)
|
| 223 |
+
pro_mrc = pro_mrc.merge(pro_mrc_engagement, on=["StudyId", "IndexDate"], how="left")
|
| 224 |
+
|
| 225 |
+
# Calculate last PRO score
|
| 226 |
+
pro_mrc = model_h.calc_last_pro_score(pro_mrc, "MRC")
|
| 227 |
+
|
| 228 |
+
# Calculate the PRO score prior to the last PRO score
|
| 229 |
+
pro_mrc = model_h.calc_pro_score_prior_to_latest(pro_mrc, "MRC")
|
| 230 |
+
|
| 231 |
+
#############################
|
| 232 |
+
# Scores across time windows
|
| 233 |
+
#############################
|
| 234 |
+
# Mapping to calculate the difference between the latest PRO scores and both the average
|
| 235 |
+
# and previous PRO score
|
| 236 |
+
question_names_mrc = ["Q1"]
|
| 237 |
+
mapping_mrc = model_h.define_mapping_for_calcs(
|
| 238 |
+
"MRC", question_names_mrc, prefixes=["Avg", "Prev"]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Calculate average PRO score across various time windows (months) prior to IndexDate,
|
| 242 |
+
# ignoring the latest PRO score
|
| 243 |
+
for time_window in range(1, 7):
|
| 244 |
+
pro_mrc = model_h.calc_pro_average(pro_mrc, "MRC", time_window=time_window)
|
| 245 |
+
for key in mapping_mrc:
|
| 246 |
+
model_h.calc_diff_pro_scores(
|
| 247 |
+
pro_mrc, "MRC", key, mapping_mrc[key][0], time_window=time_window
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Calculate difference between latest PRO score and PRO score prior to the latest
|
| 251 |
+
for key in mapping_mrc:
|
| 252 |
+
model_h.calc_diff_pro_scores(pro_mrc, "MRC", key, mapping_mrc[key][1])
|
| 253 |
+
|
| 254 |
+
# Remove unwanted columns and duplicates
|
| 255 |
+
pro_mrc = pro_mrc.loc[
|
| 256 |
+
:,
|
| 257 |
+
~(
|
| 258 |
+
pro_mrc.columns.str.startswith("Avg")
|
| 259 |
+
| pro_mrc.columns.str.startswith("MRC")
|
| 260 |
+
| pro_mrc.columns.str.startswith("Prev")
|
| 261 |
+
| pro_mrc.columns.str.startswith("Response")
|
| 262 |
+
),
|
| 263 |
+
]
|
| 264 |
+
pro_mrc = pro_mrc.drop(
|
| 265 |
+
columns=[
|
| 266 |
+
"SubmissionTime",
|
| 267 |
+
"Score",
|
| 268 |
+
"FirstSubmissionDate",
|
| 269 |
+
"TimeInService",
|
| 270 |
+
"TimeSinceSubmission",
|
| 271 |
+
"LatestPredictionDate",
|
| 272 |
+
"LatestPRODate",
|
| 273 |
+
]
|
| 274 |
+
)
|
| 275 |
+
pro_mrc = pro_mrc.drop_duplicates()
|
| 276 |
+
|
| 277 |
+
############################################
|
| 278 |
+
# Daily PROs - CAT
|
| 279 |
+
############################################
|
| 280 |
+
pro_cat = pd.read_csv("<YOUR_DATA_PATH>/copd-dataset/CopdDatasetProCat.txt", delimiter="|")
|
| 281 |
+
pro_cat = pro_cat.merge(
|
| 282 |
+
data[["StudyId", "IndexDate", "FirstSubmissionDate", "LatestPredictionDate"]],
|
| 283 |
+
on="StudyId",
|
| 284 |
+
how="inner",
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Calculate engagement over service and 1 month prior to index date
|
| 288 |
+
pro_cat = model_h.calc_total_pro_engagement(pro_cat, "CAT")
|
| 289 |
+
|
| 290 |
+
# Calculate engagement over multiple time windows
|
| 291 |
+
for time_window in range(1, 7):
|
| 292 |
+
pro_cat_engagement = model_h.calc_pro_engagement_in_time_window(
|
| 293 |
+
pro_cat, "CAT", time_window=time_window, data=data
|
| 294 |
+
)
|
| 295 |
+
pro_cat = pro_cat.merge(pro_cat_engagement, on=["StudyId", "IndexDate"], how="left")
|
| 296 |
+
|
| 297 |
+
# Calculate PRO average for the week before the index date
|
| 298 |
+
pro_cat = model_h.calc_pro_average(pro_cat, "CAT", avg_period="WeeklyAvg")
|
| 299 |
+
|
| 300 |
+
# Calculate PRO average for the week before most recent week to the index date
|
| 301 |
+
pro_cat = model_h.calc_pro_average(pro_cat, "CAT", avg_period="WeekPrevAvg")
|
| 302 |
+
|
| 303 |
+
#############################
|
| 304 |
+
# Scores across time windows
|
| 305 |
+
#############################
|
| 306 |
+
# Mapping to calculate the difference between the latest PRO scores and both the average
|
| 307 |
+
# and previous PRO score
|
| 308 |
+
question_names_cat = ["Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Q7", "Q8", "Score"]
|
| 309 |
+
mapping_cat = model_h.define_mapping_for_calcs(
|
| 310 |
+
"CAT", question_names_cat, prefixes=["LongerAvg", "WeekPrevAvg"]
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Calculate average PRO score across various time windows (months) prior to IndexDate,
|
| 314 |
+
# ignoring the latest PRO score
|
| 315 |
+
for time_window in range(1, 7):
|
| 316 |
+
pro_cat = model_h.calc_pro_average(
|
| 317 |
+
pro_cat, "CAT", time_window=time_window, avg_period="LongerAvg"
|
| 318 |
+
)
|
| 319 |
+
for key in mapping_cat:
|
| 320 |
+
model_h.calc_diff_pro_scores(
|
| 321 |
+
pro_cat, "CAT", key, mapping_cat[key][0], time_window=time_window
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Calculate difference between latest PRO score and PRO score prior to the latest
|
| 325 |
+
for key in mapping_cat:
|
| 326 |
+
model_h.calc_diff_pro_scores(pro_cat, "CAT", key, mapping_cat[key][1])
|
| 327 |
+
|
| 328 |
+
# Remove unwanted columns and duplicates
|
| 329 |
+
pro_cat = pro_cat.loc[
|
| 330 |
+
:,
|
| 331 |
+
~(
|
| 332 |
+
pro_cat.columns.str.startswith("WeekPrev")
|
| 333 |
+
| pro_cat.columns.str.startswith("Longer")
|
| 334 |
+
| pro_cat.columns.str.startswith("CAT")
|
| 335 |
+
| pro_cat.columns.str.startswith("Response")
|
| 336 |
+
),
|
| 337 |
+
]
|
| 338 |
+
pro_cat = pro_cat.drop(
|
| 339 |
+
columns=[
|
| 340 |
+
"Score",
|
| 341 |
+
"SubmissionTime",
|
| 342 |
+
"FirstSubmissionDate",
|
| 343 |
+
"LatestPredictionDate",
|
| 344 |
+
"TimeInService",
|
| 345 |
+
"AvgStartDate",
|
| 346 |
+
"WeekStartDate",
|
| 347 |
+
]
|
| 348 |
+
)
|
| 349 |
+
pro_cat = pro_cat.drop_duplicates()
|
| 350 |
+
|
| 351 |
+
############################################
|
| 352 |
+
# Daily PROs - Symptom Diary
|
| 353 |
+
############################################
|
| 354 |
+
# Symptom diary have some questions that are numeric and some that are categorical
|
| 355 |
+
pro_sd = pd.read_csv(
|
| 356 |
+
"<YOUR_DATA_PATH>/copd-dataset/CopdDatasetProSymptomDiary.txt", delimiter="|"
|
| 357 |
+
)
|
| 358 |
+
pro_sd = pro_sd.merge(
|
| 359 |
+
data[["StudyId", "IndexDate", "FirstSubmissionDate", "LatestPredictionDate"]],
|
| 360 |
+
on="StudyId",
|
| 361 |
+
how="inner",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Calculate engagement over service
|
| 365 |
+
pro_sd = model_h.calc_total_pro_engagement(pro_sd, "SymptomDiary")
|
| 366 |
+
pro_sd_engagement = pro_sd[
|
| 367 |
+
["StudyId", "PatientId", "IndexDate", "TotalEngagementSymptomDiary"]
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
# Calculate engagement over multiple time windows
|
| 371 |
+
for time_window in range(1, 7):
|
| 372 |
+
pro_sd_engagement_tw = model_h.calc_pro_engagement_in_time_window(
|
| 373 |
+
pro_sd, "SymptomDiary", time_window=time_window, data=data
|
| 374 |
+
)
|
| 375 |
+
pro_sd_engagement = pro_sd_engagement.merge(
|
| 376 |
+
pro_sd_engagement_tw, on=["StudyId", "IndexDate"], how="left"
|
| 377 |
+
)
|
| 378 |
+
pro_sd_engagement = pro_sd_engagement.drop_duplicates()
|
| 379 |
+
|
| 380 |
+
###############################
|
| 381 |
+
# Numeric questions
|
| 382 |
+
# (Q1, Q2)
|
| 383 |
+
# Q3 included for comparison
|
| 384 |
+
###############################
|
| 385 |
+
# Calculate PRO average for the week before the index date
|
| 386 |
+
pro_sd_numeric = pro_sd[
|
| 387 |
+
[
|
| 388 |
+
"StudyId",
|
| 389 |
+
"PatientId",
|
| 390 |
+
"IndexDate",
|
| 391 |
+
"SubmissionTime",
|
| 392 |
+
"Score",
|
| 393 |
+
"SymptomDiaryQ1",
|
| 394 |
+
"SymptomDiaryQ2",
|
| 395 |
+
"SymptomDiaryQ3",
|
| 396 |
+
]
|
| 397 |
+
]
|
| 398 |
+
pro_sd_numeric = model_h.calc_pro_average(
|
| 399 |
+
pro_sd_numeric, "SymptomDiary", avg_period="WeeklyAvg"
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Calculate PRO average for the week before most recent week to the index date
|
| 403 |
+
pro_sd_numeric = model_h.calc_pro_average(
|
| 404 |
+
pro_sd_numeric, "SymptomDiary", avg_period="WeekPrevAvg"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
#############################
|
| 408 |
+
# Scores across time windows
|
| 409 |
+
#############################
|
| 410 |
+
# Mapping to calculate the difference between the latest PRO scores and both the average
|
| 411 |
+
# and previous PRO score
|
| 412 |
+
question_names_sd = ["Q1", "Q2", "Q3"]
|
| 413 |
+
mapping_sd = model_h.define_mapping_for_calcs(
|
| 414 |
+
"SymptomDiary", question_names_sd, prefixes=["LongerAvg", "WeekPrevAvg"]
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# Calculate average PRO score across various time windows (months) prior to IndexDate,
|
| 418 |
+
# ignoring the latest PRO score
|
| 419 |
+
for time_window in range(1, 7):
|
| 420 |
+
pro_sd_numeric = model_h.calc_pro_average(
|
| 421 |
+
pro_sd_numeric, "SymptomDiary", time_window=time_window, avg_period="LongerAvg"
|
| 422 |
+
)
|
| 423 |
+
for key in mapping_sd:
|
| 424 |
+
model_h.calc_diff_pro_scores(
|
| 425 |
+
pro_sd_numeric,
|
| 426 |
+
"SymptomDiary",
|
| 427 |
+
key,
|
| 428 |
+
mapping_sd[key][0],
|
| 429 |
+
time_window=time_window,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Calculate difference between latest PRO score and PRO score prior to the latest week
|
| 433 |
+
for key in mapping_sd:
|
| 434 |
+
model_h.calc_diff_pro_scores(pro_sd_numeric, "SymptomDiary", key, mapping_sd[key][1])
|
| 435 |
+
|
| 436 |
+
###############################
|
| 437 |
+
# Binary questions
|
| 438 |
+
# (Q3)
|
| 439 |
+
###############################
|
| 440 |
+
# Calculate sum of binary values across previous months
|
| 441 |
+
sd_sum_all = pro_sd_numeric[["StudyId", "IndexDate"]]
|
| 442 |
+
sd_sum_all = sd_sum_all.drop_duplicates()
|
| 443 |
+
for time_window in range(1, 7):
|
| 444 |
+
sd_sum = model_h.calc_sum_binary_vals(
|
| 445 |
+
pro_sd_numeric, binary_cols=["SymptomDiaryQ3"], time_window=time_window
|
| 446 |
+
)
|
| 447 |
+
sd_sum_all = sd_sum_all.merge(sd_sum, on=["StudyId", "IndexDate"], how="left")
|
| 448 |
+
|
| 449 |
+
# Scale sums by how often patients responded
|
| 450 |
+
sd_sum_all = sd_sum_all.merge(
|
| 451 |
+
pro_sd_engagement, on=["StudyId", "IndexDate"], how="left"
|
| 452 |
+
)
|
| 453 |
+
mapping_scaling = {}
|
| 454 |
+
for time_window in range(1, 7):
|
| 455 |
+
mapping_scaling["SumSymptomDiaryQ3TW" + str(time_window)] = (
|
| 456 |
+
"EngagementSymptomDiaryTW" + str(time_window)
|
| 457 |
+
)
|
| 458 |
+
for key in mapping_scaling:
|
| 459 |
+
model_h.scale_sum_to_response_rate(sd_sum_all, key, mapping_scaling[key])
|
| 460 |
+
|
| 461 |
+
# Combine numeric and binary dfs
|
| 462 |
+
pro_sd_full = pro_sd_numeric.merge(
|
| 463 |
+
sd_sum_all, on=["StudyId", "PatientId", "IndexDate"], how="left"
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Remove unwanted columns from numeric df
|
| 467 |
+
pro_sd_full = pro_sd_full.loc[
|
| 468 |
+
:,
|
| 469 |
+
~(
|
| 470 |
+
pro_sd_full.columns.str.startswith("WeekPrev")
|
| 471 |
+
| pro_sd_full.columns.str.startswith("Longer")
|
| 472 |
+
| pro_sd_full.columns.str.startswith("Symptom")
|
| 473 |
+
| pro_sd_full.columns.str.startswith("Sum")
|
| 474 |
+
| pro_sd_full.columns.str.startswith("Response")
|
| 475 |
+
),
|
| 476 |
+
]
|
| 477 |
+
pro_sd_full = pro_sd_full.drop(
|
| 478 |
+
columns=[
|
| 479 |
+
"Score",
|
| 480 |
+
"SubmissionTime",
|
| 481 |
+
"AvgStartDate",
|
| 482 |
+
"TimeWindowStartDate",
|
| 483 |
+
"WeekStartDate",
|
| 484 |
+
]
|
| 485 |
+
)
|
| 486 |
+
pro_sd_full = pro_sd_full.drop_duplicates()
|
| 487 |
+
|
| 488 |
+
############################################################################
|
| 489 |
+
# Combine PROs with main df
|
| 490 |
+
############################################################################
|
| 491 |
+
data = data.merge(pro_eq5d, on=["StudyId", "PatientId", "IndexDate"], how="left")
|
| 492 |
+
data = data.merge(pro_mrc, on=["StudyId", "PatientId", "IndexDate"], how="left")
|
| 493 |
+
data = data.merge(pro_cat, on=["StudyId", "PatientId", "IndexDate"], how="left")
|
| 494 |
+
data = data.merge(pro_sd_full, on=["StudyId", "PatientId", "IndexDate"], how="left")
|
| 495 |
+
|
| 496 |
+
# Calculate mean for features grouped by outcome
|
| 497 |
+
feat_to_explore = data.loc[:, "TotalEngagementEQ5D":"ScaledSumSymptomDiaryQ3TW6"]
|
| 498 |
+
feat_to_explore.loc[:, "ExacWithin3Months"] = data.loc[:, "ExacWithin3Months"]
|
| 499 |
+
grouped_data_by_outcome = feat_to_explore.groupby("ExacWithin3Months").mean()
|
| 500 |
+
grouped_data_by_outcome = grouped_data_by_outcome.T
|
| 501 |
+
|
| 502 |
+
############################################################################
|
| 503 |
+
# Plotting
|
| 504 |
+
############################################################################
|
| 505 |
+
##############################
|
| 506 |
+
# EQ5D Boxplots
|
| 507 |
+
##############################
|
| 508 |
+
# Plotting score values
|
| 509 |
+
cols_for_plotting = model_h.create_cols_for_plotting(
|
| 510 |
+
"EQ5D", question_col_names=question_names_eq5d
|
| 511 |
+
)
|
| 512 |
+
for key in cols_for_plotting:
|
| 513 |
+
model_h.plot_feature_signal(
|
| 514 |
+
data,
|
| 515 |
+
nrows=3,
|
| 516 |
+
ncols=3,
|
| 517 |
+
figsize=(12, 12),
|
| 518 |
+
cols_to_plot=cols_for_plotting[key],
|
| 519 |
+
fig_name=key + "_boxplot",
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Plotting engagement
|
| 523 |
+
cols_for_plotting = model_h.create_cols_for_plotting("EQ5D", var_engagement=True)
|
| 524 |
+
for key in cols_for_plotting:
|
| 525 |
+
model_h.plot_feature_signal(
|
| 526 |
+
data,
|
| 527 |
+
nrows=2,
|
| 528 |
+
ncols=3,
|
| 529 |
+
figsize=(12, 12),
|
| 530 |
+
cols_to_plot=cols_for_plotting[key],
|
| 531 |
+
fig_name=key + "_engagement",
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
##############################
|
| 535 |
+
# MRC Boxplots
|
| 536 |
+
##############################
|
| 537 |
+
# Plotting score values
|
| 538 |
+
cols_for_plotting = model_h.create_cols_for_plotting(
|
| 539 |
+
"MRC", question_col_names=question_names_mrc
|
| 540 |
+
)
|
| 541 |
+
for key in cols_for_plotting:
|
| 542 |
+
model_h.plot_feature_signal(
|
| 543 |
+
data,
|
| 544 |
+
nrows=3,
|
| 545 |
+
ncols=3,
|
| 546 |
+
figsize=(12, 12),
|
| 547 |
+
cols_to_plot=cols_for_plotting[key],
|
| 548 |
+
fig_name=key + "_boxplot",
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# Plotting engagement
|
| 552 |
+
cols_for_plotting = model_h.create_cols_for_plotting("MRC", var_engagement=True)
|
| 553 |
+
for key in cols_for_plotting:
|
| 554 |
+
model_h.plot_feature_signal(
|
| 555 |
+
data,
|
| 556 |
+
nrows=2,
|
| 557 |
+
ncols=3,
|
| 558 |
+
figsize=(12, 12),
|
| 559 |
+
cols_to_plot=cols_for_plotting[key],
|
| 560 |
+
fig_name=key + "_engagement",
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
##############################
|
| 564 |
+
# CAT Boxplots
|
| 565 |
+
##############################
|
| 566 |
+
# Plotting score values
|
| 567 |
+
cols_for_plotting = model_h.create_cols_for_plotting(
|
| 568 |
+
"CAT", question_col_names=question_names_cat
|
| 569 |
+
)
|
| 570 |
+
for key in cols_for_plotting:
|
| 571 |
+
model_h.plot_feature_signal(
|
| 572 |
+
data,
|
| 573 |
+
nrows=3,
|
| 574 |
+
ncols=3,
|
| 575 |
+
figsize=(12, 12),
|
| 576 |
+
cols_to_plot=cols_for_plotting[key],
|
| 577 |
+
fig_name=key + "_boxplot",
|
| 578 |
+
)
|
| 579 |
+
# Plotting engagement
|
| 580 |
+
cols_for_plotting = model_h.create_cols_for_plotting("CAT", var_engagement=True)
|
| 581 |
+
for key in cols_for_plotting:
|
| 582 |
+
model_h.plot_feature_signal(
|
| 583 |
+
data,
|
| 584 |
+
nrows=2,
|
| 585 |
+
ncols=3,
|
| 586 |
+
figsize=(12, 12),
|
| 587 |
+
cols_to_plot=cols_for_plotting[key],
|
| 588 |
+
fig_name=key + "_engagement",
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
##############################
|
| 592 |
+
# SymptomDiary Boxplots
|
| 593 |
+
##############################
|
| 594 |
+
# Plotting score values
|
| 595 |
+
cols_for_plotting = model_h.create_cols_for_plotting(
|
| 596 |
+
"SymptomDiary", question_col_names=question_names_sd
|
| 597 |
+
)
|
| 598 |
+
for key in cols_for_plotting:
|
| 599 |
+
model_h.plot_feature_signal(
|
| 600 |
+
data,
|
| 601 |
+
nrows=3,
|
| 602 |
+
ncols=3,
|
| 603 |
+
figsize=(12, 12),
|
| 604 |
+
cols_to_plot=cols_for_plotting[key],
|
| 605 |
+
fig_name=key + "_boxplot",
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
# Plotting engagement
|
| 609 |
+
cols_for_plotting = model_h.create_cols_for_plotting("SymptomDiary", var_engagement=True)
|
| 610 |
+
for key in cols_for_plotting:
|
| 611 |
+
model_h.plot_feature_signal(
|
| 612 |
+
data,
|
| 613 |
+
nrows=2,
|
| 614 |
+
ncols=3,
|
| 615 |
+
figsize=(12, 12),
|
| 616 |
+
cols_to_plot=cols_for_plotting[key],
|
| 617 |
+
fig_name=key + "_engagement",
|
| 618 |
+
)
|
training/setup_labels_forward_val.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script uses both hospital and community exacerbation events.
|
| 3 |
+
|
| 4 |
+
Collate all hospital, patient reported events and apply PRO LOGIC to determine the number
|
| 5 |
+
of exacerbation events. Use exacerbation events to determine the number of rows required per
|
| 6 |
+
patient in the data and generate random index dates and setup labels. Data starts at July
|
| 7 |
+
2022 and runs until Dec 2023 and will be used for forward validation of the model.
|
| 8 |
+
"""
|
| 9 |
+
import model_h
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from datetime import timedelta
|
| 16 |
+
import random
|
| 17 |
+
import yaml
|
| 18 |
+
|
| 19 |
+
with open("./training/config.yaml", "r") as config:
|
| 20 |
+
config = yaml.safe_load(config)
|
| 21 |
+
|
| 22 |
+
# Setup log file
|
| 23 |
+
log = open(os.path.join(config['outputs']['logging_dir'], "setup_labels_hosp_comm.log"), "w")
|
| 24 |
+
sys.stdout = log
|
| 25 |
+
|
| 26 |
+
############################################################################
|
| 27 |
+
# Define model cohort and training data windows
|
| 28 |
+
############################################################################
|
| 29 |
+
|
| 30 |
+
# Read relevant info from patient details
|
| 31 |
+
patient_details = pd.read_csv(
|
| 32 |
+
config['inputs']['raw_data_paths']['patient_details'],
|
| 33 |
+
usecols=[
|
| 34 |
+
"PatientId",
|
| 35 |
+
"FirstSubmissionDate",
|
| 36 |
+
"MostRecentSubmissionDate",
|
| 37 |
+
"DateOfBirth",
|
| 38 |
+
"Sex",
|
| 39 |
+
"StudyId",
|
| 40 |
+
],
|
| 41 |
+
delimiter="|",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Select patients for inclusion (those with up to date events in service)
|
| 45 |
+
# Original RECEIVER cohort study id list
|
| 46 |
+
receiver_patients = ["RC{:02d}".format(i) for i in range(1, 85)]
|
| 47 |
+
# This patient needs removing
|
| 48 |
+
receiver_patients.remove("RC34")
|
| 49 |
+
|
| 50 |
+
# Scale up patients (subset)
|
| 51 |
+
scaleup_patients = ["SU{:02d}".format(i) for i in range(1, 219)]
|
| 52 |
+
# scaleup_patients.append('SU287') #Removed as study ID contains 2 patients
|
| 53 |
+
|
| 54 |
+
# List of all valid patients for modelling
|
| 55 |
+
valid_patients = receiver_patients + scaleup_patients
|
| 56 |
+
|
| 57 |
+
# Filter for valid patients accounting for white spaces in StudyId (e.g. RC 26 and RC 52)
|
| 58 |
+
patient_details = patient_details[
|
| 59 |
+
patient_details.StudyId.str.replace(" ", "").isin(valid_patients)
|
| 60 |
+
]
|
| 61 |
+
# Select only non null entries in patient data start/end dates
|
| 62 |
+
patient_details = patient_details[
|
| 63 |
+
(patient_details.FirstSubmissionDate.notna())
|
| 64 |
+
& (patient_details.MostRecentSubmissionDate.notna())
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
# Create a column stating the earliest permitted date for forward validation
|
| 68 |
+
patient_details["EarliestIndexDate"] = config['model_settings']['forward_validation_earliest_date']
|
| 69 |
+
|
| 70 |
+
# Create a column stating the latest date permitted based on events added to service data
|
| 71 |
+
patient_details["LatestPredictionDate"] = config['model_settings']['forward_validation_latest_date']
|
| 72 |
+
|
| 73 |
+
date_cols = ["FirstSubmissionDate", "MostRecentSubmissionDate", "LatestPredictionDate", "EarliestIndexDate"]
|
| 74 |
+
patient_details[date_cols] = patient_details[date_cols].apply(
|
| 75 |
+
lambda x: pd.to_datetime(x, utc=True, format="mixed").dt.normalize(), axis=1
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Choose the earlier date out of the patient's last submission and the latest COPD data
|
| 79 |
+
# events
|
| 80 |
+
patient_details["LatestPredictionDate"] = patient_details[
|
| 81 |
+
["MostRecentSubmissionDate", "LatestPredictionDate"]
|
| 82 |
+
].min(axis=1)
|
| 83 |
+
|
| 84 |
+
# Calculate the latest date that the index date can be for each patient
|
| 85 |
+
patient_details["LatestIndexDate"] = patient_details[
|
| 86 |
+
"LatestPredictionDate"
|
| 87 |
+
] - pd.DateOffset(days=config['model_settings']['prediction_window'])
|
| 88 |
+
|
| 89 |
+
# Add 6 months to start of data window to allow enough of a lookback period
|
| 90 |
+
patient_details["EarliestDataDate"] = patient_details[
|
| 91 |
+
"EarliestIndexDate"
|
| 92 |
+
] - pd.DateOffset(days=config['model_settings']['lookback_period'])
|
| 93 |
+
|
| 94 |
+
# Remove any patients for whom the index start date overlaps the last index
|
| 95 |
+
# date, i.e. they have too short a window of data
|
| 96 |
+
print("Number of total patients", len(patient_details))
|
| 97 |
+
print(
|
| 98 |
+
"Number of patients with too short of a window of data:",
|
| 99 |
+
len(
|
| 100 |
+
patient_details[
|
| 101 |
+
patient_details["EarliestIndexDate"] > patient_details["LatestIndexDate"]
|
| 102 |
+
]
|
| 103 |
+
),
|
| 104 |
+
)
|
| 105 |
+
patient_details = patient_details[
|
| 106 |
+
patient_details["EarliestIndexDate"] < patient_details["LatestIndexDate"]
|
| 107 |
+
]
|
| 108 |
+
patient_details.to_pickle("./data/patient_details_forward_val.pkl")
|
| 109 |
+
|
| 110 |
+
# List of remaining patients
|
| 111 |
+
model_patients = list(patient_details.PatientId.unique())
|
| 112 |
+
model_study_ids = list(patient_details.StudyId.unique())
|
| 113 |
+
|
| 114 |
+
print(
|
| 115 |
+
"Model cohort: {} patients. {} RECEIVER and {} SU".format(
|
| 116 |
+
len(model_patients),
|
| 117 |
+
len(patient_details[patient_details["StudyId"].str.startswith("RC")]),
|
| 118 |
+
len(patient_details[patient_details["StudyId"].str.startswith("SU")]),
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
df = patient_details[
|
| 123 |
+
[
|
| 124 |
+
"PatientId",
|
| 125 |
+
"DateOfBirth",
|
| 126 |
+
"Sex",
|
| 127 |
+
"StudyId",
|
| 128 |
+
"EarliestDataDate",
|
| 129 |
+
"EarliestIndexDate",
|
| 130 |
+
"LatestIndexDate",
|
| 131 |
+
"LatestPredictionDate",
|
| 132 |
+
]
|
| 133 |
+
].copy()
|
| 134 |
+
|
| 135 |
+
# Create a row per day between the EarliestDataDate and the LatestPredictionDate
|
| 136 |
+
df["DateOfEvent"] = df.apply(
|
| 137 |
+
lambda x: pd.date_range(x.EarliestDataDate, x.LatestPredictionDate, freq="D"),
|
| 138 |
+
axis=1,
|
| 139 |
+
)
|
| 140 |
+
df = df.explode("DateOfEvent").reset_index(drop=True)
|
| 141 |
+
|
| 142 |
+
############################################################################
|
| 143 |
+
# Extract hospital exacerbations and admissions from COPD service data
|
| 144 |
+
############################################################################
|
| 145 |
+
|
| 146 |
+
# Contains exacerbations among other event types
|
| 147 |
+
patient_events = pd.read_csv(
|
| 148 |
+
config['inputs']['raw_data_paths']['patient_events'],
|
| 149 |
+
delimiter="|",
|
| 150 |
+
usecols=["PatientId", "DateOfEvent", "EventType"],
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Filter for only patients in model cohort
|
| 154 |
+
patient_events = patient_events[patient_events.PatientId.isin(model_patients)]
|
| 155 |
+
|
| 156 |
+
# Identify hospital exacerbation events
|
| 157 |
+
patient_events["IsHospExac"] = model_h.define_service_exac_event(
|
| 158 |
+
events=patient_events.EventType, include_community=False
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Identify hospital admissions (all causes)
|
| 162 |
+
patient_events["IsHospAdmission"] = model_h.define_hospital_admission(
|
| 163 |
+
patient_events.EventType
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
admissions = patient_events[patient_events.IsHospAdmission == 1][
|
| 167 |
+
["PatientId", "DateOfEvent", "IsHospAdmission"]
|
| 168 |
+
]
|
| 169 |
+
hosp_exacs = patient_events[patient_events.IsHospExac == 1][
|
| 170 |
+
["PatientId", "DateOfEvent", "IsHospExac"]
|
| 171 |
+
]
|
| 172 |
+
admissions["DateOfEvent"] = pd.to_datetime(
|
| 173 |
+
admissions.DateOfEvent, utc=True
|
| 174 |
+
).dt.normalize()
|
| 175 |
+
hosp_exacs["DateOfEvent"] = pd.to_datetime(
|
| 176 |
+
hosp_exacs.DateOfEvent, utc=True
|
| 177 |
+
).dt.normalize()
|
| 178 |
+
|
| 179 |
+
hosp_exacs = hosp_exacs.drop_duplicates()
|
| 180 |
+
admissions = admissions.drop_duplicates()
|
| 181 |
+
|
| 182 |
+
# Save hospital exacerbations and admissions data
|
| 183 |
+
hosp_exacs.to_pickle("./data/hospital_exacerbations.pkl")
|
| 184 |
+
admissions.to_pickle("./data/hospital_admissions.pkl")
|
| 185 |
+
|
| 186 |
+
############################################################################
|
| 187 |
+
# Extract patient reported exacerbation events
|
| 188 |
+
############################################################################
|
| 189 |
+
|
| 190 |
+
########################
|
| 191 |
+
# Data post Q5 change
|
| 192 |
+
#######################
|
| 193 |
+
|
| 194 |
+
# Read file containing patient reported events (not patient_events because it contains
|
| 195 |
+
# the dates when patients answered PROs and not which date they reported as having taken
|
| 196 |
+
# their rescue meds)
|
| 197 |
+
symptom_diary = pd.read_csv(
|
| 198 |
+
config['inputs']['raw_data_paths']['pro_symptom_diary'],
|
| 199 |
+
usecols=[
|
| 200 |
+
"PatientId",
|
| 201 |
+
"StudyId",
|
| 202 |
+
"Score",
|
| 203 |
+
"SubmissionTime",
|
| 204 |
+
"SymptomDiaryQ5",
|
| 205 |
+
"SymptomDiaryQ11a",
|
| 206 |
+
"SymptomDiaryQ11b",
|
| 207 |
+
],
|
| 208 |
+
delimiter="|",
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
Q5ChangeDate = pd.to_datetime(config['model_settings']['pro_q5_change_date'], utc=True)
|
| 212 |
+
symptom_diary = model_h.filter_symptom_diary(
|
| 213 |
+
df=symptom_diary, date_cutoff=Q5ChangeDate, patients=model_patients
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
weekly_pros = model_h.get_rescue_med_pro_responses(symptom_diary)
|
| 217 |
+
weekly_pros = model_h.set_pro_exac_dates(weekly_pros)
|
| 218 |
+
weekly_pros = weekly_pros[
|
| 219 |
+
[
|
| 220 |
+
"PatientId",
|
| 221 |
+
"Q5Answered",
|
| 222 |
+
"NegativeQ5",
|
| 223 |
+
"IsCommExac",
|
| 224 |
+
"DateOfEvent",
|
| 225 |
+
"ExacDateUnknown",
|
| 226 |
+
]
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
####################################################################################
|
| 230 |
+
# Merge hospital and patient reported events with daily patient records
|
| 231 |
+
#
|
| 232 |
+
# Exacerbations occurring in Lenus service period include verified clinician events
|
| 233 |
+
# pre-April 2021 (after onboarding) and community exacerbations recorded in weekly
|
| 234 |
+
# PROs post-April 2021. Hospital exacerbations include exacerbations occuring during
|
| 235 |
+
# service period.
|
| 236 |
+
#####################################################################################
|
| 237 |
+
|
| 238 |
+
# Patient reported, clinician verified
|
| 239 |
+
#df = df.merge(verified_exacs, on=["StudyId", "DateOfEvent"], how="left")
|
| 240 |
+
|
| 241 |
+
# Patient reported, new rescue med PRO (April 2021 onwards)
|
| 242 |
+
df = df.merge(weekly_pros, on=["PatientId", "DateOfEvent"], how="left")
|
| 243 |
+
|
| 244 |
+
# Hospital exacerbations
|
| 245 |
+
df = df.merge(hosp_exacs, on=["PatientId", "DateOfEvent"], how="left")
|
| 246 |
+
df = model_h.fill_column_by_patient(df=df, id_col="PatientId", col="StudyId")
|
| 247 |
+
|
| 248 |
+
# Hospital admissions
|
| 249 |
+
df = df.merge(admissions, on=["PatientId", "DateOfEvent"], how="left")
|
| 250 |
+
df = model_h.fill_column_by_patient(df=df, id_col="PatientId", col="StudyId")
|
| 251 |
+
|
| 252 |
+
# Column for whether an exacerbation of any kind occurred on each date. To be filtered
|
| 253 |
+
# using (PRO) LOGIC
|
| 254 |
+
df["IsExac"] = np.where((df.IsCommExac == 1) | (df.IsHospExac == 1), 1, 0)
|
| 255 |
+
|
| 256 |
+
# Resample the df to one day per patient starting from the earliest record
|
| 257 |
+
df = (
|
| 258 |
+
df.set_index("DateOfEvent")
|
| 259 |
+
.groupby("StudyId")
|
| 260 |
+
.resample("D")
|
| 261 |
+
.asfreq()
|
| 262 |
+
.drop("StudyId", axis=1)
|
| 263 |
+
.reset_index()
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Infill binary cols with zero where applicable
|
| 267 |
+
df[
|
| 268 |
+
[
|
| 269 |
+
"Q5Answered",
|
| 270 |
+
"NegativeQ5",
|
| 271 |
+
"IsHospExac",
|
| 272 |
+
"IsCommExac",
|
| 273 |
+
"ExacDateUnknown",
|
| 274 |
+
"IsExac",
|
| 275 |
+
"IsHospAdmission",
|
| 276 |
+
]
|
| 277 |
+
] = df[
|
| 278 |
+
[
|
| 279 |
+
"Q5Answered",
|
| 280 |
+
"NegativeQ5",
|
| 281 |
+
"IsHospExac",
|
| 282 |
+
"IsCommExac",
|
| 283 |
+
"ExacDateUnknown",
|
| 284 |
+
"IsExac",
|
| 285 |
+
"IsHospAdmission",
|
| 286 |
+
]
|
| 287 |
+
].fillna(
|
| 288 |
+
0
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Infill some columns by StudyId to populate entire df
|
| 292 |
+
#df = copd.fill_column_by_patient(df=df, id_col="StudyId", col="FirstSubmissionDate")
|
| 293 |
+
df = model_h.fill_column_by_patient(df=df, id_col="StudyId", col="LatestPredictionDate")
|
| 294 |
+
df = model_h.fill_column_by_patient(df=df, id_col="StudyId", col="PatientId")
|
| 295 |
+
|
| 296 |
+
# Retain only dates before the end of each patient's data window
|
| 297 |
+
df = df[df.DateOfEvent <= df.LatestPredictionDate]
|
| 298 |
+
|
| 299 |
+
print("Starting number of exacerbations: {}".format(df.IsExac.sum()))
|
| 300 |
+
print(
|
| 301 |
+
"Number of exacerbations during COPD service: {}".format(
|
| 302 |
+
len(df[(df.IsExac == 1) & (df.DateOfEvent >= df.EarliestDataDate)])
|
| 303 |
+
)
|
| 304 |
+
)
|
| 305 |
+
print(
|
| 306 |
+
"Number of unique exacerbation patients: {}".format(
|
| 307 |
+
len(df[df.IsExac == 1].PatientId.unique())
|
| 308 |
+
)
|
| 309 |
+
)
|
| 310 |
+
print(
|
| 311 |
+
"Exacerbation breakdown: {} hospital, {} patient reported and {} overlapping".format(
|
| 312 |
+
df.IsHospExac.sum(),
|
| 313 |
+
df.IsCommExac.sum(),
|
| 314 |
+
len(df.loc[(df.IsCommExac == 1) & (df.IsHospExac == 1)]),
|
| 315 |
+
)
|
| 316 |
+
)
|
| 317 |
+
print(
|
| 318 |
+
"Number of hospital exacerbations during COPD service: {} ({} unique patients)".format(
|
| 319 |
+
len(df[(df.IsHospExac == 1) & (df.DateOfEvent >= df.EarliestDataDate)]),
|
| 320 |
+
len(
|
| 321 |
+
df[
|
| 322 |
+
(df.IsHospExac == 1) & (df.DateOfEvent >= df.EarliestDataDate)
|
| 323 |
+
].StudyId.unique()
|
| 324 |
+
),
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
print(
|
| 328 |
+
"Community exacerbations from weekly PROs: {} ({} unique patients)".format(
|
| 329 |
+
len(df[df.IsCommExac == 1]), len(df[df.IsCommExac == 1].StudyId.unique())
|
| 330 |
+
)
|
| 331 |
+
)
|
| 332 |
+
print(
|
| 333 |
+
"Number of patient reported exacerbations with unknown dates: {} ({} overlapping\
|
| 334 |
+
with hospital events)".format(
|
| 335 |
+
df.ExacDateUnknown.sum(),
|
| 336 |
+
len(df[(df.IsHospExac == 1) & (df.ExacDateUnknown == 1)]),
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Check for any patient reported events with unknown dates that occurred on the same day
|
| 341 |
+
# as a hospital event. Hospital events are trusted so set the date to known
|
| 342 |
+
df.loc[(df.IsCommExac == 1) & (df.IsHospExac == 1), "ExacDateUnknown"] = 0
|
| 343 |
+
print("Remaining exacerbations with unknown dates: {}".format(df.ExacDateUnknown.sum()))
|
| 344 |
+
|
| 345 |
+
############################################################################
|
| 346 |
+
# Implement PRO LOGIC on hospital and patient reported exacerbation events
|
| 347 |
+
############################################################################
|
| 348 |
+
|
| 349 |
+
# Define min and max days for PRO LOGIC. No predictions made or data used within
|
| 350 |
+
# logic_min_days after an exacerbation. Events falling between logic_min_days and
|
| 351 |
+
# logic_max_days after an event are subject to the weekly rescue med LOGIC criterion
|
| 352 |
+
logic_min_days = config['model_settings']['pro_logic_min_days_after_exac']
|
| 353 |
+
logic_max_days = config['model_settings']['pro_logic_max_days_after_exac']
|
| 354 |
+
|
| 355 |
+
# Calculate the days since the previous exacerbation for all patient days.
|
| 356 |
+
df = (
|
| 357 |
+
df.groupby("StudyId")
|
| 358 |
+
.apply(
|
| 359 |
+
lambda x: model_h.calculate_days_since_last_event(
|
| 360 |
+
df=x, event_col="IsExac", output_col="DaysSinceLastExac"
|
| 361 |
+
)
|
| 362 |
+
)
|
| 363 |
+
.reset_index(drop=True)
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Apply exclusion period following all exacerbations
|
| 367 |
+
df["RemoveRow"] = model_h.minimum_period_between_exacerbations(
|
| 368 |
+
df, minimum_days=logic_min_days
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Do not remove hospital exacerbations even if they get flagged up by PRO logic
|
| 372 |
+
df["RemoveRow"] = np.where(df["IsHospExac"] == 1, 0, df["RemoveRow"])
|
| 373 |
+
|
| 374 |
+
print(
|
| 375 |
+
"Number of community exacerbations excluded by PRO LOGIC {} day criterion: {}".format(
|
| 376 |
+
logic_min_days, len(df[(df.IsExac == 1) & (df.RemoveRow == 1)])
|
| 377 |
+
)
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Apply criterion for negative weekly Q5 responses - doesn't capture anything post Q5
|
| 381 |
+
# change
|
| 382 |
+
consecutive_replies = config['model_settings']['neg_consecutive_q5_replies']
|
| 383 |
+
df = model_h.apply_logic_response_criterion(
|
| 384 |
+
df,
|
| 385 |
+
minimum_period=logic_min_days,
|
| 386 |
+
maximum_period=logic_max_days,
|
| 387 |
+
N=consecutive_replies,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Do not remove hospital exacerbations even if they get flagged up by PRO logic
|
| 391 |
+
df["RemoveExac"] = np.where(df["IsHospExac"] == 1, 0, df["RemoveExac"])
|
| 392 |
+
|
| 393 |
+
print(
|
| 394 |
+
"Weekly rescue med (Q5) criterion applied to events occurring between {} and {} \
|
| 395 |
+
days after a previous event. {} consecutive negative replies required for the event to \
|
| 396 |
+
count as a new event".format(
|
| 397 |
+
logic_min_days, logic_max_days, consecutive_replies
|
| 398 |
+
)
|
| 399 |
+
)
|
| 400 |
+
print(
|
| 401 |
+
"Number of exacerbations excluded by PRO LOGIC Q5 response criterion: {}".format(
|
| 402 |
+
df.RemoveExac.sum()
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
print(
|
| 406 |
+
"Earliest and latest exacerbations excluded: {}, {}".format(
|
| 407 |
+
df[df.RemoveExac == 1].DateOfEvent.min(),
|
| 408 |
+
df[df.RemoveExac == 1].DateOfEvent.max(),
|
| 409 |
+
)
|
| 410 |
+
)
|
| 411 |
+
print(
|
| 412 |
+
"Remaining number of exacerbations: {}".format(
|
| 413 |
+
len(df[(df.IsExac == 1) & (df.RemoveRow != 1) & (df.RemoveExac != 1)])
|
| 414 |
+
)
|
| 415 |
+
)
|
| 416 |
+
print(
|
| 417 |
+
"Remaining exacerbations with unknown dates: {}".format(
|
| 418 |
+
len(df[(df.ExacDateUnknown == 1) & (df.RemoveRow != 1) & (df.RemoveExac != 1)])
|
| 419 |
+
)
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Remove data between segments of prolonged events, count only first occurrence
|
| 423 |
+
df = model_h.remove_data_between_exacerbations(df)
|
| 424 |
+
|
| 425 |
+
# Remove 7 days before each reported exacerbation within unknown date (meds in last week)
|
| 426 |
+
df = model_h.remove_unknown_date_exacerbations(df)
|
| 427 |
+
|
| 428 |
+
# Remove rows flagged as to remove
|
| 429 |
+
df = df[df["RemoveRow"] != 1]
|
| 430 |
+
|
| 431 |
+
# New df with unwanted rows removed for events breakdown.
|
| 432 |
+
print("---Final exacerbation counts---")
|
| 433 |
+
print("Final number of exacerbations: {}".format(df.IsExac.sum()))
|
| 434 |
+
exac_patients = pd.Series(df[df.IsExac == 1].StudyId.unique())
|
| 435 |
+
print(
|
| 436 |
+
"Number of unique exacerbation patients: {} ({} RC and {} SU)".format(
|
| 437 |
+
len(exac_patients),
|
| 438 |
+
exac_patients.str.startswith("RC").sum(),
|
| 439 |
+
exac_patients.str.startswith("SU").sum(),
|
| 440 |
+
)
|
| 441 |
+
)
|
| 442 |
+
print(
|
| 443 |
+
"Exacerbation breakdown: {} hospital, {} patient reported and {} overlapping".format(
|
| 444 |
+
df.IsHospExac.sum(),
|
| 445 |
+
df.IsCommExac.sum(),
|
| 446 |
+
len(df.loc[(df.IsCommExac == 1) & (df.IsHospExac == 1)]),
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
df.to_pickle("./data/hosp_comm_exacs.pkl")
|
| 450 |
+
|
| 451 |
+
############################################################################
|
| 452 |
+
# Calculate the number of rows to include per patient in the dataset. This
|
| 453 |
+
# is calculated based on the average number of exacerbations per patient and
|
| 454 |
+
# is then adjusted to the average time within the service
|
| 455 |
+
############################################################################
|
| 456 |
+
|
| 457 |
+
# Calculate the average time patients have data recorded in the COPD service
|
| 458 |
+
service_time = df[["StudyId", "LatestPredictionDate", "EarliestDataDate"]]
|
| 459 |
+
service_time = service_time.drop_duplicates(subset="StudyId", keep="first")
|
| 460 |
+
service_time["ServiceTime"] = (
|
| 461 |
+
service_time["LatestPredictionDate"] - service_time["EarliestDataDate"]
|
| 462 |
+
).dt.days
|
| 463 |
+
avg_service_time = sum(service_time["ServiceTime"]) / len(service_time["ServiceTime"])
|
| 464 |
+
avg_service_time_months = round(avg_service_time / 30)
|
| 465 |
+
print("Average time in service (days):", avg_service_time)
|
| 466 |
+
print("Average time in service (months):", avg_service_time_months)
|
| 467 |
+
|
| 468 |
+
# Calculate the average number of exacerberations per patient
|
| 469 |
+
avg_exac_per_patient = round(
|
| 470 |
+
len(df[df["IsExac"] == 1]) / df[df["IsExac"] == 1][["StudyId"]].nunique().item(), 2
|
| 471 |
+
)
|
| 472 |
+
print(
|
| 473 |
+
"Number of exac/patient/months: {} exacerbations/patient in {} months".format(
|
| 474 |
+
avg_exac_per_patient, avg_service_time_months
|
| 475 |
+
)
|
| 476 |
+
)
|
| 477 |
+
print(
|
| 478 |
+
"On average, 1 exacerbation occurs in a patient every: {} months".format(
|
| 479 |
+
round(avg_service_time_months / avg_exac_per_patient, 2)
|
| 480 |
+
)
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
#################################################################
|
| 484 |
+
# Calculate index dates. 1 row/patient for every 5 months in service.
|
| 485 |
+
#################################################################
|
| 486 |
+
|
| 487 |
+
# Obtain the number of rows required per patient.
|
| 488 |
+
service_time["NumRows"] = round(service_time["ServiceTime"] / config['model_settings']['one_row_per_days_in_service']).astype("int")
|
| 489 |
+
patient_details = pd.merge(
|
| 490 |
+
patient_details, service_time[["StudyId", "NumRows"]], on="StudyId", how="left"
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Calculate the number of days between earliest and latest index
|
| 494 |
+
patient_details["NumDaysPossibleIndex"] = (
|
| 495 |
+
patient_details["LatestIndexDate"] - patient_details["EarliestIndexDate"]
|
| 496 |
+
).dt.days
|
| 497 |
+
patient_details.to_csv("./data/pat_details_to_calc_index_dt.csv", index=False)
|
| 498 |
+
|
| 499 |
+
# Make sure the number of rows isn't larger than the number of possible index dates
|
| 500 |
+
patient_details["NumRows"] = np.where(patient_details["NumRows"] > patient_details["NumDaysPossibleIndex"], patient_details["NumDaysPossibleIndex"], patient_details["NumRows"])
|
| 501 |
+
|
| 502 |
+
# Generate random index dates
|
| 503 |
+
# Multiple seeds tested to identify the random index dates that give a good
|
| 504 |
+
# distribution across months. Seed chosen as 2188398760 from check_index_date_dist.py
|
| 505 |
+
random_seed_general = config['model_settings']['index_date_generation_master_seed']
|
| 506 |
+
random.seed(random_seed_general)
|
| 507 |
+
|
| 508 |
+
# Create different random seeds for each patient
|
| 509 |
+
patient_details["RandomSeed"] = random.sample(
|
| 510 |
+
range(0, 2**32), patient_details.shape[0]
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Create random index dates for each patient based on their random seed
|
| 514 |
+
rand_days_dict = {}
|
| 515 |
+
rand_date_dict = {}
|
| 516 |
+
for index, row in patient_details.iterrows():
|
| 517 |
+
np.random.seed(row["RandomSeed"])
|
| 518 |
+
rand_days_dict[row["StudyId"]] = np.random.choice(
|
| 519 |
+
row["NumDaysPossibleIndex"], size=row["NumRows"], replace=False
|
| 520 |
+
)
|
| 521 |
+
rand_date_dict[row["StudyId"]] = [
|
| 522 |
+
row["EarliestIndexDate"] + timedelta(days=int(day))
|
| 523 |
+
for day in rand_days_dict[row["StudyId"]]
|
| 524 |
+
]
|
| 525 |
+
|
| 526 |
+
# Create df from dictionaries containing random index dates
|
| 527 |
+
index_date_df = pd.DataFrame.from_dict(rand_date_dict, orient="index").reset_index()
|
| 528 |
+
index_date_df = index_date_df.rename(columns={"index": "StudyId"})
|
| 529 |
+
|
| 530 |
+
# Convert the multiple columns containing index dates to one column
|
| 531 |
+
index_date_df = (
|
| 532 |
+
pd.melt(index_date_df, id_vars=["StudyId"], value_name="IndexDate")
|
| 533 |
+
.drop(["variable"], axis=1)
|
| 534 |
+
.sort_values(by=["StudyId", "IndexDate"])
|
| 535 |
+
)
|
| 536 |
+
index_date_df = index_date_df.dropna()
|
| 537 |
+
index_date_df = index_date_df.reset_index(drop=True)
|
| 538 |
+
|
| 539 |
+
# Join index dates with exacerbation events
|
| 540 |
+
exac_events = pd.merge(index_date_df, df, on="StudyId", how="left")
|
| 541 |
+
exac_events["IndexDate"] = pd.to_datetime(exac_events["IndexDate"], utc=True)
|
| 542 |
+
|
| 543 |
+
# Calculate whether an exacerbation event occurred within the model time window (3 months)
|
| 544 |
+
# after the index date
|
| 545 |
+
exac_events["TimeToEvent"] = (
|
| 546 |
+
exac_events["DateOfEvent"] - exac_events["IndexDate"]
|
| 547 |
+
).dt.days
|
| 548 |
+
exac_events["ExacWithin3Months"] = np.where(
|
| 549 |
+
(exac_events["TimeToEvent"].between(1, config['model_settings']['prediction_window'], inclusive="both"))
|
| 550 |
+
& (exac_events["IsExac"] == 1),
|
| 551 |
+
1,
|
| 552 |
+
0,
|
| 553 |
+
)
|
| 554 |
+
exac_events["HospExacWithin3Months"] = np.where(
|
| 555 |
+
(exac_events["TimeToEvent"].between(1, config['model_settings']['prediction_window'], inclusive="both"))
|
| 556 |
+
& (exac_events["IsHospExac"] == 1),
|
| 557 |
+
1,
|
| 558 |
+
0,
|
| 559 |
+
)
|
| 560 |
+
exac_events["CommExacWithin3Months"] = np.where(
|
| 561 |
+
(exac_events["TimeToEvent"].between(1, config['model_settings']['prediction_window'], inclusive="both"))
|
| 562 |
+
& (exac_events["IsCommExac"] == 1),
|
| 563 |
+
1,
|
| 564 |
+
0,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
exac_events = exac_events.sort_values(
|
| 568 |
+
by=["StudyId", "IndexDate", "ExacWithin3Months"], ascending=[True, True, False]
|
| 569 |
+
)
|
| 570 |
+
exac_events = exac_events.drop_duplicates(subset=["StudyId", "IndexDate"], keep="first")
|
| 571 |
+
exac_events = exac_events[
|
| 572 |
+
[
|
| 573 |
+
"StudyId",
|
| 574 |
+
"PatientId",
|
| 575 |
+
"IndexDate",
|
| 576 |
+
"DateOfBirth",
|
| 577 |
+
"Sex",
|
| 578 |
+
"ExacWithin3Months",
|
| 579 |
+
"HospExacWithin3Months",
|
| 580 |
+
"CommExacWithin3Months",
|
| 581 |
+
]
|
| 582 |
+
]
|
| 583 |
+
|
| 584 |
+
# Save exac_events
|
| 585 |
+
exac_events.to_pickle("./data/patient_labels_forward_val_hosp_comm.pkl")
|
| 586 |
+
|
| 587 |
+
# Summary info
|
| 588 |
+
class_distribution = (
|
| 589 |
+
exac_events.groupby("ExacWithin3Months").count()[["StudyId"]].reset_index()
|
| 590 |
+
)
|
| 591 |
+
class_distribution.plot.bar(x="ExacWithin3Months", y="StudyId")
|
| 592 |
+
plt.savefig(
|
| 593 |
+
"./plots/class_distributions/final_seed_"
|
| 594 |
+
+ str(random_seed_general)
|
| 595 |
+
+ "_class_distribution_hosp_comm.png",
|
| 596 |
+
bbox_inches="tight",
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
print("---Summary info after setting up labels---")
|
| 600 |
+
print("Number of unique patients:", exac_events["StudyId"].nunique())
|
| 601 |
+
print("Number of rows:", len(exac_events))
|
| 602 |
+
print(
|
| 603 |
+
"Number of exacerbations within 3 months of index date:",
|
| 604 |
+
len(exac_events[exac_events["ExacWithin3Months"] == 1]),
|
| 605 |
+
)
|
| 606 |
+
print(
|
| 607 |
+
"Percentage positive class (num exac/total rows): {} %".format(
|
| 608 |
+
round(
|
| 609 |
+
(len(exac_events[exac_events["ExacWithin3Months"] == 1]) / len(exac_events))
|
| 610 |
+
* 100,
|
| 611 |
+
2,
|
| 612 |
+
)
|
| 613 |
+
)
|
| 614 |
+
)
|
| 615 |
+
print(
|
| 616 |
+
"Percentage negative class: {} %".format(
|
| 617 |
+
round(
|
| 618 |
+
(len(exac_events[exac_events["ExacWithin3Months"] == 0]) / len(exac_events))
|
| 619 |
+
* 100,
|
| 620 |
+
2,
|
| 621 |
+
)
|
| 622 |
+
)
|
| 623 |
+
)
|
| 624 |
+
print(
|
| 625 |
+
"Percentage hospital exacs: {} %".format(
|
| 626 |
+
round(
|
| 627 |
+
(len(exac_events[exac_events["HospExacWithin3Months"] == 1]) / len(exac_events))
|
| 628 |
+
* 100,
|
| 629 |
+
2,
|
| 630 |
+
)
|
| 631 |
+
)
|
| 632 |
+
)
|
| 633 |
+
print(
|
| 634 |
+
"Percentage community exacs: {} %".format(
|
| 635 |
+
round(
|
| 636 |
+
(len(exac_events[exac_events["CommExacWithin3Months"] == 1]) / len(exac_events))
|
| 637 |
+
* 100,
|
| 638 |
+
2,
|
| 639 |
+
)
|
| 640 |
+
)
|
| 641 |
+
)
|
| 642 |
+
print("Class balance:")
|
| 643 |
+
print(class_distribution)
|
training/setup_labels_hosp_comm.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script uses both hospital and community exacerbation events.
|
| 3 |
+
|
| 4 |
+
Collate all hospital, clincian verified and patient reported events and apply
|
| 5 |
+
PRO LOGIC to determine the number of exacerbation events. Use exacerbation events to
|
| 6 |
+
determine the number of rows required per patient in the data and generate random
|
| 7 |
+
index dates and setup labels.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import model_h
|
| 11 |
+
import numpy as np
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
from datetime import timedelta
|
| 17 |
+
import random
|
| 18 |
+
import yaml
|
| 19 |
+
|
| 20 |
+
# Need to have pyyaml in environment to use this
|
| 21 |
+
with open("./training/config.yaml", "r") as config:
|
| 22 |
+
config = yaml.safe_load(config)
|
| 23 |
+
|
| 24 |
+
# Setup log file
|
| 25 |
+
log = open(
|
| 26 |
+
os.path.join(
|
| 27 |
+
config["outputs"]["logging_dir"],
|
| 28 |
+
"setup_labels" + config["model_settings"]["model_type"] + "2023.log",
|
| 29 |
+
),
|
| 30 |
+
"w",
|
| 31 |
+
)
|
| 32 |
+
sys.stdout = log
|
| 33 |
+
|
| 34 |
+
############################################################################
|
| 35 |
+
# Define model cohort and training data windows
|
| 36 |
+
############################################################################
|
| 37 |
+
|
| 38 |
+
# Read relevant info from patient details
|
| 39 |
+
patient_details = pd.read_csv(
|
| 40 |
+
config["inputs"]["raw_data_paths"]["patient_details"],
|
| 41 |
+
usecols=[
|
| 42 |
+
"PatientId",
|
| 43 |
+
"FirstSubmissionDate",
|
| 44 |
+
"MostRecentSubmissionDate",
|
| 45 |
+
"DateOfBirth",
|
| 46 |
+
"Sex",
|
| 47 |
+
"StudyId",
|
| 48 |
+
],
|
| 49 |
+
delimiter="|",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Select patients for inclusion (those with up to date events in service)
|
| 53 |
+
# Original RECEIVER cohort study id list
|
| 54 |
+
receiver_patients = ["RC{:02d}".format(i) for i in range(1, 85)]
|
| 55 |
+
# This patient needs removing
|
| 56 |
+
receiver_patients.remove("RC34")
|
| 57 |
+
|
| 58 |
+
# Scale up patients (subset)
|
| 59 |
+
scaleup_patients = ["SU{:02d}".format(i) for i in range(1, 219)]
|
| 60 |
+
|
| 61 |
+
# List of all valid patients for modelling
|
| 62 |
+
valid_patients = receiver_patients + scaleup_patients
|
| 63 |
+
|
| 64 |
+
# Filter for valid patients accounting for white spaces in StudyId (e.g. RC 26 and RC 52)
|
| 65 |
+
patient_details = patient_details[
|
| 66 |
+
patient_details.StudyId.str.replace(" ", "").isin(valid_patients)
|
| 67 |
+
]
|
| 68 |
+
# Select only non null entries in patient data start/end dates
|
| 69 |
+
patient_details = patient_details[
|
| 70 |
+
(patient_details.FirstSubmissionDate.notna())
|
| 71 |
+
& (patient_details.MostRecentSubmissionDate.notna())
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# Get death data
|
| 75 |
+
patient_deaths = pd.read_csv(
|
| 76 |
+
config["inputs"]["raw_data_paths"]["patient_events"],
|
| 77 |
+
usecols=[
|
| 78 |
+
"PatientId",
|
| 79 |
+
"DateOfEvent",
|
| 80 |
+
"EventType",
|
| 81 |
+
],
|
| 82 |
+
delimiter="|",
|
| 83 |
+
)
|
| 84 |
+
patient_deaths = patient_deaths[patient_deaths["EventType"] == "Death"]
|
| 85 |
+
patient_deaths = patient_deaths.rename(columns={"DateOfEvent": "DeathDate"})
|
| 86 |
+
patient_deaths = patient_deaths.drop(columns=["EventType"])
|
| 87 |
+
|
| 88 |
+
# Merge patient details with deaths
|
| 89 |
+
patient_details = patient_details.merge(patient_deaths, on="PatientId", how="left")
|
| 90 |
+
|
| 91 |
+
############################################################################
|
| 92 |
+
# Define training data windows
|
| 93 |
+
############################################################################
|
| 94 |
+
|
| 95 |
+
# Create a column stating the latest date permitted based on events added to service data
|
| 96 |
+
patient_details["LatestPredictionDate"] = config["model_settings"][
|
| 97 |
+
"latest_date_before_bug_break"
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
# Create a column stating when the events start again
|
| 101 |
+
patient_details["AfterGapStartDate"] = config["model_settings"][
|
| 102 |
+
"after_bug_fixed_start_date"
|
| 103 |
+
]
|
| 104 |
+
patient_details["DataEndDate"] = config["model_settings"]["training_data_end_date"]
|
| 105 |
+
|
| 106 |
+
date_cols = [
|
| 107 |
+
"FirstSubmissionDate",
|
| 108 |
+
"MostRecentSubmissionDate",
|
| 109 |
+
"LatestPredictionDate",
|
| 110 |
+
"AfterGapStartDate",
|
| 111 |
+
"DataEndDate",
|
| 112 |
+
"DeathDate",
|
| 113 |
+
]
|
| 114 |
+
patient_details[date_cols] = patient_details[date_cols].apply(
|
| 115 |
+
lambda x: pd.to_datetime(x, utc=True, format="mixed").dt.normalize(), axis=1
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Choose the earlier date out of the patient's last submission, death date and the latest
|
| 119 |
+
# COPD data events
|
| 120 |
+
patient_details["LatestPredictionDate"] = patient_details[
|
| 121 |
+
["MostRecentSubmissionDate", "LatestPredictionDate", "DeathDate"]
|
| 122 |
+
].min(axis=1)
|
| 123 |
+
|
| 124 |
+
patient_details["DataEndDate"] = patient_details[
|
| 125 |
+
["MostRecentSubmissionDate", "DataEndDate", "DeathDate"]
|
| 126 |
+
].min(axis=1)
|
| 127 |
+
|
| 128 |
+
# Calculate the latest date that the index date can be for each patient
|
| 129 |
+
patient_details["LatestIndexDate"] = patient_details[
|
| 130 |
+
"LatestPredictionDate"
|
| 131 |
+
] - pd.DateOffset(days=config["model_settings"]["prediction_window"])
|
| 132 |
+
|
| 133 |
+
patient_details["LatestIndexAfterGap"] = patient_details["DataEndDate"] - pd.DateOffset(
|
| 134 |
+
days=config["model_settings"]["prediction_window"]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Add 6 months to start of data window to allow enough of a lookback period
|
| 138 |
+
patient_details["EarliestIndexDate"] = patient_details[
|
| 139 |
+
"FirstSubmissionDate"
|
| 140 |
+
] + pd.DateOffset(days=config["model_settings"]["lookback_period"])
|
| 141 |
+
|
| 142 |
+
patient_details["EarliestIndexAfterGap"] = patient_details[
|
| 143 |
+
"AfterGapStartDate"
|
| 144 |
+
] + pd.DateOffset(days=config["model_settings"]["lookback_period"])
|
| 145 |
+
|
| 146 |
+
# Remove any patients for whom the index start date overlaps the last index
|
| 147 |
+
# date, i.e. they have too short a window of data
|
| 148 |
+
print("Number of total patients", len(patient_details))
|
| 149 |
+
print(
|
| 150 |
+
"Number of patients with too short of a window of data:",
|
| 151 |
+
len(
|
| 152 |
+
patient_details[
|
| 153 |
+
patient_details["EarliestIndexDate"] > patient_details["LatestIndexDate"]
|
| 154 |
+
]
|
| 155 |
+
),
|
| 156 |
+
)
|
| 157 |
+
patient_details = patient_details[
|
| 158 |
+
patient_details["EarliestIndexDate"] < patient_details["LatestIndexDate"]
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
patient_details["DatesAfterGap"] = np.where(
|
| 162 |
+
patient_details["EarliestIndexAfterGap"] > patient_details["LatestIndexAfterGap"],
|
| 163 |
+
False,
|
| 164 |
+
True,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Calculate length in service
|
| 168 |
+
patient_details["FirstLength"] = (
|
| 169 |
+
patient_details["LatestIndexDate"] - patient_details["EarliestIndexDate"]
|
| 170 |
+
).dt.days
|
| 171 |
+
patient_details["SecondLength"] = (
|
| 172 |
+
patient_details["LatestIndexAfterGap"] - patient_details["EarliestIndexAfterGap"]
|
| 173 |
+
).dt.days
|
| 174 |
+
patient_details["LengthInService"] = np.where(
|
| 175 |
+
patient_details["DatesAfterGap"] == True,
|
| 176 |
+
patient_details["FirstLength"] + patient_details["SecondLength"],
|
| 177 |
+
patient_details["FirstLength"],
|
| 178 |
+
)
|
| 179 |
+
patient_details["TotalLength1"] = (
|
| 180 |
+
patient_details["LatestPredictionDate"] - patient_details["FirstSubmissionDate"]
|
| 181 |
+
).dt.days
|
| 182 |
+
patient_details["TotalLength2"] = (
|
| 183 |
+
patient_details["DataEndDate"] - patient_details["AfterGapStartDate"]
|
| 184 |
+
).dt.days
|
| 185 |
+
patient_details["TotalLengthInService"] = np.where(
|
| 186 |
+
patient_details["DatesAfterGap"] == True,
|
| 187 |
+
patient_details["TotalLength1"] + patient_details["TotalLength2"],
|
| 188 |
+
patient_details["TotalLength1"],
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Save patient details
|
| 192 |
+
patient_details.to_pickle(
|
| 193 |
+
os.path.join(config["outputs"]["output_data_dir"], "patient_details.pkl")
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# List of remaining patients
|
| 197 |
+
model_patients = list(patient_details.PatientId.unique())
|
| 198 |
+
model_study_ids = list(patient_details.StudyId.unique())
|
| 199 |
+
|
| 200 |
+
print(
|
| 201 |
+
"Model cohort: {} patients. {} RECEIVER and {} SU".format(
|
| 202 |
+
len(model_patients),
|
| 203 |
+
len(patient_details[patient_details["StudyId"].str.startswith("RC")]),
|
| 204 |
+
len(patient_details[patient_details["StudyId"].str.startswith("SU")]),
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
df1 = patient_details[
|
| 209 |
+
[
|
| 210 |
+
"PatientId",
|
| 211 |
+
"DateOfBirth",
|
| 212 |
+
"Sex",
|
| 213 |
+
"StudyId",
|
| 214 |
+
"FirstSubmissionDate",
|
| 215 |
+
"EarliestIndexDate",
|
| 216 |
+
"LatestIndexDate",
|
| 217 |
+
"LatestPredictionDate",
|
| 218 |
+
"AfterGapStartDate",
|
| 219 |
+
"EarliestIndexAfterGap",
|
| 220 |
+
"LatestIndexAfterGap",
|
| 221 |
+
"DataEndDate",
|
| 222 |
+
]
|
| 223 |
+
].copy()
|
| 224 |
+
df2 = df1.copy()
|
| 225 |
+
|
| 226 |
+
# Create a row per day between the FirstSubmissionDate and the LatestPredictionDate
|
| 227 |
+
df1["DateOfEvent"] = df1.apply(
|
| 228 |
+
lambda x: pd.date_range(x.FirstSubmissionDate, x.LatestPredictionDate, freq="D"),
|
| 229 |
+
axis=1,
|
| 230 |
+
)
|
| 231 |
+
# Create a row per day between AfterGapStartDate and DataEndDate
|
| 232 |
+
df2["DateOfEvent"] = df2.apply(
|
| 233 |
+
lambda x: pd.date_range(x.AfterGapStartDate, x.DataEndDate, freq="D"),
|
| 234 |
+
axis=1,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Combine dfs from before and after the time gap where exac data is unreliable
|
| 238 |
+
df1 = df1.explode("DateOfEvent").reset_index(drop=True)
|
| 239 |
+
df2 = df2.explode("DateOfEvent").reset_index(drop=True)
|
| 240 |
+
df2 = df2.dropna(subset=["DateOfEvent"])
|
| 241 |
+
df = pd.concat([df1, df2])
|
| 242 |
+
df = df.sort_values(by=["StudyId", "DateOfEvent"])
|
| 243 |
+
|
| 244 |
+
############################################################################
|
| 245 |
+
# Extract hospital exacerbations and admissions from COPD service data
|
| 246 |
+
############################################################################
|
| 247 |
+
|
| 248 |
+
# Contains exacerbations among other event types
|
| 249 |
+
patient_events = pd.read_csv(
|
| 250 |
+
config["inputs"]["raw_data_paths"]["patient_events"],
|
| 251 |
+
delimiter="|",
|
| 252 |
+
usecols=["PatientId", "DateOfEvent", "EventType"],
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Filter for only patients in model cohort
|
| 256 |
+
patient_events = patient_events[patient_events.PatientId.isin(model_patients)]
|
| 257 |
+
|
| 258 |
+
# Identify hospital exacerbation events
|
| 259 |
+
patient_events["IsHospExac"] = model_h.define_service_exac_event(
|
| 260 |
+
events=patient_events.EventType, include_community=False
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Identify hospital admissions (all causes)
|
| 264 |
+
patient_events["IsHospAdmission"] = model_h.define_hospital_admission(
|
| 265 |
+
patient_events.EventType
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
admissions = patient_events[patient_events.IsHospAdmission == 1][
|
| 269 |
+
["PatientId", "DateOfEvent", "IsHospAdmission"]
|
| 270 |
+
]
|
| 271 |
+
hosp_exacs = patient_events[patient_events.IsHospExac == 1][
|
| 272 |
+
["PatientId", "DateOfEvent", "IsHospExac"]
|
| 273 |
+
]
|
| 274 |
+
admissions["DateOfEvent"] = pd.to_datetime(
|
| 275 |
+
admissions.DateOfEvent, utc=True
|
| 276 |
+
).dt.normalize()
|
| 277 |
+
hosp_exacs["DateOfEvent"] = pd.to_datetime(
|
| 278 |
+
hosp_exacs.DateOfEvent, utc=True
|
| 279 |
+
).dt.normalize()
|
| 280 |
+
|
| 281 |
+
hosp_exacs = hosp_exacs.drop_duplicates()
|
| 282 |
+
admissions = admissions.drop_duplicates()
|
| 283 |
+
|
| 284 |
+
# Save hospital exacerbations and admissions data
|
| 285 |
+
hosp_exacs.to_pickle(
|
| 286 |
+
os.path.join(config["outputs"]["output_data_dir"], "hospital_exacerbations.pkl")
|
| 287 |
+
)
|
| 288 |
+
admissions.to_pickle(
|
| 289 |
+
os.path.join(config["outputs"]["output_data_dir"], "hospital_admissions.pkl")
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
############################################################################
|
| 293 |
+
# Extract patient reported exacerbation events
|
| 294 |
+
############################################################################
|
| 295 |
+
|
| 296 |
+
########################
|
| 297 |
+
# Data post Q5 change
|
| 298 |
+
#######################
|
| 299 |
+
|
| 300 |
+
# Read file containing patient reported events (not patient_events because it contains
|
| 301 |
+
# the dates when patients answered PROs and not which date they reported as having taken
|
| 302 |
+
# their rescue meds)
|
| 303 |
+
symptom_diary = pd.read_csv(
|
| 304 |
+
config["inputs"]["raw_data_paths"]["pro_symptom_diary"],
|
| 305 |
+
usecols=[
|
| 306 |
+
"PatientId",
|
| 307 |
+
"StudyId",
|
| 308 |
+
"Score",
|
| 309 |
+
"SubmissionTime",
|
| 310 |
+
"SymptomDiaryQ5",
|
| 311 |
+
"SymptomDiaryQ11a",
|
| 312 |
+
"SymptomDiaryQ11b",
|
| 313 |
+
],
|
| 314 |
+
delimiter="|",
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
Q5ChangeDate = pd.to_datetime(config["model_settings"]["pro_q5_change_date"], utc=True)
|
| 318 |
+
symptom_diary = model_h.filter_symptom_diary(
|
| 319 |
+
df=symptom_diary, date_cutoff=Q5ChangeDate, patients=model_patients
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
weekly_pros = model_h.get_rescue_med_pro_responses(symptom_diary)
|
| 323 |
+
weekly_pros = model_h.set_pro_exac_dates(weekly_pros)
|
| 324 |
+
weekly_pros = weekly_pros[
|
| 325 |
+
[
|
| 326 |
+
"PatientId",
|
| 327 |
+
"Q5Answered",
|
| 328 |
+
"NegativeQ5",
|
| 329 |
+
"IsCommExac",
|
| 330 |
+
"DateOfEvent",
|
| 331 |
+
"ExacDateUnknown",
|
| 332 |
+
]
|
| 333 |
+
]
|
| 334 |
+
|
| 335 |
+
#########################
|
| 336 |
+
# Pre Q5 change events
|
| 337 |
+
#########################
|
| 338 |
+
|
| 339 |
+
# RECEIVER cohort - community events verified up to 16/03/21
|
| 340 |
+
receiver = pd.read_excel(
|
| 341 |
+
config["inputs"]["raw_data_paths"]["receiver_community_verified_events"]
|
| 342 |
+
)
|
| 343 |
+
receiver = receiver.rename(
|
| 344 |
+
columns={"Study number": "StudyId", "Exacerbation recorded": "DateRecorded"}
|
| 345 |
+
)
|
| 346 |
+
receiver_exacs = model_h.extract_clinician_verified_exacerbations(receiver)
|
| 347 |
+
|
| 348 |
+
# Scale up cohort - community events verified up to 17/05/2021
|
| 349 |
+
scaleup = pd.read_excel(
|
| 350 |
+
config["inputs"]["raw_data_paths"]["scale_up_community_verified_events"]
|
| 351 |
+
)
|
| 352 |
+
scaleup = scaleup.rename(
|
| 353 |
+
columns={"Study Number": "StudyId", "Date Exacerbation recorded": "DateRecorded"}
|
| 354 |
+
)
|
| 355 |
+
scaleup["StudyId"] = scaleup["StudyId"].ffill()
|
| 356 |
+
scaleup_exacs = model_h.extract_clinician_verified_exacerbations(scaleup)
|
| 357 |
+
|
| 358 |
+
# Combine RECEIVER and scale up events into one df
|
| 359 |
+
verified_exacs = pd.concat([receiver_exacs, scaleup_exacs])
|
| 360 |
+
verified_exacs = verified_exacs[verified_exacs.StudyId.isin(model_study_ids)]
|
| 361 |
+
|
| 362 |
+
####################################################################################
|
| 363 |
+
# Merge hospital and patient reported events with daily patient records
|
| 364 |
+
#
|
| 365 |
+
# Exacerbations occurring in Lenus service period include verified clinician events
|
| 366 |
+
# pre-April 2021 (after onboarding) and community exacerbations recorded in weekly
|
| 367 |
+
# PROs post-April 2021. Hospital exacerbations include exacerbations occuring during
|
| 368 |
+
# service period.
|
| 369 |
+
#####################################################################################
|
| 370 |
+
|
| 371 |
+
# Patient reported, clinician verified
|
| 372 |
+
df = df.merge(verified_exacs, on=["StudyId", "DateOfEvent"], how="left")
|
| 373 |
+
|
| 374 |
+
# Patient reported, new rescue med PRO (April 2021 onwards)
|
| 375 |
+
df = df.merge(weekly_pros, on=["PatientId", "DateOfEvent"], how="left")
|
| 376 |
+
|
| 377 |
+
# Hospital exacerbations
|
| 378 |
+
df = df.merge(hosp_exacs, on=["PatientId", "DateOfEvent"], how="left")
|
| 379 |
+
df = model_h.fill_column_by_patient(df=df, id_col="PatientId", col="StudyId")
|
| 380 |
+
|
| 381 |
+
# Hospital admissions
|
| 382 |
+
df = df.merge(admissions, on=["PatientId", "DateOfEvent"], how="left")
|
| 383 |
+
df = model_h.fill_column_by_patient(df=df, id_col="PatientId", col="StudyId")
|
| 384 |
+
|
| 385 |
+
# Combine cols from individual datasets into one
|
| 386 |
+
df["ExacDateUnknown"] = np.where(
|
| 387 |
+
(df.ExacDateUnknown_x == 1) | (df.ExacDateUnknown_y == 1), 1, 0
|
| 388 |
+
)
|
| 389 |
+
df["IsCommExac"] = np.where((df.IsCommExac_x == 1) | (df.IsCommExac_y == 1), 1, 0)
|
| 390 |
+
|
| 391 |
+
# Column for whether an exacerbation of any kind occurred on each date. To be filtered
|
| 392 |
+
# using (PRO) LOGIC
|
| 393 |
+
df["IsExac"] = np.where((df.IsCommExac == 1) | (df.IsHospExac == 1), 1, 0)
|
| 394 |
+
|
| 395 |
+
# Resample the df to one day per patient starting from the earliest record
|
| 396 |
+
df = (
|
| 397 |
+
df.set_index("DateOfEvent")
|
| 398 |
+
.groupby("StudyId")
|
| 399 |
+
.resample("D")
|
| 400 |
+
.asfreq()
|
| 401 |
+
.drop("StudyId", axis=1)
|
| 402 |
+
.reset_index()
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Infill binary cols with zero where applicable
|
| 406 |
+
df[
|
| 407 |
+
[
|
| 408 |
+
"Q5Answered",
|
| 409 |
+
"NegativeQ5",
|
| 410 |
+
"IsHospExac",
|
| 411 |
+
"IsCommExac",
|
| 412 |
+
"ExacDateUnknown",
|
| 413 |
+
"IsExac",
|
| 414 |
+
"IsHospAdmission",
|
| 415 |
+
]
|
| 416 |
+
] = df[
|
| 417 |
+
[
|
| 418 |
+
"Q5Answered",
|
| 419 |
+
"NegativeQ5",
|
| 420 |
+
"IsHospExac",
|
| 421 |
+
"IsCommExac",
|
| 422 |
+
"ExacDateUnknown",
|
| 423 |
+
"IsExac",
|
| 424 |
+
"IsHospAdmission",
|
| 425 |
+
]
|
| 426 |
+
].fillna(
|
| 427 |
+
0
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Infill some columns by StudyId to populate entire df
|
| 431 |
+
df = model_h.fill_column_by_patient(df=df, id_col="StudyId", col="FirstSubmissionDate")
|
| 432 |
+
df = model_h.fill_column_by_patient(df=df, id_col="StudyId", col="LatestPredictionDate")
|
| 433 |
+
df = model_h.fill_column_by_patient(df=df, id_col="StudyId", col="PatientId")
|
| 434 |
+
|
| 435 |
+
print("Starting number of exacerbations: {}".format(df.IsExac.sum()))
|
| 436 |
+
print(
|
| 437 |
+
"Number of exacerbations during COPD service: {}".format(
|
| 438 |
+
len(df[(df.IsExac == 1) & (df.DateOfEvent >= df.FirstSubmissionDate)])
|
| 439 |
+
)
|
| 440 |
+
)
|
| 441 |
+
print(
|
| 442 |
+
"Number of unique exacerbation patients: {}".format(
|
| 443 |
+
len(df[df.IsExac == 1].PatientId.unique())
|
| 444 |
+
)
|
| 445 |
+
)
|
| 446 |
+
print(
|
| 447 |
+
"Exacerbation breakdown: {} hospital, {} patient reported and {} overlapping".format(
|
| 448 |
+
df.IsHospExac.sum(),
|
| 449 |
+
df.IsCommExac.sum(),
|
| 450 |
+
len(df.loc[(df.IsCommExac == 1) & (df.IsHospExac == 1)]),
|
| 451 |
+
)
|
| 452 |
+
)
|
| 453 |
+
print(
|
| 454 |
+
"Number of hospital exacerbations during COPD service: {} ({} unique patients)".format(
|
| 455 |
+
len(df[(df.IsHospExac == 1) & (df.DateOfEvent >= df.FirstSubmissionDate)]),
|
| 456 |
+
len(
|
| 457 |
+
df[
|
| 458 |
+
(df.IsHospExac == 1) & (df.DateOfEvent >= df.FirstSubmissionDate)
|
| 459 |
+
].StudyId.unique()
|
| 460 |
+
),
|
| 461 |
+
)
|
| 462 |
+
)
|
| 463 |
+
print(
|
| 464 |
+
"Clinician verified community exacerbations during COPD service: {} ({} unique patients)".format(
|
| 465 |
+
len(df[df.IsCommExac_x == 1]), len(df[df.IsCommExac_x == 1].StudyId.unique())
|
| 466 |
+
)
|
| 467 |
+
)
|
| 468 |
+
print(
|
| 469 |
+
"Community exacerbations from weekly PROs: {} ({} unique patients)".format(
|
| 470 |
+
len(df[df.IsCommExac_y == 1]), len(df[df.IsCommExac_y == 1].StudyId.unique())
|
| 471 |
+
)
|
| 472 |
+
)
|
| 473 |
+
print(
|
| 474 |
+
"Number of patient reported exacerbations with unknown dates: {} ({} overlapping\
|
| 475 |
+
with hospital events)".format(
|
| 476 |
+
df.ExacDateUnknown.sum(),
|
| 477 |
+
len(df[(df.IsHospExac == 1) & (df.ExacDateUnknown == 1)]),
|
| 478 |
+
)
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# Check for any patient reported events with unknown dates that occurred on the same day
|
| 482 |
+
# as a hospital event. Hospital events are trusted so set the date to known
|
| 483 |
+
df.loc[(df.IsCommExac == 1) & (df.IsHospExac == 1), "ExacDateUnknown"] = 0
|
| 484 |
+
print("Remaining exacerbations with unknown dates: {}".format(df.ExacDateUnknown.sum()))
|
| 485 |
+
|
| 486 |
+
df = df.drop(
|
| 487 |
+
columns=["IsCommExac_x", "IsCommExac_y", "ExacDateUnknown_x", "ExacDateUnknown_y"]
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
############################################################################
|
| 491 |
+
# Implement PRO LOGIC on hospital and patient reported exacerbation events
|
| 492 |
+
############################################################################
|
| 493 |
+
|
| 494 |
+
# Define min and max days for PRO LOGIC. No predictions made or data used within
|
| 495 |
+
# logic_min_days after an exacerbation. Events falling between logic_min_days and
|
| 496 |
+
# logic_max_days after an event are subject to the weekly rescue med LOGIC criterion
|
| 497 |
+
logic_min_days = config["model_settings"]["pro_logic_min_days_after_exac"]
|
| 498 |
+
logic_max_days = config["model_settings"]["pro_logic_max_days_after_exac"]
|
| 499 |
+
|
| 500 |
+
# Calculate the days since the previous exacerbation for all patient days.
|
| 501 |
+
df = (
|
| 502 |
+
df.groupby("StudyId")
|
| 503 |
+
.apply(
|
| 504 |
+
lambda x: model_h.calculate_days_since_last_event(
|
| 505 |
+
df=x, event_col="IsExac", output_col="DaysSinceLastExac"
|
| 506 |
+
)
|
| 507 |
+
)
|
| 508 |
+
.reset_index(drop=True)
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Apply exclusion period following all exacerbations
|
| 512 |
+
df["RemoveRow"] = model_h.minimum_period_between_exacerbations(
|
| 513 |
+
df, minimum_days=logic_min_days
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# Do not remove hospital exacerbations even if they get flagged up by PRO logic
|
| 517 |
+
df["RemoveRow"] = np.where(df["IsHospExac"] == 1, 0, df["RemoveRow"])
|
| 518 |
+
|
| 519 |
+
print(
|
| 520 |
+
"Number of community exacerbations excluded by PRO LOGIC {} day criterion: {}".format(
|
| 521 |
+
logic_min_days, len(df[(df.IsExac == 1) & (df.RemoveRow == 1)])
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# Apply criterion for negative weekly Q5 responses - doesn't capture anything post Q5
|
| 526 |
+
# change
|
| 527 |
+
consecutive_replies = config["model_settings"]["neg_consecutive_q5_replies"]
|
| 528 |
+
df = model_h.apply_logic_response_criterion(
|
| 529 |
+
df,
|
| 530 |
+
minimum_period=logic_min_days,
|
| 531 |
+
maximum_period=logic_max_days,
|
| 532 |
+
N=consecutive_replies,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# Do not remove hospital exacerbations even if they get flagged up by PRO logic
|
| 536 |
+
df["RemoveExac"] = np.where(df["IsHospExac"] == 1, 0, df["RemoveExac"])
|
| 537 |
+
|
| 538 |
+
print(
|
| 539 |
+
"Weekly rescue med (Q5) criterion applied to events occurring between {} and {} \
|
| 540 |
+
days after a previous event. {} consecutive negative replies required for the event to \
|
| 541 |
+
count as a new event".format(
|
| 542 |
+
logic_min_days, logic_max_days, consecutive_replies
|
| 543 |
+
)
|
| 544 |
+
)
|
| 545 |
+
print(
|
| 546 |
+
"Number of exacerbations excluded by PRO LOGIC Q5 response criterion: {}".format(
|
| 547 |
+
df.RemoveExac.sum()
|
| 548 |
+
)
|
| 549 |
+
)
|
| 550 |
+
print(
|
| 551 |
+
"Earliest and latest exacerbations excluded: {}, {}".format(
|
| 552 |
+
df[df.RemoveExac == 1].DateOfEvent.min(),
|
| 553 |
+
df[df.RemoveExac == 1].DateOfEvent.max(),
|
| 554 |
+
)
|
| 555 |
+
)
|
| 556 |
+
print(
|
| 557 |
+
"Remaining number of exacerbations: {}".format(
|
| 558 |
+
len(df[(df.IsExac == 1) & (df.RemoveRow != 1) & (df.RemoveExac != 1)])
|
| 559 |
+
)
|
| 560 |
+
)
|
| 561 |
+
print(
|
| 562 |
+
"Remaining exacerbations with unknown dates: {}".format(
|
| 563 |
+
len(df[(df.ExacDateUnknown == 1) & (df.RemoveRow != 1) & (df.RemoveExac != 1)])
|
| 564 |
+
)
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# Remove data between segments of prolonged events, count only first occurrence
|
| 568 |
+
df = model_h.remove_data_between_exacerbations(df)
|
| 569 |
+
|
| 570 |
+
# Remove 7 days before each reported exacerbation within unknown date (meds in last week)
|
| 571 |
+
df = model_h.remove_unknown_date_exacerbations(df)
|
| 572 |
+
|
| 573 |
+
# Remove rows flagged as to remove
|
| 574 |
+
df = df[df["RemoveRow"] != 1]
|
| 575 |
+
|
| 576 |
+
# New df with unwanted rows removed for events breakdown.
|
| 577 |
+
print("---Final exacerbation counts---")
|
| 578 |
+
print("Final number of exacerbations: {}".format(df.IsExac.sum()))
|
| 579 |
+
exac_patients = pd.Series(df[df.IsExac == 1].StudyId.unique())
|
| 580 |
+
print(
|
| 581 |
+
"Number of unique exacerbation patients: {} ({} RC and {} SU)".format(
|
| 582 |
+
len(exac_patients),
|
| 583 |
+
exac_patients.str.startswith("RC").sum(),
|
| 584 |
+
exac_patients.str.startswith("SU").sum(),
|
| 585 |
+
)
|
| 586 |
+
)
|
| 587 |
+
print(
|
| 588 |
+
"Exacerbation breakdown: {} hospital, {} patient reported and {} overlapping".format(
|
| 589 |
+
df.IsHospExac.sum(),
|
| 590 |
+
df.IsCommExac.sum(),
|
| 591 |
+
len(df.loc[(df.IsCommExac == 1) & (df.IsHospExac == 1)]),
|
| 592 |
+
)
|
| 593 |
+
)
|
| 594 |
+
df.to_pickle(os.path.join(config["outputs"]["output_data_dir"], "hosp_comm_exacs.pkl"))
|
| 595 |
+
|
| 596 |
+
############################################################################
|
| 597 |
+
# Calculate the number of rows to include per patient in the dataset. This
|
| 598 |
+
# is calculated based on the average number of exacerbations per patient and
|
| 599 |
+
# is then adjusted to the average time within the service
|
| 600 |
+
############################################################################
|
| 601 |
+
|
| 602 |
+
# Calculate the average time patients have data recorded in the COPD service
|
| 603 |
+
service_time = patient_details[["StudyId", "TotalLengthInService"]]
|
| 604 |
+
service_time = service_time.drop_duplicates(subset="StudyId", keep="first")
|
| 605 |
+
print(service_time)
|
| 606 |
+
|
| 607 |
+
avg_service_time = sum(service_time["TotalLengthInService"]) / len(
|
| 608 |
+
service_time["TotalLengthInService"]
|
| 609 |
+
)
|
| 610 |
+
avg_service_time_months = round(avg_service_time / 30)
|
| 611 |
+
print("Average time in service (days):", avg_service_time)
|
| 612 |
+
print("Average time in service (months):", avg_service_time_months)
|
| 613 |
+
|
| 614 |
+
# Calculate the average number of exacerberations per patient
|
| 615 |
+
avg_exac_per_patient = round(
|
| 616 |
+
len(df[(df["IsExac"] == 1)]) / df[(df["IsExac"] == 1) | (df["IsExac"] == 0)][["StudyId"]].nunique().item(), 2
|
| 617 |
+
)
|
| 618 |
+
print(
|
| 619 |
+
"Number of exac/patient/months: {} exacerbations/patient in {} months".format(
|
| 620 |
+
avg_exac_per_patient, avg_service_time_months
|
| 621 |
+
)
|
| 622 |
+
)
|
| 623 |
+
print(
|
| 624 |
+
"On average, 1 exacerbation occurs in a patient every: {} months".format(
|
| 625 |
+
round(avg_service_time_months / avg_exac_per_patient, 2)
|
| 626 |
+
)
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
#################################################################
|
| 630 |
+
# Calculate index dates. 1 row/patient for every 5 months in service.
|
| 631 |
+
#################################################################
|
| 632 |
+
|
| 633 |
+
# Obtain the number of rows required per patient.
|
| 634 |
+
service_time["NumRows"] = round(
|
| 635 |
+
service_time["TotalLengthInService"]
|
| 636 |
+
/ config["model_settings"]["one_row_per_days_in_service"]
|
| 637 |
+
).astype("int")
|
| 638 |
+
# If patient has not been in service for 5 months, make sure they have 1 row
|
| 639 |
+
service_time["NumRows"] = np.where(
|
| 640 |
+
service_time["NumRows"] < 1, 1, service_time["NumRows"]
|
| 641 |
+
)
|
| 642 |
+
patient_details = pd.merge(
|
| 643 |
+
patient_details, service_time[["StudyId", "NumRows"]], on="StudyId", how="left"
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# Calculate the number of days between earliest and latest index
|
| 647 |
+
patient_details["NumDaysPossibleIndex"] = (
|
| 648 |
+
patient_details["LatestIndexDate"] - patient_details["EarliestIndexDate"]
|
| 649 |
+
).dt.days
|
| 650 |
+
|
| 651 |
+
patient_details["NumDaysPossibleIndex2"] = (
|
| 652 |
+
patient_details["LatestIndexAfterGap"] - patient_details["EarliestIndexAfterGap"]
|
| 653 |
+
).dt.days
|
| 654 |
+
|
| 655 |
+
patient_details.to_csv(
|
| 656 |
+
os.path.join(
|
| 657 |
+
config["outputs"]["output_data_dir"], "pat_details_to_calc_index_dt.csv"
|
| 658 |
+
),
|
| 659 |
+
index=False,
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
# Generate random index dates
|
| 663 |
+
# Multiple seeds tested to identify the random index dates that give a good
|
| 664 |
+
# distribution across months. Seed chosen as 2188398760 from check_index_date_dist.py
|
| 665 |
+
random_seed_general = config["model_settings"]["index_date_generation_master_seed"]
|
| 666 |
+
random.seed(random_seed_general)
|
| 667 |
+
|
| 668 |
+
# Create different random seeds for each patient
|
| 669 |
+
patient_details["RandomSeed"] = random.sample(range(0, 2**32), patient_details.shape[0])
|
| 670 |
+
|
| 671 |
+
# Create random index dates for each patient based on their random seed
|
| 672 |
+
rand_days_dict = {}
|
| 673 |
+
rand_date_dict = {}
|
| 674 |
+
for index, row in patient_details.iterrows():
|
| 675 |
+
np.random.seed(row["RandomSeed"])
|
| 676 |
+
rand_days_dict[row["StudyId"]] = np.random.choice(
|
| 677 |
+
row["LengthInService"], size=row["NumRows"], replace=False
|
| 678 |
+
)
|
| 679 |
+
rand_date_dict[row["StudyId"]] = [
|
| 680 |
+
(
|
| 681 |
+
row["EarliestIndexDate"] + timedelta(days=int(day))
|
| 682 |
+
if day <= row["NumDaysPossibleIndex"]
|
| 683 |
+
else row["EarliestIndexAfterGap"]
|
| 684 |
+
+ timedelta(days=int(day - row["NumDaysPossibleIndex"]))
|
| 685 |
+
)
|
| 686 |
+
for day in rand_days_dict[row["StudyId"]]
|
| 687 |
+
]
|
| 688 |
+
|
| 689 |
+
# Create df from dictionaries containing random index dates
|
| 690 |
+
index_date_df = pd.DataFrame.from_dict(rand_date_dict, orient="index").reset_index()
|
| 691 |
+
index_date_df = index_date_df.rename(columns={"index": "StudyId"})
|
| 692 |
+
|
| 693 |
+
# Convert the multiple columns containing index dates to one column
|
| 694 |
+
index_date_df = (
|
| 695 |
+
pd.melt(index_date_df, id_vars=["StudyId"], value_name="IndexDate")
|
| 696 |
+
.drop(["variable"], axis=1)
|
| 697 |
+
.sort_values(by=["StudyId", "IndexDate"])
|
| 698 |
+
)
|
| 699 |
+
index_date_df = index_date_df.dropna()
|
| 700 |
+
index_date_df = index_date_df.reset_index(drop=True)
|
| 701 |
+
|
| 702 |
+
# Join index dates with exacerbation events
|
| 703 |
+
exac_events = pd.merge(index_date_df, df, on="StudyId", how="left")
|
| 704 |
+
exac_events["IndexDate"] = pd.to_datetime(exac_events["IndexDate"], utc=True)
|
| 705 |
+
|
| 706 |
+
# Calculate whether an exacerbation event occurred within the model time window (3 months)
|
| 707 |
+
# after the index date
|
| 708 |
+
exac_events["TimeToEvent"] = (
|
| 709 |
+
exac_events["DateOfEvent"] - exac_events["IndexDate"]
|
| 710 |
+
).dt.days
|
| 711 |
+
exac_events["ExacWithin3Months"] = np.where(
|
| 712 |
+
(
|
| 713 |
+
exac_events["TimeToEvent"].between(
|
| 714 |
+
1, config["model_settings"]["prediction_window"], inclusive="both"
|
| 715 |
+
)
|
| 716 |
+
)
|
| 717 |
+
& (exac_events["IsExac"] == 1),
|
| 718 |
+
1,
|
| 719 |
+
0,
|
| 720 |
+
)
|
| 721 |
+
exac_events["HospExacWithin3Months"] = np.where(
|
| 722 |
+
(
|
| 723 |
+
exac_events["TimeToEvent"].between(
|
| 724 |
+
1, config["model_settings"]["prediction_window"], inclusive="both"
|
| 725 |
+
)
|
| 726 |
+
)
|
| 727 |
+
& (exac_events["IsHospExac"] == 1),
|
| 728 |
+
1,
|
| 729 |
+
0,
|
| 730 |
+
)
|
| 731 |
+
exac_events["CommExacWithin3Months"] = np.where(
|
| 732 |
+
(
|
| 733 |
+
exac_events["TimeToEvent"].between(
|
| 734 |
+
1, config["model_settings"]["prediction_window"], inclusive="both"
|
| 735 |
+
)
|
| 736 |
+
)
|
| 737 |
+
& (exac_events["IsCommExac"] == 1),
|
| 738 |
+
1,
|
| 739 |
+
0,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
exac_events = exac_events.sort_values(
|
| 743 |
+
by=["StudyId", "IndexDate", "ExacWithin3Months"], ascending=[True, True, False]
|
| 744 |
+
)
|
| 745 |
+
exac_events = exac_events.drop_duplicates(subset=["StudyId", "IndexDate"], keep="first")
|
| 746 |
+
exac_events = exac_events[
|
| 747 |
+
[
|
| 748 |
+
"StudyId",
|
| 749 |
+
"PatientId",
|
| 750 |
+
"IndexDate",
|
| 751 |
+
"DateOfBirth",
|
| 752 |
+
"Sex",
|
| 753 |
+
"ExacWithin3Months",
|
| 754 |
+
"HospExacWithin3Months",
|
| 755 |
+
"CommExacWithin3Months",
|
| 756 |
+
]
|
| 757 |
+
]
|
| 758 |
+
|
| 759 |
+
# Save exac_events
|
| 760 |
+
exac_events.to_pickle(
|
| 761 |
+
os.path.join(config["outputs"]["output_data_dir"], "patient_labels_hosp_comm.pkl")
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
# Summary info
|
| 765 |
+
class_distribution = (
|
| 766 |
+
exac_events.groupby("ExacWithin3Months").count()[["StudyId"]].reset_index()
|
| 767 |
+
)
|
| 768 |
+
class_distribution.plot.bar(x="ExacWithin3Months", y="StudyId")
|
| 769 |
+
plt.savefig(
|
| 770 |
+
"./plots/class_distributions/final_seed_"
|
| 771 |
+
+ str(random_seed_general)
|
| 772 |
+
+ "_class_distribution_hosp_comm.png",
|
| 773 |
+
bbox_inches="tight",
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
print("---Summary info after setting up labels---")
|
| 777 |
+
print("Number of unique patients:", exac_events["StudyId"].nunique())
|
| 778 |
+
print("Number of rows:", len(exac_events))
|
| 779 |
+
print(
|
| 780 |
+
"Number of exacerbations within 3 months of index date:",
|
| 781 |
+
len(exac_events[exac_events["ExacWithin3Months"] == 1]),
|
| 782 |
+
)
|
| 783 |
+
print(
|
| 784 |
+
"Percentage positive class (num exac/total rows): {} %".format(
|
| 785 |
+
round(
|
| 786 |
+
(len(exac_events[exac_events["ExacWithin3Months"] == 1]) / len(exac_events))
|
| 787 |
+
* 100,
|
| 788 |
+
2,
|
| 789 |
+
)
|
| 790 |
+
)
|
| 791 |
+
)
|
| 792 |
+
print(
|
| 793 |
+
"Percentage negative class: {} %".format(
|
| 794 |
+
round(
|
| 795 |
+
(len(exac_events[exac_events["ExacWithin3Months"] == 0]) / len(exac_events))
|
| 796 |
+
* 100,
|
| 797 |
+
2,
|
| 798 |
+
)
|
| 799 |
+
)
|
| 800 |
+
)
|
| 801 |
+
print(
|
| 802 |
+
"Percentage hospital exacs: {} %".format(
|
| 803 |
+
round(
|
| 804 |
+
(
|
| 805 |
+
len(exac_events[exac_events["HospExacWithin3Months"] == 1])
|
| 806 |
+
/ len(exac_events)
|
| 807 |
+
)
|
| 808 |
+
* 100,
|
| 809 |
+
2,
|
| 810 |
+
)
|
| 811 |
+
)
|
| 812 |
+
)
|
| 813 |
+
print(
|
| 814 |
+
"Percentage community exacs: {} %".format(
|
| 815 |
+
round(
|
| 816 |
+
(
|
| 817 |
+
len(exac_events[exac_events["CommExacWithin3Months"] == 1])
|
| 818 |
+
/ len(exac_events)
|
| 819 |
+
)
|
| 820 |
+
* 100,
|
| 821 |
+
2,
|
| 822 |
+
)
|
| 823 |
+
)
|
| 824 |
+
)
|
| 825 |
+
print("Class balance:")
|
| 826 |
+
print(class_distribution)
|
| 827 |
+
|
| 828 |
+
print("---Events based on dates---")
|
| 829 |
+
verified_events = exac_events[
|
| 830 |
+
exac_events["IndexDate"] <= pd.to_datetime("2021-11-30", utc=True)
|
| 831 |
+
]
|
| 832 |
+
unverified_events = exac_events[
|
| 833 |
+
exac_events["IndexDate"] > pd.to_datetime("2021-11-30", utc=True)
|
| 834 |
+
]
|
| 835 |
+
print("---Verified events---")
|
| 836 |
+
print(
|
| 837 |
+
"Percentage positive class (num exac/total rows): {} %".format(
|
| 838 |
+
round(
|
| 839 |
+
(
|
| 840 |
+
len(verified_events[verified_events["ExacWithin3Months"] == 1])
|
| 841 |
+
/ len(verified_events)
|
| 842 |
+
)
|
| 843 |
+
* 100,
|
| 844 |
+
2,
|
| 845 |
+
)
|
| 846 |
+
)
|
| 847 |
+
)
|
| 848 |
+
print(
|
| 849 |
+
"Percentage negative class: {} %".format(
|
| 850 |
+
round(
|
| 851 |
+
(
|
| 852 |
+
len(verified_events[verified_events["ExacWithin3Months"] == 0])
|
| 853 |
+
/ len(verified_events)
|
| 854 |
+
)
|
| 855 |
+
* 100,
|
| 856 |
+
2,
|
| 857 |
+
)
|
| 858 |
+
)
|
| 859 |
+
)
|
| 860 |
+
print(
|
| 861 |
+
"Percentage hospital exacs: {} %".format(
|
| 862 |
+
round(
|
| 863 |
+
(
|
| 864 |
+
len(verified_events[verified_events["HospExacWithin3Months"] == 1])
|
| 865 |
+
/ len(verified_events)
|
| 866 |
+
)
|
| 867 |
+
* 100,
|
| 868 |
+
2,
|
| 869 |
+
)
|
| 870 |
+
)
|
| 871 |
+
)
|
| 872 |
+
print(
|
| 873 |
+
"Percentage community exacs: {} %".format(
|
| 874 |
+
round(
|
| 875 |
+
(
|
| 876 |
+
len(verified_events[verified_events["CommExacWithin3Months"] == 1])
|
| 877 |
+
/ len(verified_events)
|
| 878 |
+
)
|
| 879 |
+
* 100,
|
| 880 |
+
2,
|
| 881 |
+
)
|
| 882 |
+
)
|
| 883 |
+
)
|
| 884 |
+
print("---Unverified events---")
|
| 885 |
+
print(
|
| 886 |
+
"Percentage positive class (num exac/total rows): {} %".format(
|
| 887 |
+
round(
|
| 888 |
+
(
|
| 889 |
+
len(unverified_events[unverified_events["ExacWithin3Months"] == 1])
|
| 890 |
+
/ len(unverified_events)
|
| 891 |
+
)
|
| 892 |
+
* 100,
|
| 893 |
+
2,
|
| 894 |
+
)
|
| 895 |
+
)
|
| 896 |
+
)
|
| 897 |
+
print(
|
| 898 |
+
"Percentage negative class: {} %".format(
|
| 899 |
+
round(
|
| 900 |
+
(
|
| 901 |
+
len(unverified_events[unverified_events["ExacWithin3Months"] == 0])
|
| 902 |
+
/ len(unverified_events)
|
| 903 |
+
)
|
| 904 |
+
* 100,
|
| 905 |
+
2,
|
| 906 |
+
)
|
| 907 |
+
)
|
| 908 |
+
)
|
| 909 |
+
print(
|
| 910 |
+
"Percentage hospital exacs: {} %".format(
|
| 911 |
+
round(
|
| 912 |
+
(
|
| 913 |
+
len(unverified_events[unverified_events["HospExacWithin3Months"] == 1])
|
| 914 |
+
/ len(unverified_events)
|
| 915 |
+
)
|
| 916 |
+
* 100,
|
| 917 |
+
2,
|
| 918 |
+
)
|
| 919 |
+
)
|
| 920 |
+
)
|
| 921 |
+
print(
|
| 922 |
+
"Percentage community exacs: {} %".format(
|
| 923 |
+
round(
|
| 924 |
+
(
|
| 925 |
+
len(unverified_events[unverified_events["CommExacWithin3Months"] == 1])
|
| 926 |
+
/ len(unverified_events)
|
| 927 |
+
)
|
| 928 |
+
* 100,
|
| 929 |
+
2,
|
| 930 |
+
)
|
| 931 |
+
)
|
| 932 |
+
)
|
| 933 |
+
print(
|
| 934 |
+
"Train date range", exac_events["IndexDate"].min(), exac_events["IndexDate"].max()
|
| 935 |
+
)
|
training/setup_labels_only_hosp.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script uses only hospital exacerbation events.
|
| 3 |
+
|
| 4 |
+
Collate all hospital to determine the number of exacerbation events. Use exacerbation
|
| 5 |
+
events to determine the number of rows required per patient in the data and generate
|
| 6 |
+
random index dates and setup labels.
|
| 7 |
+
"""
|
| 8 |
+
import model_h
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import sys
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from datetime import timedelta
|
| 15 |
+
import random
|
| 16 |
+
|
| 17 |
+
data_dir_service = "<YOUR_DATA_PATH>/copd-dataset"
|
| 18 |
+
data_dir_model = "./data"
|
| 19 |
+
|
| 20 |
+
# Setup log file
|
| 21 |
+
log = open("./training/logging/setup_labels_only_hosp.log", "w")
|
| 22 |
+
sys.stdout = log
|
| 23 |
+
|
| 24 |
+
# Model time window in days to predict exacerbation
|
| 25 |
+
model_time_window = 90
|
| 26 |
+
|
| 27 |
+
############################################################################
|
| 28 |
+
# Define model cohort and training data windows
|
| 29 |
+
############################################################################
|
| 30 |
+
|
| 31 |
+
# Read relevant info from patient details
|
| 32 |
+
patient_details = pd.read_csv(
|
| 33 |
+
os.path.join(data_dir_service, "CopdDatasetPatientDetails.txt"),
|
| 34 |
+
usecols=[
|
| 35 |
+
"PatientId",
|
| 36 |
+
"FirstSubmissionDate",
|
| 37 |
+
"MostRecentSubmissionDate",
|
| 38 |
+
"DateOfBirth",
|
| 39 |
+
"Sex",
|
| 40 |
+
"StudyId",
|
| 41 |
+
],
|
| 42 |
+
delimiter="|",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Select patients for inclusion (those with up to date events in service)
|
| 46 |
+
# Create list of patients for model inclusion
|
| 47 |
+
# Original RECEIVER cohort study id list
|
| 48 |
+
receiver_patients = ["RC{:02d}".format(i) for i in range(1, 85)]
|
| 49 |
+
# This patient needs removing
|
| 50 |
+
receiver_patients.remove("RC34")
|
| 51 |
+
# Scale up patients (subset)
|
| 52 |
+
scaleup_patients = ["SU{:02d}".format(i) for i in range(1, 219)]
|
| 53 |
+
# scaleup_patients.append('SU287') #Removed as study ID contains 2 patients
|
| 54 |
+
|
| 55 |
+
# List of all valid patients for modelling
|
| 56 |
+
valid_patients = receiver_patients + scaleup_patients
|
| 57 |
+
|
| 58 |
+
# Filter for valid patients accounting for white spaces in StudyId (e.g. RC 26 and RC 52)
|
| 59 |
+
patient_details = patient_details[
|
| 60 |
+
patient_details.StudyId.str.replace(" ", "").isin(valid_patients)
|
| 61 |
+
]
|
| 62 |
+
# Select only non null entries in patient data start/end dates
|
| 63 |
+
patient_details = patient_details[
|
| 64 |
+
(patient_details.FirstSubmissionDate.notna())
|
| 65 |
+
& (patient_details.MostRecentSubmissionDate.notna())
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
# Create a column stating the latest date permitted based on events added to service data
|
| 69 |
+
patient_details["LatestPredictionDate"] = "2022-02-28"
|
| 70 |
+
|
| 71 |
+
date_cols = ["FirstSubmissionDate", "MostRecentSubmissionDate", "LatestPredictionDate"]
|
| 72 |
+
patient_details[date_cols] = patient_details[date_cols].apply(
|
| 73 |
+
lambda x: pd.to_datetime(x, utc=True, format="mixed").dt.normalize(), axis=1
|
| 74 |
+
)
|
| 75 |
+
# Choose the earlier date out of the patient's last submission and the latest COPD data
|
| 76 |
+
# events
|
| 77 |
+
patient_details["LatestPredictionDate"] = patient_details[
|
| 78 |
+
["MostRecentSubmissionDate", "LatestPredictionDate"]
|
| 79 |
+
].min(axis=1)
|
| 80 |
+
|
| 81 |
+
# Calculate the latest date that the index date can be for each patient
|
| 82 |
+
patient_details["LatestIndexDate"] = patient_details[
|
| 83 |
+
"LatestPredictionDate"
|
| 84 |
+
] - pd.DateOffset(days=model_time_window)
|
| 85 |
+
|
| 86 |
+
# Add 6 months to start of data window to allow enough of a lookback period
|
| 87 |
+
patient_details["EarliestIndexDate"] = patient_details[
|
| 88 |
+
"FirstSubmissionDate"
|
| 89 |
+
] + pd.DateOffset(days=180)
|
| 90 |
+
|
| 91 |
+
# Remove any patients for whom the index start date overlaps the last index
|
| 92 |
+
# date, i.e. they have too short a window of data
|
| 93 |
+
print("Number of total patients", len(patient_details))
|
| 94 |
+
print(
|
| 95 |
+
"Number of patients with too short of a window of data:",
|
| 96 |
+
len(
|
| 97 |
+
patient_details[
|
| 98 |
+
patient_details["EarliestIndexDate"] > patient_details["LatestIndexDate"]
|
| 99 |
+
]
|
| 100 |
+
),
|
| 101 |
+
)
|
| 102 |
+
patient_details = patient_details[
|
| 103 |
+
patient_details["EarliestIndexDate"] < patient_details["LatestIndexDate"]
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
# List of remaining patients
|
| 107 |
+
model_patients = list(patient_details.PatientId.unique())
|
| 108 |
+
model_study_ids = list(patient_details.StudyId.unique())
|
| 109 |
+
|
| 110 |
+
print(
|
| 111 |
+
"Model cohort: {} patients. {} RECEIVER and {} SU".format(
|
| 112 |
+
len(model_patients),
|
| 113 |
+
len(patient_details[patient_details["StudyId"].str.startswith("RC")]),
|
| 114 |
+
len(patient_details[patient_details["StudyId"].str.startswith("SU")]),
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
df = patient_details[
|
| 119 |
+
[
|
| 120 |
+
"PatientId",
|
| 121 |
+
"DateOfBirth",
|
| 122 |
+
"Sex",
|
| 123 |
+
"StudyId",
|
| 124 |
+
"FirstSubmissionDate",
|
| 125 |
+
"EarliestIndexDate",
|
| 126 |
+
"LatestIndexDate",
|
| 127 |
+
"LatestPredictionDate",
|
| 128 |
+
]
|
| 129 |
+
].copy()
|
| 130 |
+
|
| 131 |
+
############################################################################
|
| 132 |
+
# Extract hospital exacerbations and admissions from COPD service data
|
| 133 |
+
############################################################################
|
| 134 |
+
|
| 135 |
+
# Load hospital exacerbations and admissions data
|
| 136 |
+
hosp_exacs = pd.read_pickle(os.path.join(data_dir_model, "hospital_exacerbations.pkl"))
|
| 137 |
+
admissions = pd.read_pickle(os.path.join(data_dir_model, "hospital_admissions.pkl"))
|
| 138 |
+
|
| 139 |
+
# Merge hospital exacs and admissions data
|
| 140 |
+
hosp_exacs = hosp_exacs.merge(admissions, on=["PatientId", "DateOfEvent"], how="outer")
|
| 141 |
+
|
| 142 |
+
# Fill missing values in PatientId and StudyId using a lookup table
|
| 143 |
+
patient_id_lookup = patient_details[["PatientId", "StudyId"]]
|
| 144 |
+
hosp_exacs["StudyId"] = np.NaN
|
| 145 |
+
hosp_exacs["StudyId"] = np.where(
|
| 146 |
+
hosp_exacs.StudyId.isnull(),
|
| 147 |
+
hosp_exacs.PatientId.map(patient_id_lookup.set_index("PatientId").StudyId),
|
| 148 |
+
hosp_exacs.StudyId,
|
| 149 |
+
)
|
| 150 |
+
hosp_exacs = hosp_exacs.sort_values(
|
| 151 |
+
by=["StudyId", "DateOfEvent", "IsHospExac", "IsHospAdmission"],
|
| 152 |
+
ascending=[True, True, False, False],
|
| 153 |
+
)
|
| 154 |
+
exac_data = hosp_exacs.drop_duplicates(subset=["StudyId", "DateOfEvent"], keep="first")
|
| 155 |
+
exac_data.to_pickle(os.path.join(data_dir_model, "only_hosp_exacs.pkl"))
|
| 156 |
+
|
| 157 |
+
# Merge with patient details
|
| 158 |
+
exac_data = pd.merge(
|
| 159 |
+
exac_data,
|
| 160 |
+
df[["StudyId", "PatientId", "FirstSubmissionDate", "LatestPredictionDate"]],
|
| 161 |
+
on=["StudyId", "PatientId"],
|
| 162 |
+
how="left",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Remove exacerbations before onboarding to COPD service
|
| 166 |
+
exac_data = exac_data[exac_data["DateOfEvent"] > exac_data["FirstSubmissionDate"]]
|
| 167 |
+
|
| 168 |
+
# Retain only dates before the end of each patient's data window
|
| 169 |
+
exac_data = exac_data[exac_data.DateOfEvent <= exac_data.LatestPredictionDate]
|
| 170 |
+
exac_data = exac_data.drop(columns=["FirstSubmissionDate", "LatestPredictionDate"])
|
| 171 |
+
|
| 172 |
+
df = pd.merge(df, exac_data, on=["StudyId", "PatientId"], how="left")
|
| 173 |
+
df = df.rename(columns={"IsHospExac": "IsExac"})
|
| 174 |
+
|
| 175 |
+
print("Starting number of exacerbations: {}".format(df.IsExac.sum()))
|
| 176 |
+
print(
|
| 177 |
+
"Number of unique exacerbation patients: {}".format(
|
| 178 |
+
len(df[df.IsExac == 1].PatientId.unique())
|
| 179 |
+
)
|
| 180 |
+
)
|
| 181 |
+
print(
|
| 182 |
+
"Hospital exacerbations: {} ({} unique patients)".format(
|
| 183 |
+
len(df[(df.IsExac == 1)]), len(df[(df.IsExac == 1)].StudyId.unique())
|
| 184 |
+
)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
#####################################################################
|
| 188 |
+
# Calculate the number of rows to include per patient in the dataset.
|
| 189 |
+
# This is calculated based on the average number of exacerbations per
|
| 190 |
+
# patient and is then adjusted to the average time within the service
|
| 191 |
+
#####################################################################
|
| 192 |
+
|
| 193 |
+
# Calculate the average time patients have data recorded in the COPD service
|
| 194 |
+
service_time = df[["StudyId", "LatestPredictionDate", "FirstSubmissionDate"]]
|
| 195 |
+
service_time = service_time.drop_duplicates(subset="StudyId", keep="first")
|
| 196 |
+
service_time["ServiceTime"] = (
|
| 197 |
+
service_time["LatestPredictionDate"] - service_time["FirstSubmissionDate"]
|
| 198 |
+
).dt.days
|
| 199 |
+
avg_service_time = sum(service_time["ServiceTime"]) / len(service_time["ServiceTime"])
|
| 200 |
+
avg_service_time_months = round(avg_service_time / 30)
|
| 201 |
+
print("Average time in service (days):", avg_service_time)
|
| 202 |
+
print("Average time in service (months):", avg_service_time_months)
|
| 203 |
+
|
| 204 |
+
# Calculate the average number of exacerberations per patient
|
| 205 |
+
avg_exac_per_patient = round(
|
| 206 |
+
len(df[df["IsExac"] == 1]) / df[df["IsExac"] == 1][["StudyId"]].nunique().item(), 2
|
| 207 |
+
)
|
| 208 |
+
print(
|
| 209 |
+
"Number of exac/patient/months: {} exacerbations/patient in {} months".format(
|
| 210 |
+
avg_exac_per_patient, avg_service_time_months
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
print(
|
| 214 |
+
"On average, 1 exacerbation occurs in a patient every: {} months".format(
|
| 215 |
+
round(avg_service_time_months / avg_exac_per_patient, 2)
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
#################################################################
|
| 220 |
+
# Calculate index dates. 1 row/patient for every 6 months in service.
|
| 221 |
+
#################################################################
|
| 222 |
+
|
| 223 |
+
# Obtain the number of rows required per patient. One row per patient for every 6 months in service.
|
| 224 |
+
service_time["NumRows"] = round(service_time["ServiceTime"] / 180).astype("int")
|
| 225 |
+
patient_details = pd.merge(
|
| 226 |
+
patient_details, service_time[["StudyId", "NumRows"]], on="StudyId", how="left"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Calculate the number of days between earliest and latest index
|
| 230 |
+
patient_details["NumDaysPossibleIndex"] = (
|
| 231 |
+
patient_details["LatestIndexDate"] - patient_details["EarliestIndexDate"]
|
| 232 |
+
).dt.days
|
| 233 |
+
# patient_details['NumRows'] = patient_details['NumRows'].astype('int')
|
| 234 |
+
patient_details.to_csv("./data/pat_details_to_calc_index_dt.csv", index=False)
|
| 235 |
+
|
| 236 |
+
# Generate random index dates
|
| 237 |
+
# Multiple seeds tested to identify the random index dates that give a good
|
| 238 |
+
# distribution across months. Seed chosen as 2188398760 from check_index_date_dist.py
|
| 239 |
+
random_seed_general = 2188398760
|
| 240 |
+
random.seed(random_seed_general)
|
| 241 |
+
|
| 242 |
+
# Create different random seeds for each patient
|
| 243 |
+
patient_details["RandomSeed"] = random.sample(
|
| 244 |
+
range(0, 2**32), patient_details.shape[0]
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Create random index dates for each patient based on their random seed
|
| 248 |
+
rand_days_dict = {}
|
| 249 |
+
rand_date_dict = {}
|
| 250 |
+
for index, row in patient_details.iterrows():
|
| 251 |
+
np.random.seed(row["RandomSeed"])
|
| 252 |
+
rand_days_dict[row["StudyId"]] = np.random.choice(
|
| 253 |
+
row["NumDaysPossibleIndex"], size=row["NumRows"], replace=False
|
| 254 |
+
)
|
| 255 |
+
rand_date_dict[row["StudyId"]] = [
|
| 256 |
+
row["EarliestIndexDate"] + timedelta(days=int(day))
|
| 257 |
+
for day in rand_days_dict[row["StudyId"]]
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
# Create df from dictionaries containing random index dates
|
| 261 |
+
index_date_df = pd.DataFrame.from_dict(rand_date_dict, orient="index").reset_index()
|
| 262 |
+
index_date_df = index_date_df.rename(columns={"index": "StudyId"})
|
| 263 |
+
|
| 264 |
+
# Convert the multiple columns containg index dates to one column
|
| 265 |
+
index_date_df = (
|
| 266 |
+
pd.melt(index_date_df, id_vars=["StudyId"], value_name="IndexDate")
|
| 267 |
+
.drop(["variable"], axis=1)
|
| 268 |
+
.sort_values(by=["StudyId", "IndexDate"])
|
| 269 |
+
)
|
| 270 |
+
index_date_df = index_date_df.dropna()
|
| 271 |
+
index_date_df = index_date_df.reset_index(drop=True)
|
| 272 |
+
|
| 273 |
+
# Join index dates with exacerbation events
|
| 274 |
+
exac_events = pd.merge(index_date_df, df, on="StudyId", how="left")
|
| 275 |
+
exac_events["IndexDate"] = pd.to_datetime(exac_events["IndexDate"], utc=True)
|
| 276 |
+
|
| 277 |
+
# Calculate whether an exacerbation event occurred within
|
| 278 |
+
# the model time window (3 months) after the index date
|
| 279 |
+
exac_events["TimeToEvent"] = (
|
| 280 |
+
exac_events["DateOfEvent"] - exac_events["IndexDate"]
|
| 281 |
+
).dt.days
|
| 282 |
+
exac_events["ExacWithin3Months"] = np.where(
|
| 283 |
+
(exac_events["TimeToEvent"].between(1, model_time_window, inclusive="both"))
|
| 284 |
+
& (exac_events["IsExac"] == 1),
|
| 285 |
+
1,
|
| 286 |
+
0,
|
| 287 |
+
)
|
| 288 |
+
exac_events = exac_events.sort_values(
|
| 289 |
+
by=["StudyId", "IndexDate", "ExacWithin3Months"], ascending=[True, True, False]
|
| 290 |
+
)
|
| 291 |
+
exac_events = exac_events.drop_duplicates(subset=["StudyId", "IndexDate"], keep="first")
|
| 292 |
+
exac_events = exac_events[
|
| 293 |
+
["StudyId", "PatientId", "IndexDate", "DateOfBirth", "Sex", "ExacWithin3Months"]
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
# Save exac_events
|
| 297 |
+
exac_events.to_pickle(os.path.join(data_dir_model, "patient_labels_only_hosp.pkl"))
|
| 298 |
+
|
| 299 |
+
# Summary info
|
| 300 |
+
class_distribution = (
|
| 301 |
+
exac_events.groupby("ExacWithin3Months").count()[["StudyId"]].reset_index()
|
| 302 |
+
)
|
| 303 |
+
class_distribution.plot.bar(x="ExacWithin3Months", y="StudyId")
|
| 304 |
+
plt.title("Class distribution of hospital exacerbations occuring within 3 months")
|
| 305 |
+
plt.savefig(
|
| 306 |
+
"./plots/class_distributions/final_seed_"
|
| 307 |
+
+ str(random_seed_general)
|
| 308 |
+
+ "_class_distribution_only_hosp.png",
|
| 309 |
+
bbox_inches="tight",
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
print("---Summary info after setting up labels---")
|
| 313 |
+
print("Number of unique patients:", exac_events["StudyId"].nunique())
|
| 314 |
+
print("Number of rows:", len(exac_events))
|
| 315 |
+
print(
|
| 316 |
+
"Number of exacerbations within 3 months of index date:",
|
| 317 |
+
len(exac_events[exac_events["ExacWithin3Months"] == 1]),
|
| 318 |
+
)
|
| 319 |
+
print(
|
| 320 |
+
"Percentage positive class (num exac/total rows): {} %".format(
|
| 321 |
+
round(
|
| 322 |
+
(len(exac_events[exac_events["ExacWithin3Months"] == 1]) / len(exac_events))
|
| 323 |
+
* 100,
|
| 324 |
+
2,
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
+
print(
|
| 329 |
+
"Percentage negative class: {} %".format(
|
| 330 |
+
round(
|
| 331 |
+
(len(exac_events[exac_events["ExacWithin3Months"] == 0]) / len(exac_events))
|
| 332 |
+
* 100,
|
| 333 |
+
2,
|
| 334 |
+
)
|
| 335 |
+
)
|
| 336 |
+
)
|
| 337 |
+
print("Class balance:")
|
| 338 |
+
print(class_distribution)
|
training/split_train_test_val.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Splits the model H cohort into train, test and balanced cross validation folds.
|
| 2 |
+
|
| 3 |
+
The train set retains class ratio, sex and age distributions from the full dataset.
|
| 4 |
+
Patients can only appear in either train or test set.
|
| 5 |
+
|
| 6 |
+
This script also splits the train data into balanced folds for cross-validation. Patient
|
| 7 |
+
IDs for train, test and all data folds are stored for use in subsequent scripts.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import pickle
|
| 14 |
+
import sys
|
| 15 |
+
import yaml
|
| 16 |
+
import splitting
|
| 17 |
+
|
| 18 |
+
with open("./training/config.yaml", "r") as config:
|
| 19 |
+
config = yaml.safe_load(config)
|
| 20 |
+
|
| 21 |
+
##############################################################
|
| 22 |
+
# Load correct files
|
| 23 |
+
##############################################################
|
| 24 |
+
save_cohort_info = True
|
| 25 |
+
|
| 26 |
+
# Specify which model to perform split on
|
| 27 |
+
model_type = config["model_settings"]["model_type"]
|
| 28 |
+
|
| 29 |
+
# Setup log file
|
| 30 |
+
log = open("./training/logging/split_train_test_" + model_type + ".log", "w")
|
| 31 |
+
sys.stdout = log
|
| 32 |
+
|
| 33 |
+
demographics = pd.read_pickle(
|
| 34 |
+
os.path.join(
|
| 35 |
+
config["outputs"]["processed_data_dir"],
|
| 36 |
+
"demographics_{}.pkl".format(model_type),
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
##############################################################
|
| 41 |
+
# Split data into a train and hold out test set
|
| 42 |
+
##############################################################
|
| 43 |
+
train_data, test_data = splitting.subject_wise_train_test_split(
|
| 44 |
+
data=demographics,
|
| 45 |
+
target_col="ExacWithin3Months",
|
| 46 |
+
id_col="StudyId",
|
| 47 |
+
test_size=0.2,
|
| 48 |
+
stratify_by_sex=True,
|
| 49 |
+
sex_col="Sex_F",
|
| 50 |
+
stratify_by_age=True,
|
| 51 |
+
age_bin_col="AgeBinned",
|
| 52 |
+
)
|
| 53 |
+
print(demographics.Sex_F.value_counts() / demographics.Sex_F.count())
|
| 54 |
+
print(train_data.Sex_F.value_counts() / train_data.Sex_F.count())
|
| 55 |
+
print(test_data.Sex_F.value_counts() / test_data.Sex_F.count())
|
| 56 |
+
print(demographics.Age.mean())
|
| 57 |
+
print(train_data.Age.mean())
|
| 58 |
+
print(test_data.Age.mean())
|
| 59 |
+
|
| 60 |
+
train_ids = train_data.StudyId.unique()
|
| 61 |
+
test_ids = test_data.StudyId.unique()
|
| 62 |
+
##############################################################
|
| 63 |
+
# Split training data into groups for cross validation
|
| 64 |
+
##############################################################
|
| 65 |
+
fold_patients = splitting.subject_wise_kfold_split(
|
| 66 |
+
train_data=train_data,
|
| 67 |
+
target_col="ExacWithin3Months",
|
| 68 |
+
id_col="StudyId",
|
| 69 |
+
num_folds=5,
|
| 70 |
+
sex_col="Sex_F",
|
| 71 |
+
age_col="Age",
|
| 72 |
+
stratify_by_sex=True,
|
| 73 |
+
print_log=True,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
##############################################################
|
| 77 |
+
# Save cohort info
|
| 78 |
+
##############################################################
|
| 79 |
+
fold_patients = np.array(fold_patients, dtype="object")
|
| 80 |
+
|
| 81 |
+
if save_cohort_info:
|
| 82 |
+
# Save train and test ID info
|
| 83 |
+
os.makedirs(config["outputs"]["cohort_info_dir"], exist_ok=True)
|
| 84 |
+
with open(
|
| 85 |
+
os.path.join(
|
| 86 |
+
config["outputs"]["cohort_info_dir"], "test_ids_" + model_type + ".pkl"
|
| 87 |
+
),
|
| 88 |
+
"wb",
|
| 89 |
+
) as f:
|
| 90 |
+
pickle.dump(list(test_ids), f)
|
| 91 |
+
with open(
|
| 92 |
+
os.path.join(
|
| 93 |
+
config["outputs"]["cohort_info_dir"], "train_ids_" + model_type + ".pkl"
|
| 94 |
+
),
|
| 95 |
+
"wb",
|
| 96 |
+
) as f:
|
| 97 |
+
pickle.dump(list(train_ids), f)
|
| 98 |
+
print("Train and test patient IDs saved")
|
| 99 |
+
|
| 100 |
+
# Save cross validation fold info
|
| 101 |
+
np.save(
|
| 102 |
+
os.path.join(
|
| 103 |
+
config["outputs"]["cohort_info_dir"], "fold_patients_" + model_type + ".npy"
|
| 104 |
+
),
|
| 105 |
+
fold_patients,
|
| 106 |
+
allow_pickle=True,
|
| 107 |
+
)
|
| 108 |
+
print("Cross validation fold information saved")
|
training/splitting.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module to perform splitting of data into train/test or K-folds."""
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.model_selection import StratifiedGroupKFold
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def subject_wise_train_test_split(
|
| 9 |
+
data,
|
| 10 |
+
target_col,
|
| 11 |
+
id_col,
|
| 12 |
+
test_size,
|
| 13 |
+
stratify_by_sex=False,
|
| 14 |
+
sex_col=None,
|
| 15 |
+
stratify_by_age=False,
|
| 16 |
+
age_bin_col=None,
|
| 17 |
+
shuffle=False,
|
| 18 |
+
random_state=None,
|
| 19 |
+
):
|
| 20 |
+
"""Subject wise splitting data into train and test sets.
|
| 21 |
+
|
| 22 |
+
Splits data into train and test sets ensuring that the same patient can only appear in
|
| 23 |
+
either train or test set. Stratifies data according to class label, with additional
|
| 24 |
+
options to stratify by sex and age.
|
| 25 |
+
|
| 26 |
+
Parameters
|
| 27 |
+
----------
|
| 28 |
+
data : pd.DataFrame
|
| 29 |
+
dataframe containing features and target.
|
| 30 |
+
target_col : str
|
| 31 |
+
name of target column.
|
| 32 |
+
id_col : str
|
| 33 |
+
name of patient id column.
|
| 34 |
+
test_size : float
|
| 35 |
+
represents the proportion of the dataset to include in the test split. Float should
|
| 36 |
+
be between 0 and 1.
|
| 37 |
+
stratify_by_sex : bool, optional
|
| 38 |
+
option to stratify data by sex, by default False.
|
| 39 |
+
sex_col : str, optional
|
| 40 |
+
name of sex column, by default None.
|
| 41 |
+
stratify_by_age : bool, optional
|
| 42 |
+
option to stratify data by age, by default False.
|
| 43 |
+
age_bin_col : str, optional
|
| 44 |
+
name of age column, by default None. Age column must be provided in binned format.
|
| 45 |
+
shuffle : bool, optional
|
| 46 |
+
whether to shuffle each class's samples before splitting into batches, by default
|
| 47 |
+
False.
|
| 48 |
+
random_state : int, optional
|
| 49 |
+
when shuffle is True, random_state affects the ordering of the indices, by default
|
| 50 |
+
None.
|
| 51 |
+
|
| 52 |
+
Returns
|
| 53 |
+
-------
|
| 54 |
+
train_data : pd.DataFrame
|
| 55 |
+
train data stratified by class. Also stratified by age/sex as specified in input
|
| 56 |
+
parameters.
|
| 57 |
+
test_data : pd.DataFrame
|
| 58 |
+
test data stratified by class. Also stratified by age/sex as specified in input
|
| 59 |
+
parameters.
|
| 60 |
+
|
| 61 |
+
Raises
|
| 62 |
+
-------
|
| 63 |
+
ValueError : error raised when boolean stratify_by_age or stratify_by_sex is True but
|
| 64 |
+
the respective columns are not provided for stratifying.
|
| 65 |
+
|
| 66 |
+
"""
|
| 67 |
+
# Raise error if stratify_by_sex/stratify_by_age is True but the respective columns are
|
| 68 |
+
# not provided
|
| 69 |
+
if (stratify_by_age is True) & (age_bin_col is None):
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Parameter stratify_by_age is True but age_bin_col not provided."
|
| 72 |
+
)
|
| 73 |
+
if (stratify_by_sex is True) & (sex_col is None):
|
| 74 |
+
raise ValueError("Parameter stratify_by_sex is True but sex_col not provided.")
|
| 75 |
+
|
| 76 |
+
# Adapt target column to contain all variables to split by to allow stratified splitting
|
| 77 |
+
# by StratifiedGroupKFold.
|
| 78 |
+
if (stratify_by_sex is True) and (stratify_by_age is True):
|
| 79 |
+
data["TempTarget"] = (
|
| 80 |
+
data[target_col].astype(str) + data[sex_col].astype(str) + data[age_bin_col]
|
| 81 |
+
)
|
| 82 |
+
elif (stratify_by_sex is True) and (stratify_by_age is False):
|
| 83 |
+
data["TempTarget"] = data[target_col].astype(str) + data[sex_col].astype(str)
|
| 84 |
+
elif (stratify_by_sex is False) and (stratify_by_age is True):
|
| 85 |
+
data["TempTarget"] = data[target_col].astype(str) + data[sex_col].astype(str)
|
| 86 |
+
else:
|
| 87 |
+
data["TempTarget"] = data[target_col]
|
| 88 |
+
temp_target_col = "TempTarget"
|
| 89 |
+
|
| 90 |
+
# Calculate the number of folds to split data to using the size of the test data.
|
| 91 |
+
num_folds = round(1 / test_size)
|
| 92 |
+
sgkf = StratifiedGroupKFold(
|
| 93 |
+
n_splits=num_folds, shuffle=shuffle, random_state=random_state
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Retrieve the first split and create train and test dfs
|
| 97 |
+
train_test_splits = next(
|
| 98 |
+
sgkf.split(data, data[temp_target_col], groups=data[id_col])
|
| 99 |
+
)
|
| 100 |
+
train_ids = train_test_splits[0].tolist()
|
| 101 |
+
test_ids = train_test_splits[1].tolist()
|
| 102 |
+
train_data = data.iloc[train_ids]
|
| 103 |
+
test_data = data.iloc[test_ids]
|
| 104 |
+
|
| 105 |
+
# Drop temporary target column
|
| 106 |
+
train_data = train_data.drop(columns=temp_target_col)
|
| 107 |
+
test_data = test_data.drop(columns=temp_target_col)
|
| 108 |
+
|
| 109 |
+
return train_data, test_data
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def subject_wise_kfold_split(
|
| 113 |
+
train_data,
|
| 114 |
+
target_col,
|
| 115 |
+
id_col,
|
| 116 |
+
num_folds,
|
| 117 |
+
sex_col=None,
|
| 118 |
+
age_col=None,
|
| 119 |
+
stratify_by_sex=False,
|
| 120 |
+
stratify_by_age=False,
|
| 121 |
+
age_bin_col=None,
|
| 122 |
+
shuffle=False,
|
| 123 |
+
random_state=None,
|
| 124 |
+
print_log=False,
|
| 125 |
+
):
|
| 126 |
+
"""Subject wise splitting data into balanced K-folds.
|
| 127 |
+
|
| 128 |
+
Splits data into K-folds ensuring that the same patient can only appear in
|
| 129 |
+
either train or validation set. Stratifies data according to class label, with additional
|
| 130 |
+
options to stratify by sex and age.
|
| 131 |
+
|
| 132 |
+
Parameters
|
| 133 |
+
----------
|
| 134 |
+
train_data : pd.DataFrame
|
| 135 |
+
dataframe containing features and target.
|
| 136 |
+
target_col : str
|
| 137 |
+
name of target column.
|
| 138 |
+
id_col : str
|
| 139 |
+
name of patient id column.
|
| 140 |
+
num_folds : int
|
| 141 |
+
number of folds.
|
| 142 |
+
sex_col : str, optional
|
| 143 |
+
name of sex column, by default None. Required if stratify_by_sex is True. Can be
|
| 144 |
+
included when stratify_by_sex is False to get info on sex ratio across folds.
|
| 145 |
+
age_col : str, optional
|
| 146 |
+
name of age column, by default None. Column must be a continous variable. Can be
|
| 147 |
+
included to get info on mean age across folds.
|
| 148 |
+
stratify_by_sex : bool, optional
|
| 149 |
+
option to stratify data by sex, by default False.
|
| 150 |
+
stratify_by_age : bool, optional
|
| 151 |
+
option to stratify data by age, by default False. The binned age (age_bin_col) is
|
| 152 |
+
used for stratifying by age rather than age_col.
|
| 153 |
+
age_bin_col : str, optional
|
| 154 |
+
name of age column, by default None. Age column must be provided in binned format.
|
| 155 |
+
shuffle : bool, optional
|
| 156 |
+
whether to shuffle each class's samples before splitting into batches, by default
|
| 157 |
+
False.
|
| 158 |
+
random_state : int, optional
|
| 159 |
+
when shuffle is True, random_state affects the ordering of the indices, by default
|
| 160 |
+
None.
|
| 161 |
+
print_log : bool, optional
|
| 162 |
+
flag to print distributions across folds, by default False.
|
| 163 |
+
|
| 164 |
+
Returns
|
| 165 |
+
-------
|
| 166 |
+
validation_fold_ids : list of arrays
|
| 167 |
+
each array contains the validation patient IDs for each fold.
|
| 168 |
+
|
| 169 |
+
Raises
|
| 170 |
+
-------
|
| 171 |
+
ValueError : error raised when boolean stratify_by_age or stratify_by_sex is True but
|
| 172 |
+
the respective columns are not provided for stratifying.
|
| 173 |
+
|
| 174 |
+
"""
|
| 175 |
+
# Raise error if stratify_by_sex/stratify_by_age is True but the respective columns are
|
| 176 |
+
# not provided
|
| 177 |
+
if (stratify_by_age is True) & (age_bin_col is None):
|
| 178 |
+
raise ValueError(
|
| 179 |
+
"Parameter stratify_by_age is True but age_bin_col not provided."
|
| 180 |
+
)
|
| 181 |
+
if (stratify_by_sex is True) & (sex_col is None):
|
| 182 |
+
raise ValueError("Parameter stratify_by_sex is True but sex_col not provided.")
|
| 183 |
+
|
| 184 |
+
# Adapt target column to contain all variables to split by to allow stratified splitting
|
| 185 |
+
# by StratifiedGroupKFold.
|
| 186 |
+
if (stratify_by_sex is True) and (stratify_by_age is True):
|
| 187 |
+
train_data["TempTarget"] = (
|
| 188 |
+
train_data[target_col].astype(str)
|
| 189 |
+
+ train_data[sex_col].astype(str)
|
| 190 |
+
+ train_data[age_bin_col]
|
| 191 |
+
)
|
| 192 |
+
elif (stratify_by_sex is True) and (stratify_by_age is False):
|
| 193 |
+
train_data["TempTarget"] = train_data[target_col].astype(str) + train_data[
|
| 194 |
+
sex_col
|
| 195 |
+
].astype(str)
|
| 196 |
+
elif (stratify_by_sex is False) and (stratify_by_age is True):
|
| 197 |
+
train_data["TempTarget"] = train_data[target_col].astype(str) + train_data[
|
| 198 |
+
sex_col
|
| 199 |
+
].astype(str)
|
| 200 |
+
else:
|
| 201 |
+
train_data["TempTarget"] = train_data[target_col]
|
| 202 |
+
temp_target_col = "TempTarget"
|
| 203 |
+
|
| 204 |
+
sgkf_train = StratifiedGroupKFold(
|
| 205 |
+
n_splits=num_folds, shuffle=shuffle, random_state=random_state
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
validation_fold_ids = []
|
| 209 |
+
class_fold_ratios = []
|
| 210 |
+
sex_fold_ratios = []
|
| 211 |
+
fold_mean_ages = []
|
| 212 |
+
for i, (train_index, validation_index) in enumerate(
|
| 213 |
+
sgkf_train.split(
|
| 214 |
+
train_data, train_data[temp_target_col], groups=train_data[id_col]
|
| 215 |
+
)
|
| 216 |
+
):
|
| 217 |
+
# Get patient ID's for each validation fold
|
| 218 |
+
validation_ids = train_data[id_col].iloc[validation_index].unique()
|
| 219 |
+
validation_fold_ids.append(validation_ids)
|
| 220 |
+
|
| 221 |
+
# Get class, sex and age distributions for the training data for each fold
|
| 222 |
+
train_fold_data = train_data[~train_data[id_col].isin(validation_ids)]
|
| 223 |
+
class_ratio = train_fold_data[target_col].value_counts()[1] / len(
|
| 224 |
+
train_fold_data
|
| 225 |
+
)
|
| 226 |
+
class_fold_ratios.append(class_ratio)
|
| 227 |
+
if not sex_col is None:
|
| 228 |
+
sex_ratio = train_fold_data[sex_col].value_counts()[1] / len(
|
| 229 |
+
train_fold_data
|
| 230 |
+
)
|
| 231 |
+
sex_fold_ratios.append(sex_ratio)
|
| 232 |
+
if not age_col is None:
|
| 233 |
+
mean_age = train_fold_data[age_col].mean()
|
| 234 |
+
fold_mean_ages.append(mean_age)
|
| 235 |
+
|
| 236 |
+
if print_log is True:
|
| 237 |
+
print("Fold proportions:")
|
| 238 |
+
print("Train class ratio:", class_fold_ratios)
|
| 239 |
+
if not sex_col is None:
|
| 240 |
+
print("Sex class ratio:", sex_fold_ratios)
|
| 241 |
+
if not age_col is None:
|
| 242 |
+
print("Mean age:", fold_mean_ages)
|
| 243 |
+
|
| 244 |
+
# Allows inhomogenous array to be saved with np.save
|
| 245 |
+
validation_fold_ids = np.asarray(validation_fold_ids, dtype="object")
|
| 246 |
+
|
| 247 |
+
# Delete temporary target column
|
| 248 |
+
del train_data[temp_target_col]
|
| 249 |
+
|
| 250 |
+
return validation_fold_ids
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def get_cv_fold_indices(validation_fold_ids, train_data, id_col):
|
| 254 |
+
"""
|
| 255 |
+
Find train/val dataframe indices for each fold and format for cross validation.
|
| 256 |
+
|
| 257 |
+
Creates a tuple with training and validation indices for each K-fold using the
|
| 258 |
+
validation_fold_ids. These patients are assigned to the validation portion of the data
|
| 259 |
+
and all other patients are assigned to the train portion for that fold.
|
| 260 |
+
Checks that all patient IDs listed for the K folds are contained in the train data.
|
| 261 |
+
For each fold, extracts the dataframe indices for patient data belonging that fold
|
| 262 |
+
and assigns all other indices to the 'train' portion. The list returned contains tuples
|
| 263 |
+
required to be passed to sklearn's cross_validate function (through the cv argument).
|
| 264 |
+
|
| 265 |
+
Parameters
|
| 266 |
+
----------
|
| 267 |
+
fold_patients : array
|
| 268 |
+
lists of patient IDs for each of the K folds.
|
| 269 |
+
train_data : pd.DataFrame
|
| 270 |
+
train data (must contain id_col).
|
| 271 |
+
id_col : str
|
| 272 |
+
name of column containing patient ID.
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
cross_validation_fold_indices : list of tuples
|
| 277 |
+
K lists of val/train dataframe indices.
|
| 278 |
+
|
| 279 |
+
"""
|
| 280 |
+
# Create a tuple with training and validation indices for each fold.
|
| 281 |
+
cross_val_fold_indices = []
|
| 282 |
+
for fold in validation_fold_ids:
|
| 283 |
+
fold_val_ids = train_data[train_data[id_col].isin(fold)]
|
| 284 |
+
fold_train_ids = train_data[~train_data[id_col].isin(fold)]
|
| 285 |
+
|
| 286 |
+
# Get index of rows in val and train
|
| 287 |
+
fold_val_index = fold_val_ids.index
|
| 288 |
+
fold_train_index = fold_train_ids.index
|
| 289 |
+
|
| 290 |
+
# Append tuple of training and val indices
|
| 291 |
+
cross_val_fold_indices.append((fold_train_index, fold_val_index))
|
| 292 |
+
return cross_val_fold_indices
|