Upload 34 files
Browse files- README.md +153 -0
- configurations/__pycache__/fastai_configs.cpython-310.pyc +0 -0
- configurations/__pycache__/fastai_configs.cpython-39.pyc +0 -0
- configurations/fastai_configs.py +137 -0
- configurations/wavelet_configs.py +17 -0
- evaluation/Model_Evaluation.ipynb +0 -0
- experiments/__pycache__/scp_experiment.cpython-310.pyc +0 -0
- experiments/__pycache__/scp_experiment.cpython-39.pyc +0 -0
- experiments/scp_experiment.py +227 -0
- exploratory_data_analysis/AutoECG_EDA.ipynb +0 -0
- main.py +30 -0
- models/__pycache__/base_model.cpython-39.pyc +0 -0
- models/__pycache__/basicconv1d.cpython-39.pyc +0 -0
- models/__pycache__/fastaiModel.cpython-310.pyc +0 -0
- models/__pycache__/fastaiModel.cpython-39.pyc +0 -0
- models/__pycache__/inception1d.cpython-39.pyc +0 -0
- models/__pycache__/resnet1d.cpython-39.pyc +0 -0
- models/__pycache__/rnn1d.cpython-39.pyc +0 -0
- models/__pycache__/wavelet.cpython-39.pyc +0 -0
- models/__pycache__/xresnet1d.cpython-39.pyc +0 -0
- models/base_model.py +10 -0
- models/basicconv1d.py +240 -0
- models/fastaiModel.py +513 -0
- models/inception1d.py +137 -0
- models/resnet1d.py +299 -0
- models/rnn1d.py +67 -0
- models/wavelet.py +158 -0
- models/xresnet1d.py +239 -0
- requirements.txt +23 -0
- utilities/__pycache__/timeseries_utils.cpython-39.pyc +0 -0
- utilities/__pycache__/utils.cpython-39.pyc +0 -0
- utilities/stratify.py +173 -0
- utilities/timeseries_utils.py +649 -0
- utilities/utils.py +509 -0
README.md
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automated ECG Interpretation
|
| 2 |
+
|
| 3 |
+
[![Contributors][contributors-shield]][contributors-url]
|
| 4 |
+
[](https://github.com/AutoECG/Automated-ECG-Interpretation/network)
|
| 5 |
+
[](https://github.com/AutoECG/Automated-ECG-Interpretation/stargazers)
|
| 6 |
+
[](https://github.com/AutoECG/Automated-ECG-Interpretation/issues)
|
| 7 |
+
|
| 8 |
+
<br>
|
| 9 |
+
|
| 10 |
+
<div align="center">
|
| 11 |
+
<img src="https://user-images.githubusercontent.com/46399191/191921241-495090db-a088-46b6-bd09-0f7f21170b0a.png" height="350"/>
|
| 12 |
+
</div>
|
| 13 |
+
|
| 14 |
+
## Summary
|
| 15 |
+
|
| 16 |
+
Electrocardiography (ECG) is a key diagnostic tool to assess the cardiac condition of a patient. Automatic ECG interpretation algorithms as diagnosis support systems promise large reliefs for the medical personnel - only based on the number of ECGs that are routinely taken. However, the development of such algorithms requires large training datasets and clear benchmark procedures.
|
| 17 |
+
|
| 18 |
+
## Data Description
|
| 19 |
+
|
| 20 |
+
The [PTB-XL ECG dataset](https://physionet.org/content/ptb-xl/1.0.1/) is a large dataset of 21837 clinical 12-lead ECGs from 18885 patients of 10 second length. The raw waveform data was annotated by up to two cardiologists, who assigned potentially multiple ECG statements to each record. In total 71 different ECG statements conform to the SCP-ECG standard and cover diagnostic, form, and rhythm statements. Combined with the extensive annotation, this turns the dataset into a rich resource for training and evaluating automatic ECG interpretation algorithms. The dataset is complemented by extensive metadata on demographics, infarction characteristics, likelihoods for diagnostic ECG statements, and annotated signal properties.
|
| 21 |
+
|
| 22 |
+
In general, the dataset is organized as follows:
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
ptbxl
|
| 26 |
+
├── ptbxl_database.csv
|
| 27 |
+
├── scp_statements.csv
|
| 28 |
+
├── records100
|
| 29 |
+
├── 00000
|
| 30 |
+
│ │ ├── 00001_lr.dat
|
| 31 |
+
│ │ ├── 00001_lr.hea
|
| 32 |
+
│ │ ├── ...
|
| 33 |
+
│ │ ├── 00999_lr.dat
|
| 34 |
+
│ │ └── 00999_lr.hea
|
| 35 |
+
│ ├── ...
|
| 36 |
+
│ └── 21000
|
| 37 |
+
│ ├── 21001_lr.dat
|
| 38 |
+
│ ├── 21001_lr.hea
|
| 39 |
+
│ ├── ...
|
| 40 |
+
│ ├── 21837_lr.dat
|
| 41 |
+
│ └── 21837_lr.hea
|
| 42 |
+
└── records500
|
| 43 |
+
├── 00000
|
| 44 |
+
│ ├── 00001_hr.dat
|
| 45 |
+
│ ├── 00001_hr.hea
|
| 46 |
+
│ ├── ...
|
| 47 |
+
│ ├── 00999_hr.dat
|
| 48 |
+
│ └── 00999_hr.hea
|
| 49 |
+
├── ...
|
| 50 |
+
└── 21000
|
| 51 |
+
├── 21001_hr.dat
|
| 52 |
+
├── 21001_hr.hea
|
| 53 |
+
├── ...
|
| 54 |
+
├── 21837_hr.dat
|
| 55 |
+
└── 21837_hr.hea
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
The dataset comprises 21837 clinical 12-lead ECG records of 10 seconds length from 18885 patients, where 52% are male and 48% are female with ages covering the whole range from 0 to 95 years (median 62 and interquantile range of 22). The value of the dataset results from the comprehensive collection of many different co-occurring pathologies, but also from a large proportion of healthy control samples.
|
| 59 |
+
|
| 60 |
+
| Records | Superclass | Description |
|
| 61 |
+
|:---|:---|:---|
|
| 62 |
+
9528 | NORM | Normal ECG |
|
| 63 |
+
5486 | MI | Myocardial Infarction |
|
| 64 |
+
5250 | STTC | ST/T Change |
|
| 65 |
+
4907 | CD | Conduction Disturbance |
|
| 66 |
+
2655 | HYP | Hypertrophy |
|
| 67 |
+
|
| 68 |
+
The waveform files are stored in WaveForm DataBase (WFDB) format with 16-bit precision at a resolution of 1μV/LSB and a sampling frequency of 500Hz (records500/) beside downsampled versions of the waveform data at a sampling frequency of 100Hz (records100/).
|
| 69 |
+
|
| 70 |
+
All relevant metadata is stored in ptbxldatabase.csv with one row per record identified by ecgid and it contains 28 columns.
|
| 71 |
+
|
| 72 |
+
All information related to the used annotation scheme is stored in a dedicated scp_statements.csv that was enriched with mappings to other annotation standards.
|
| 73 |
+
|
| 74 |
+
## Setup
|
| 75 |
+
|
| 76 |
+
### Install dependencies
|
| 77 |
+
Install the dependencies (wfdb, pytorch, torchvision, cudatoolkit, fastai, fastprogress) by creating a conda environment:
|
| 78 |
+
|
| 79 |
+
conda env create -f requirements.yml
|
| 80 |
+
conda activate autoecg_env
|
| 81 |
+
|
| 82 |
+
### Get data
|
| 83 |
+
Download the dataset (PTB-XL) via the follwing bash-script:
|
| 84 |
+
|
| 85 |
+
get_dataset.sh
|
| 86 |
+
|
| 87 |
+
This script first downloads [PTB-XL from PhysioNet](https://physionet.org/content/ptb-xl/) and stores it in `data/ptbxl/`.
|
| 88 |
+
|
| 89 |
+
## Usage
|
| 90 |
+
|
| 91 |
+
python main.py
|
| 92 |
+
|
| 93 |
+
This will perform all experiments for inception1d.
|
| 94 |
+
Depending on the executing environment, this will take up to several hours.
|
| 95 |
+
Once finished, all trained models, predictions and results are stored in `output/`,
|
| 96 |
+
where for each experiment a sub-folder is created each with `data/`, `models/` and `results/` sub-sub-folders.
|
| 97 |
+
|
| 98 |
+
| Model | AUC ↓ | Experiment |
|
| 99 |
+
|:---|:---|:---|
|
| 100 |
+
| inception1d | 0.927(00) | All statements |
|
| 101 |
+
| inception1d | 0.929(00) | Diagnostic statements |
|
| 102 |
+
| inception1d | 0.926(00) | Diagnostic subclasses |
|
| 103 |
+
| inception1d | 0.919(00) | Diagnostic superclasses |
|
| 104 |
+
| inception1d | 0.883(00) | Form statements |
|
| 105 |
+
| inception1d | 0.949(00) | Rhythm statements |
|
| 106 |
+
|
| 107 |
+
### Download model and results
|
| 108 |
+
|
| 109 |
+
We also provide a [compressed zip-archive](https://drive.google.com/drive/folders/17za6IanRm7rpb1ZGHLQ80mJvBj_53LXJ?usp=sharing) containing the `output` folder corresponding to our runs including trained model and predictions.
|
| 110 |
+
|
| 111 |
+
## Results for Inception1d Model
|
| 112 |
+
|
| 113 |
+
| Experiment name | Accuracy | Precision | Recall | F1_Score | Specificity |
|
| 114 |
+
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
|
| 115 |
+
| All | 0.9792 | 0.8949 | 0.1408 | 0.4824 | 0.9921 |
|
| 116 |
+
| Diagnostic | 0.9806 | 0.8440 | 0.1556 | 0.4746 | 0.9952 |
|
| 117 |
+
| Sub-Diagnostic | 0.9660 | 0.8315 | 0.3021 | 0.5119 | 0.9887 |
|
| 118 |
+
| Super-Diagnostic | 0.8847 | 0.7938 | 0.6757 | 0.7157 | 0.9251 |
|
| 119 |
+
| Form | 0.9452 | 0.5619 | 0.1420 | 0.3843 | 0.9916 |
|
| 120 |
+
| Rhythm | 0.9844 | 0.7676 | 0.4489 | 0.7290 | 0.9722 |
|
| 121 |
+
|
| 122 |
+
For more evaluation (Confusion Matrix, ROC curve) information and visualizations visit: [Model Evaluation](https://github.com/AutoECG/Automated-ECG-Interpretation/blob/main/evaluation/Model_Evaluation.ipynb)
|
| 123 |
+
|
| 124 |
+
## Contribution
|
| 125 |
+
|
| 126 |
+
Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**.
|
| 127 |
+
|
| 128 |
+
If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement".
|
| 129 |
+
Don't forget to give the project a star! Thanks again!
|
| 130 |
+
|
| 131 |
+
1. [Fork the Project](https://github.com/AutoECG/Automated-ECG-Interpretation/fork)
|
| 132 |
+
2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`)
|
| 133 |
+
3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`)
|
| 134 |
+
4. Push to the Branch (`git push origin feature/AmazingFeature`)
|
| 135 |
+
5. Open a Pull Request
|
| 136 |
+
|
| 137 |
+
## Future Works
|
| 138 |
+
|
| 139 |
+
1. Model Deployment.
|
| 140 |
+
2. Continue Preprocessing new ECG data from hospitals to test model reliability and accuracy.
|
| 141 |
+
3. Figure out different parsing options for xml ecg files from different ECG machines versions.
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
## Contact
|
| 145 |
+
|
| 146 |
+
Feel free to reach out to us:
|
| 147 |
+
- DM [Zaki Kurdya](https://twitter.com/ZakiKurdya)
|
| 148 |
+
- DM [Zeina Saadeddin](https://twitter.com/jszeina)
|
| 149 |
+
- DM [Salam Thabit](https://twitter.com/salamThabetDo)
|
| 150 |
+
|
| 151 |
+
<!-- MARKDOWN LINKS -->
|
| 152 |
+
[contributors-shield]: https://img.shields.io/github/contributors/AutoECG/Automated-ECG-Interpretation.svg?style=flat-square&color=blue
|
| 153 |
+
[contributors-url]: https://github.com/AutoECG/Automated-ECG-Interpretation/graphs/contributors
|
configurations/__pycache__/fastai_configs.cpython-310.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
configurations/__pycache__/fastai_configs.cpython-39.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
configurations/fastai_configs.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
conf_fastai_resnet1d18 = {'model_name': 'fastai_resnet1d18', 'model_type': 'FastaiModel',
|
| 2 |
+
'parameters': dict()}
|
| 3 |
+
|
| 4 |
+
conf_fastai_resnet1d34 = {'model_name': 'fastai_resnet1d34', 'model_type': 'FastaiModel',
|
| 5 |
+
'parameters': dict()}
|
| 6 |
+
|
| 7 |
+
conf_fastai_resnet1d50 = {'model_name': 'fastai_resnet1d50', 'model_type': 'FastaiModel',
|
| 8 |
+
'parameters': dict()}
|
| 9 |
+
|
| 10 |
+
conf_fastai_resnet1d101 = {'model_name': 'fastai_resnet1d101', 'model_type': 'FastaiModel',
|
| 11 |
+
'parameters': dict()}
|
| 12 |
+
|
| 13 |
+
conf_fastai_resnet1d152 = {'model_name': 'fastai_resnet1d152', 'model_type': 'FastaiModel',
|
| 14 |
+
'parameters': dict()}
|
| 15 |
+
|
| 16 |
+
conf_fastai_resnet1d_wang = {'model_name': 'fastai_resnet1d_wang', 'model_type': 'FastaiModel',
|
| 17 |
+
'parameters': dict()}
|
| 18 |
+
|
| 19 |
+
conf_fastai_wrn1d_22 = {'model_name': 'fastai_wrn1d_22', 'model_type': 'FastaiModel',
|
| 20 |
+
'parameters': dict()}
|
| 21 |
+
|
| 22 |
+
conf_fastai_xresnet1d18 = {'model_name': 'fastai_xresnet1d18', 'model_type': 'FastaiModel',
|
| 23 |
+
'parameters': dict()}
|
| 24 |
+
|
| 25 |
+
conf_fastai_xresnet1d34 = {'model_name': 'fastai_xresnet1d34', 'model_type': 'FastaiModel',
|
| 26 |
+
'parameters': dict()}
|
| 27 |
+
|
| 28 |
+
conf_fastai_xresnet1d50 = {'model_name': 'fastai_xresnet1d50', 'model_type': 'FastaiModel',
|
| 29 |
+
'parameters': dict()}
|
| 30 |
+
|
| 31 |
+
# more xresnet50s
|
| 32 |
+
conf_fastai_xresnet1d50_ep30 = {'model_name': 'fastai_xresnet1d50_ep30', 'model_type': 'FastaiModel',
|
| 33 |
+
'parameters': dict(epochs=30)}
|
| 34 |
+
|
| 35 |
+
conf_fastai_xresnet1d50_validloss_ep30 = {'model_name': 'fastai_xresnet1d50_validloss_ep30',
|
| 36 |
+
'model_type': 'FastaiModel',
|
| 37 |
+
'parameters': dict(early_stopping="valid_loss", epochs=30)}
|
| 38 |
+
|
| 39 |
+
conf_fastai_xresnet1d50_macroauc_ep30 = {'model_name': 'fastai_xresnet1d50_macroauc_ep30', 'model_type': 'FastaiModel',
|
| 40 |
+
'parameters': dict(early_stopping="macro_auc", epochs=30)}
|
| 41 |
+
|
| 42 |
+
conf_fastai_xresnet1d50_fmax_ep30 = {'model_name': 'fastai_xresnet1d50_fmax_ep30', 'model_type': 'FastaiModel',
|
| 43 |
+
'parameters': dict(early_stopping="fmax", epochs=30)}
|
| 44 |
+
|
| 45 |
+
conf_fastai_xresnet1d50_ep50 = {'model_name': 'fastai_xresnet1d50_ep50', 'model_type': 'FastaiModel',
|
| 46 |
+
'parameters': dict(epochs=50)}
|
| 47 |
+
|
| 48 |
+
conf_fastai_xresnet1d50_validloss_ep50 = {'model_name': 'fastai_xresnet1d50_validloss_ep50',
|
| 49 |
+
'model_type': 'FastaiModel',
|
| 50 |
+
'parameters': dict(early_stopping="valid_loss", epochs=50)}
|
| 51 |
+
|
| 52 |
+
conf_fastai_xresnet1d50_macroauc_ep50 = {'model_name': 'fastai_xresnet1d50_macroauc_ep50', 'model_type': 'FastaiModel',
|
| 53 |
+
'parameters': dict(early_stopping="macro_auc", epochs=50)}
|
| 54 |
+
|
| 55 |
+
conf_fastai_xresnet1d50_fmax_ep50 = {'model_name': 'fastai_xresnet1d50_fmax_ep50', 'model_type': 'FastaiModel',
|
| 56 |
+
'parameters': dict(early_stopping="fmax", epochs=50)}
|
| 57 |
+
|
| 58 |
+
conf_fastai_xresnet1d101 = {'model_name': 'fastai_xresnet1d101', 'model_type': 'FastaiModel',
|
| 59 |
+
'parameters': dict()}
|
| 60 |
+
|
| 61 |
+
conf_fastai_xresnet1d152 = {'model_name': 'fastai_xresnet1d152', 'model_type': 'FastaiModel',
|
| 62 |
+
'parameters': dict()}
|
| 63 |
+
|
| 64 |
+
conf_fastai_xresnet1d18_deep = {'model_name': 'fastai_xresnet1d18_deep', 'model_type': 'FastaiModel',
|
| 65 |
+
'parameters': dict()}
|
| 66 |
+
|
| 67 |
+
conf_fastai_xresnet1d34_deep = {'model_name': 'fastai_xresnet1d34_deep', 'model_type': 'FastaiModel',
|
| 68 |
+
'parameters': dict()}
|
| 69 |
+
|
| 70 |
+
conf_fastai_xresnet1d50_deep = {'model_name': 'fastai_xresnet1d50_deep', 'model_type': 'FastaiModel',
|
| 71 |
+
'parameters': dict()}
|
| 72 |
+
|
| 73 |
+
conf_fastai_xresnet1d18_deeper = {'model_name': 'fastai_xresnet1d18_deeper', 'model_type': 'FastaiModel',
|
| 74 |
+
'parameters': dict()}
|
| 75 |
+
|
| 76 |
+
conf_fastai_xresnet1d34_deeper = {'model_name': 'fastai_xresnet1d34_deeper', 'model_type': 'FastaiModel',
|
| 77 |
+
'parameters': dict()}
|
| 78 |
+
|
| 79 |
+
conf_fastai_xresnet1d50_deeper = {'model_name': 'fastai_xresnet1d50_deeper', 'model_type': 'FastaiModel',
|
| 80 |
+
'parameters': dict()}
|
| 81 |
+
|
| 82 |
+
conf_fastai_inception1d = {'model_name': 'fastai_inception1d', 'model_type': 'FastaiModel',
|
| 83 |
+
'parameters': dict()}
|
| 84 |
+
|
| 85 |
+
conf_fastai_inception1d_input256 = {'model_name': 'fastai_inception1d_input256', 'model_type': 'FastaiModel',
|
| 86 |
+
'parameters': dict(input_size=256)}
|
| 87 |
+
|
| 88 |
+
conf_fastai_inception1d_input512 = {'model_name': 'fastai_inception1d_input512', 'model_type': 'FastaiModel',
|
| 89 |
+
'parameters': dict(input_size=512)}
|
| 90 |
+
|
| 91 |
+
conf_fastai_inception1d_input1000 = {'model_name': 'fastai_inception1d_input1000', 'model_type': 'FastaiModel',
|
| 92 |
+
'parameters': dict(input_size=1000)}
|
| 93 |
+
|
| 94 |
+
conf_fastai_inception1d_no_residual = {'model_name': 'fastai_inception1d_no_residual', 'model_type': 'FastaiModel',
|
| 95 |
+
'parameters': dict()}
|
| 96 |
+
|
| 97 |
+
conf_fastai_fcn = {'model_name': 'fastai_fcn', 'model_type': 'FastaiModel',
|
| 98 |
+
'parameters': dict()}
|
| 99 |
+
|
| 100 |
+
conf_fastai_fcn_wang = {'model_name': 'fastai_fcn_wang', 'model_type': 'FastaiModel',
|
| 101 |
+
'parameters': dict()}
|
| 102 |
+
|
| 103 |
+
conf_fastai_schirrmeister = {'model_name': 'fastai_schirrmeister', 'model_type': 'FastaiModel',
|
| 104 |
+
'parameters': dict()}
|
| 105 |
+
|
| 106 |
+
conf_fastai_sen = {'model_name': 'fastai_sen', 'model_type': 'FastaiModel',
|
| 107 |
+
'parameters': dict()}
|
| 108 |
+
|
| 109 |
+
conf_fastai_basic1d = {'model_name': 'fastai_basic1d', 'model_type': 'FastaiModel',
|
| 110 |
+
'parameters': dict()}
|
| 111 |
+
|
| 112 |
+
conf_fastai_lstm = {'model_name': 'fastai_lstm', 'model_type': 'FastaiModel',
|
| 113 |
+
'parameters': dict(lr=1e-3)}
|
| 114 |
+
|
| 115 |
+
conf_fastai_gru = {'model_name': 'fastai_gru', 'model_type': 'FastaiModel',
|
| 116 |
+
'parameters': dict(lr=1e-3)}
|
| 117 |
+
|
| 118 |
+
conf_fastai_lstm_bidir = {'model_name': 'fastai_lstm_bidir', 'model_type': 'FastaiModel',
|
| 119 |
+
'parameters': dict(lr=1e-3)}
|
| 120 |
+
|
| 121 |
+
conf_fastai_gru_bidir = {'model_name': 'fastai_gru', 'model_type': 'FastaiModel',
|
| 122 |
+
'parameters': dict(lr=1e-3)}
|
| 123 |
+
|
| 124 |
+
conf_fastai_lstm_input1000 = {'model_name': 'fastai_lstm_input1000', 'model_type': 'FastaiModel',
|
| 125 |
+
'parameters': dict(input_size=1000, lr=1e-3)}
|
| 126 |
+
|
| 127 |
+
conf_fastai_gru_input1000 = {'model_name': 'fastai_gru_input1000', 'model_type': 'FastaiModel',
|
| 128 |
+
'parameters': dict(input_size=1000, lr=1e-3)}
|
| 129 |
+
|
| 130 |
+
conf_fastai_schirrmeister_input500 = {'model_name': 'fastai_schirrmeister_input500', 'model_type': 'FastaiModel',
|
| 131 |
+
'parameters': dict(input_size=500)}
|
| 132 |
+
|
| 133 |
+
conf_fastai_inception1d_input500 = {'model_name': 'fastai_inception1d_input500', 'model_type': 'FastaiModel',
|
| 134 |
+
'parameters': dict(input_size=500)}
|
| 135 |
+
|
| 136 |
+
conf_fastai_fcn_wang_input500 = {'model_name': 'fastai_fcn_wang_input500', 'model_type': 'FastaiModel',
|
| 137 |
+
'parameters': dict(input_size=500)}
|
configurations/wavelet_configs.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
conf_wavelet_standard_lr = {'model_name': 'Wavelet+LR', 'model_type': 'WAVELET',
|
| 2 |
+
'parameters': dict(
|
| 3 |
+
regularizer_C=.001,
|
| 4 |
+
classifier='LR'
|
| 5 |
+
)}
|
| 6 |
+
|
| 7 |
+
conf_wavelet_standard_rf = {'model_name': 'Wavelet+RF', 'model_type': 'WAVELET',
|
| 8 |
+
'parameters': dict(
|
| 9 |
+
regularizer_C=.001,
|
| 10 |
+
classifier='RF'
|
| 11 |
+
)}
|
| 12 |
+
|
| 13 |
+
conf_wavelet_standard_nn = {'model_name': 'Wavelet+NN', 'model_type': 'WAVELET',
|
| 14 |
+
'parameters': dict(
|
| 15 |
+
regularizer_C=.001,
|
| 16 |
+
classifier='NN'
|
| 17 |
+
)}
|
evaluation/Model_Evaluation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
experiments/__pycache__/scp_experiment.cpython-310.pyc
ADDED
|
Binary file (5.81 kB). View file
|
|
|
experiments/__pycache__/scp_experiment.cpython-39.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
experiments/scp_experiment.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing
|
| 2 |
+
from itertools import repeat
|
| 3 |
+
|
| 4 |
+
from models import fastaiModel
|
| 5 |
+
from models.wavelet import WaveletModel
|
| 6 |
+
from utilities.utils import *
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SCPExperiment:
|
| 10 |
+
"""
|
| 11 |
+
Experiment on SCP-ECG statements.
|
| 12 |
+
All experiments based on SCP are performed and evaluated the same way.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, experiment_name, task, data_folder, output_folder, models,
|
| 16 |
+
sampling_frequency=100, min_samples=0, train_fold=8, val_fold=9,
|
| 17 |
+
test_fold=10, folds_type='strat'):
|
| 18 |
+
self.models = models
|
| 19 |
+
self.min_samples = min_samples
|
| 20 |
+
self.task = task
|
| 21 |
+
self.train_fold = train_fold
|
| 22 |
+
self.val_fold = val_fold
|
| 23 |
+
self.test_fold = test_fold
|
| 24 |
+
self.folds_type = folds_type
|
| 25 |
+
self.experiment_name = experiment_name
|
| 26 |
+
self.output_folder = output_folder
|
| 27 |
+
self.data_folder = data_folder
|
| 28 |
+
self.sampling_frequency = sampling_frequency
|
| 29 |
+
|
| 30 |
+
# create folder structure if needed
|
| 31 |
+
if not os.path.exists(self.output_folder + self.experiment_name):
|
| 32 |
+
os.makedirs(self.output_folder + self.experiment_name)
|
| 33 |
+
if not os.path.exists(self.output_folder + self.experiment_name + '/results/'):
|
| 34 |
+
os.makedirs(self.output_folder + self.experiment_name + '/results/')
|
| 35 |
+
if not os.path.exists(output_folder + self.experiment_name + '/models/'):
|
| 36 |
+
os.makedirs(self.output_folder + self.experiment_name + '/models/')
|
| 37 |
+
if not os.path.exists(output_folder + self.experiment_name + '/data/'):
|
| 38 |
+
os.makedirs(self.output_folder + self.experiment_name + '/data/')
|
| 39 |
+
|
| 40 |
+
def prepare(self):
|
| 41 |
+
# Load PTB-XL data
|
| 42 |
+
self.data, self.raw_labels = load_dataset(self.data_folder, self.sampling_frequency)
|
| 43 |
+
|
| 44 |
+
# Preprocess label data
|
| 45 |
+
self.labels = compute_label_aggregations(self.raw_labels, self.data_folder, self.task)
|
| 46 |
+
|
| 47 |
+
# Select relevant data and convert to one-hot
|
| 48 |
+
self.data, self.labels, self.Y, _ = select_data(self.data, self.labels, self.task, self.min_samples,
|
| 49 |
+
self.output_folder + self.experiment_name + '/data/')
|
| 50 |
+
self.input_shape = self.data[0].shape
|
| 51 |
+
|
| 52 |
+
# 10th fold for testing (9th for now)
|
| 53 |
+
self.X_test = self.data[self.labels.strat_fold == self.test_fold]
|
| 54 |
+
self.y_test = self.Y[self.labels.strat_fold == self.test_fold]
|
| 55 |
+
# 9th fold for validation (8th for now)
|
| 56 |
+
self.X_val = self.data[self.labels.strat_fold == self.val_fold]
|
| 57 |
+
self.y_val = self.Y[self.labels.strat_fold == self.val_fold]
|
| 58 |
+
# rest for training
|
| 59 |
+
self.X_train = self.data[self.labels.strat_fold <= self.train_fold]
|
| 60 |
+
self.y_train = self.Y[self.labels.strat_fold <= self.train_fold]
|
| 61 |
+
|
| 62 |
+
# Preprocess signal data
|
| 63 |
+
self.X_train, self.X_val, self.X_test = preprocess_signals(self.X_train, self.X_val, self.X_test,
|
| 64 |
+
self.output_folder + self.experiment_name + '/data/')
|
| 65 |
+
self.n_classes = self.y_train.shape[1]
|
| 66 |
+
|
| 67 |
+
# save train and test labels
|
| 68 |
+
self.y_train.dump(self.output_folder + self.experiment_name + '/data/y_train.npy')
|
| 69 |
+
self.y_val.dump(self.output_folder + self.experiment_name + '/data/y_val.npy')
|
| 70 |
+
self.y_test.dump(self.output_folder + self.experiment_name + '/data/y_test.npy')
|
| 71 |
+
|
| 72 |
+
model_name = 'naive'
|
| 73 |
+
# create most naive predictions via simple mean in training
|
| 74 |
+
mpath = self.output_folder + self.experiment_name + '/models/' + model_name + '/'
|
| 75 |
+
# create folder for model outputs
|
| 76 |
+
if not os.path.exists(mpath):
|
| 77 |
+
os.makedirs(mpath)
|
| 78 |
+
if not os.path.exists(mpath + 'results/'):
|
| 79 |
+
os.makedirs(mpath + 'results/')
|
| 80 |
+
|
| 81 |
+
mean_y = np.mean(self.y_train, axis=0)
|
| 82 |
+
np.array([mean_y] * len(self.y_train)).dump(mpath + 'y_train_pred.npy')
|
| 83 |
+
np.array([mean_y] * len(self.y_test)).dump(mpath + 'y_test_pred.npy')
|
| 84 |
+
np.array([mean_y] * len(self.y_val)).dump(mpath + 'y_val_pred.npy')
|
| 85 |
+
|
| 86 |
+
def perform(self):
|
| 87 |
+
for model_description in self.models:
|
| 88 |
+
model_name = model_description['model_name']
|
| 89 |
+
model_type = model_description['model_type']
|
| 90 |
+
model_params = model_description['parameters']
|
| 91 |
+
|
| 92 |
+
mpath = self.output_folder + self.experiment_name + '/models/' + model_name + '/'
|
| 93 |
+
# create folder for model outputs
|
| 94 |
+
if not os.path.exists(mpath):
|
| 95 |
+
os.makedirs(mpath)
|
| 96 |
+
if not os.path.exists(mpath + 'results/'):
|
| 97 |
+
os.makedirs(mpath + 'results/')
|
| 98 |
+
|
| 99 |
+
n_classes = self.Y.shape[1]
|
| 100 |
+
# load respective model
|
| 101 |
+
if model_type == 'WAVELET':
|
| 102 |
+
model = WaveletModel(model_name, n_classes, self.sampling_frequency, mpath, self.input_shape,
|
| 103 |
+
**model_params)
|
| 104 |
+
elif model_type == "FastaiModel":
|
| 105 |
+
model = fastaiModel.FastaiModel(model_name, n_classes, self.sampling_frequency, mpath, self.input_shape,
|
| 106 |
+
**model_params)
|
| 107 |
+
else:
|
| 108 |
+
assert True
|
| 109 |
+
break
|
| 110 |
+
# Print to check
|
| 111 |
+
print("Shape of input", self.X_train.shape)
|
| 112 |
+
# fit model
|
| 113 |
+
model.fit(self.X_train, self.y_train, self.X_val, self.y_val)
|
| 114 |
+
# predict and dump
|
| 115 |
+
model.predict(self.X_train).dump(mpath + 'y_train_pred.npy')
|
| 116 |
+
model.predict(self.X_val).dump(mpath + 'y_val_pred.npy')
|
| 117 |
+
model.predict(self.X_test).dump(mpath + 'y_test_pred.npy')
|
| 118 |
+
|
| 119 |
+
model_name = 'ensemble'
|
| 120 |
+
# create ensemble predictions via simple mean across model predictions (except naive predictions)
|
| 121 |
+
ensemblepath = self.output_folder + self.experiment_name + '/models/' + model_name + '/'
|
| 122 |
+
# create folder for model outputs
|
| 123 |
+
if not os.path.exists(ensemblepath):
|
| 124 |
+
os.makedirs(ensemblepath)
|
| 125 |
+
if not os.path.exists(ensemblepath + 'results/'):
|
| 126 |
+
os.makedirs(ensemblepath + 'results/')
|
| 127 |
+
# load all predictions
|
| 128 |
+
ensemble_train, ensemble_val, ensemble_test = [], [], []
|
| 129 |
+
for model_description in os.listdir(self.output_folder + self.experiment_name + '/models/'):
|
| 130 |
+
if not model_description in ['ensemble', 'naive']:
|
| 131 |
+
mpath = self.output_folder + self.experiment_name + '/models/' + model_description + '/'
|
| 132 |
+
ensemble_train.append(np.load(mpath + 'y_train_pred.npy', allow_pickle=True))
|
| 133 |
+
ensemble_val.append(np.load(mpath + 'y_val_pred.npy', allow_pickle=True))
|
| 134 |
+
ensemble_test.append(np.load(mpath + 'y_test_pred.npy', allow_pickle=True))
|
| 135 |
+
# dump mean predictions
|
| 136 |
+
np.array(ensemble_train).mean(axis=0).dump(ensemblepath + 'y_train_pred.npy')
|
| 137 |
+
np.array(ensemble_test).mean(axis=0).dump(ensemblepath + 'y_test_pred.npy')
|
| 138 |
+
np.array(ensemble_val).mean(axis=0).dump(ensemblepath + 'y_val_pred.npy')
|
| 139 |
+
|
| 140 |
+
def evaluate(self, n_bootstraping_samples=100, n_jobs=20, bootstrap_eval=False, dumped_bootstraps=True):
|
| 141 |
+
# get labels
|
| 142 |
+
global train_samples, val_samples
|
| 143 |
+
y_train = np.load(self.output_folder + self.experiment_name + '/data/y_train.npy', allow_pickle=True)
|
| 144 |
+
y_val = np.load(self.output_folder + self.experiment_name + '/data/y_val.npy', allow_pickle=True)
|
| 145 |
+
y_test = np.load(self.output_folder + self.experiment_name + '/data/y_test.npy', allow_pickle=True)
|
| 146 |
+
|
| 147 |
+
# if bootstrapping then generate appropriate samples for each
|
| 148 |
+
if bootstrap_eval:
|
| 149 |
+
if not dumped_bootstraps:
|
| 150 |
+
train_samples = np.array(get_appropriate_bootstrap_samples(y_train, n_bootstraping_samples))
|
| 151 |
+
test_samples = np.array(get_appropriate_bootstrap_samples(y_test, n_bootstraping_samples))
|
| 152 |
+
val_samples = np.array(get_appropriate_bootstrap_samples(y_val, n_bootstraping_samples))
|
| 153 |
+
else:
|
| 154 |
+
test_samples = np.load(self.output_folder + self.experiment_name + '/test_bootstrap_ids.npy',
|
| 155 |
+
allow_pickle=True)
|
| 156 |
+
else:
|
| 157 |
+
train_samples = np.array([range(len(y_train))])
|
| 158 |
+
test_samples = np.array([range(len(y_test))])
|
| 159 |
+
val_samples = np.array([range(len(y_val))])
|
| 160 |
+
|
| 161 |
+
# store samples for future evaluations
|
| 162 |
+
train_samples.dump(self.output_folder + self.experiment_name + '/train_bootstrap_ids.npy')
|
| 163 |
+
test_samples.dump(self.output_folder + self.experiment_name + '/test_bootstrap_ids.npy')
|
| 164 |
+
val_samples.dump(self.output_folder + self.experiment_name + '/val_bootstrap_ids.npy')
|
| 165 |
+
|
| 166 |
+
# iterate over all models fitted so far
|
| 167 |
+
for m in sorted(os.listdir(self.output_folder + self.experiment_name + '/models')):
|
| 168 |
+
print(m)
|
| 169 |
+
mpath = self.output_folder + self.experiment_name + '/models/' + m + '/'
|
| 170 |
+
rpath = self.output_folder + self.experiment_name + '/models/' + m + '/results/'
|
| 171 |
+
|
| 172 |
+
# load predictions
|
| 173 |
+
y_train_pred = np.load(mpath + 'y_train_pred.npy', allow_pickle=True)
|
| 174 |
+
y_val_pred = np.load(mpath + 'y_val_pred.npy', allow_pickle=True)
|
| 175 |
+
y_test_pred = np.load(mpath + 'y_test_pred.npy', allow_pickle=True)
|
| 176 |
+
|
| 177 |
+
if self.experiment_name == 'exp_ICBEB':
|
| 178 |
+
# compute classwise thresholds such that recall-focused Gbeta is optimized
|
| 179 |
+
thresholds = find_optimal_cutoff_thresholds_for_Gbeta(y_train, y_train_pred)
|
| 180 |
+
else:
|
| 181 |
+
thresholds = None
|
| 182 |
+
|
| 183 |
+
pool = multiprocessing.Pool(n_jobs)
|
| 184 |
+
|
| 185 |
+
tr_df = pd.concat(pool.starmap(generate_results,
|
| 186 |
+
zip(train_samples, repeat(y_train), repeat(y_train_pred),
|
| 187 |
+
repeat(thresholds))))
|
| 188 |
+
tr_df_point = generate_results(range(len(y_train)), y_train, y_train_pred, thresholds)
|
| 189 |
+
tr_df_result = pd.DataFrame(
|
| 190 |
+
np.array([
|
| 191 |
+
tr_df_point.mean().values,
|
| 192 |
+
tr_df.mean().values,
|
| 193 |
+
tr_df.quantile(0.05).values,
|
| 194 |
+
tr_df.quantile(0.95).values]),
|
| 195 |
+
columns=tr_df.columns,
|
| 196 |
+
index=['point', 'mean', 'lower', 'upper'])
|
| 197 |
+
|
| 198 |
+
te_df = pd.concat(pool.starmap(generate_results,
|
| 199 |
+
zip(test_samples, repeat(y_test), repeat(y_test_pred), repeat(thresholds))))
|
| 200 |
+
te_df_point = generate_results(range(len(y_test)), y_test, y_test_pred, thresholds)
|
| 201 |
+
te_df_result = pd.DataFrame(
|
| 202 |
+
np.array([
|
| 203 |
+
te_df_point.mean().values,
|
| 204 |
+
te_df.mean().values,
|
| 205 |
+
te_df.quantile(0.05).values,
|
| 206 |
+
te_df.quantile(0.95).values]),
|
| 207 |
+
columns=te_df.columns,
|
| 208 |
+
index=['point', 'mean', 'lower', 'upper'])
|
| 209 |
+
|
| 210 |
+
val_df = pd.concat(pool.starmap(generate_results,
|
| 211 |
+
zip(val_samples, repeat(y_val), repeat(y_val_pred), repeat(thresholds))))
|
| 212 |
+
val_df_point = generate_results(range(len(y_val)), y_val, y_val_pred, thresholds)
|
| 213 |
+
val_df_result = pd.DataFrame(
|
| 214 |
+
np.array([
|
| 215 |
+
val_df_point.mean().values,
|
| 216 |
+
val_df.mean().values,
|
| 217 |
+
val_df.quantile(0.05).values,
|
| 218 |
+
val_df.quantile(0.95).values]),
|
| 219 |
+
columns=val_df.columns,
|
| 220 |
+
index=['point', 'mean', 'lower', 'upper'])
|
| 221 |
+
|
| 222 |
+
pool.close()
|
| 223 |
+
|
| 224 |
+
# dump results
|
| 225 |
+
tr_df_result.to_csv(rpath + 'tr_results.csv')
|
| 226 |
+
val_df_result.to_csv(rpath + 'val_results.csv')
|
| 227 |
+
te_df_result.to_csv(rpath + 'te_results.csv')
|
exploratory_data_analysis/AutoECG_EDA.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
main.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model configs
|
| 2 |
+
from configurations.fastai_configs import conf_fastai_inception1d
|
| 3 |
+
from experiments.scp_experiment import SCPExperiment
|
| 4 |
+
from utilities.utils import generate_ptbxl_summary_table
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
data_folder = 'data/ptbxl/'
|
| 9 |
+
output_folder = 'output/'
|
| 10 |
+
|
| 11 |
+
models = [conf_fastai_inception1d]
|
| 12 |
+
|
| 13 |
+
# STANDARD SCP EXPERIMENTS ON PTB-XL
|
| 14 |
+
|
| 15 |
+
experiments = [
|
| 16 |
+
('exp1.1', 'subdiagnostic')
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
for name, task in experiments:
|
| 20 |
+
e = SCPExperiment(name, task, data_folder, output_folder, models)
|
| 21 |
+
e.prepare()
|
| 22 |
+
e.perform()
|
| 23 |
+
e.evaluate()
|
| 24 |
+
|
| 25 |
+
# generate summary table
|
| 26 |
+
generate_ptbxl_summary_table()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
main()
|
models/__pycache__/base_model.cpython-39.pyc
ADDED
|
Binary file (760 Bytes). View file
|
|
|
models/__pycache__/basicconv1d.cpython-39.pyc
ADDED
|
Binary file (9.01 kB). View file
|
|
|
models/__pycache__/fastaiModel.cpython-310.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
models/__pycache__/fastaiModel.cpython-39.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
models/__pycache__/inception1d.cpython-39.pyc
ADDED
|
Binary file (5.52 kB). View file
|
|
|
models/__pycache__/resnet1d.cpython-39.pyc
ADDED
|
Binary file (9.51 kB). View file
|
|
|
models/__pycache__/rnn1d.cpython-39.pyc
ADDED
|
Binary file (2.99 kB). View file
|
|
|
models/__pycache__/wavelet.cpython-39.pyc
ADDED
|
Binary file (5.4 kB). View file
|
|
|
models/__pycache__/xresnet1d.cpython-39.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
models/base_model.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ClassificationModel(object):
|
| 2 |
+
|
| 3 |
+
def __init__(self):
|
| 4 |
+
pass
|
| 5 |
+
|
| 6 |
+
def fit(self, X_train, y_train, X_val, y_val):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
def predict(self, X, full_sequence=True):
|
| 10 |
+
pass
|
models/basicconv1d.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from fastai.layers import *
|
| 6 |
+
from fastai.data.core import *
|
| 7 |
+
from typing import Optional, Collection, Union
|
| 8 |
+
from collections.abc import Iterable
|
| 9 |
+
|
| 10 |
+
'''
|
| 11 |
+
This layer creates a convolution kernel that is convolved with the layer input
|
| 12 |
+
over a single spatial (or temporal) dimension to produce a tensor of outputs.
|
| 13 |
+
If use_bias is True, a bias vector is created and added to the outputs.
|
| 14 |
+
Finally, if activation is not None, it is applied to the outputs as well.
|
| 15 |
+
https://keras.io/api/layers/convolution_layers/convolution1d/
|
| 16 |
+
'''
|
| 17 |
+
def listify(o):
|
| 18 |
+
if o is None: return []
|
| 19 |
+
if isinstance(o, list): return o
|
| 20 |
+
if isinstance(o, str): return [o]
|
| 21 |
+
if isinstance(o, Iterable): return list(o)
|
| 22 |
+
return [o]
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
def bn_drop_lin(ni, no, bn=True, p=0., actn=None):
|
| 26 |
+
layers = []
|
| 27 |
+
if bn: layers.append(nn.BatchNorm1d(ni))
|
| 28 |
+
if p != 0.: layers.append(nn.Dropout(p))
|
| 29 |
+
layers.append(nn.Linear(ni, no))
|
| 30 |
+
if actn is not None: layers.append(actn)
|
| 31 |
+
return layers
|
| 32 |
+
|
| 33 |
+
def _conv1d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0):
|
| 34 |
+
lst = []
|
| 35 |
+
if (drop_p > 0):
|
| 36 |
+
lst.append(nn.Dropout(drop_p))
|
| 37 |
+
lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
|
| 38 |
+
dilation=dilation, bias=not bn))
|
| 39 |
+
if bn:
|
| 40 |
+
lst.append(nn.BatchNorm1d(out_planes))
|
| 41 |
+
if act == "relu":
|
| 42 |
+
lst.append(nn.ReLU(True))
|
| 43 |
+
if act == "elu":
|
| 44 |
+
lst.append(nn.ELU(True))
|
| 45 |
+
if act == "prelu":
|
| 46 |
+
lst.append(nn.PReLU(True))
|
| 47 |
+
return nn.Sequential(*lst)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _fc(in_planes, out_planes, act="relu", bn=True):
|
| 51 |
+
lst = [nn.Linear(in_planes, out_planes, bias=not (bn))]
|
| 52 |
+
if bn:
|
| 53 |
+
lst.append(nn.BatchNorm1d(out_planes))
|
| 54 |
+
if act == "relu":
|
| 55 |
+
lst.append(nn.ReLU(True))
|
| 56 |
+
if act == "elu":
|
| 57 |
+
lst.append(nn.ELU(True))
|
| 58 |
+
if act == "prelu":
|
| 59 |
+
lst.append(nn.PReLU(True))
|
| 60 |
+
return nn.Sequential(*lst)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def cd_adaptive_concat_pool(relevant, irrelevant, module):
|
| 64 |
+
mpr, mpi = module.mp.attrib(relevant, irrelevant)
|
| 65 |
+
apr, api = module.ap.attrib(relevant, irrelevant)
|
| 66 |
+
return torch.cat([mpr, apr], 1), torch.cat([mpi, api], 1)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def attrib_adaptive_concat_pool(self, relevant, irrelevant):
|
| 70 |
+
return cd_adaptive_concat_pool(relevant, irrelevant, self)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class AdaptiveConcatPool1d(nn.Module):
|
| 74 |
+
"""Layer that concat `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."""
|
| 75 |
+
|
| 76 |
+
def __init__(self, sz: Optional[int] = None):
|
| 77 |
+
"""Output will be 2*sz or 2 if sz is None"""
|
| 78 |
+
super().__init__()
|
| 79 |
+
sz = sz or 1
|
| 80 |
+
self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
|
| 81 |
+
|
| 82 |
+
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
|
| 83 |
+
|
| 84 |
+
def attrib(self, relevant, irrelevant):
|
| 85 |
+
return attrib_adaptive_concat_pool(self, relevant, irrelevant)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class SqueezeExcite1d(nn.Module):
|
| 89 |
+
"""squeeze excite block as used for example in LSTM FCN"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, channels, reduction=16):
|
| 92 |
+
super().__init__()
|
| 93 |
+
channels_reduced = channels // reduction
|
| 94 |
+
self.w1 = torch.nn.Parameter(torch.randn(channels_reduced, channels).unsqueeze(0))
|
| 95 |
+
self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0))
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
# input is bs,ch,seq
|
| 99 |
+
z = torch.mean(x, dim=2, keepdim=True) # bs,ch
|
| 100 |
+
intermed = F.relu(torch.matmul(self.w1, z)) # (1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1)
|
| 101 |
+
s = F.sigmoid(torch.matmul(self.w2, intermed)) # (1,ch,ch_red * bs, ch_red, 1=bs, ch, 1
|
| 102 |
+
return s * x # bs,ch,seq * bs, ch,1 = bs,ch,seq
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def weight_init(m):
|
| 106 |
+
"""call weight initialization for model n via n.apply(weight_init)"""
|
| 107 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
|
| 108 |
+
nn.init.kaiming_normal_(m.weight)
|
| 109 |
+
if m.bias is not None:
|
| 110 |
+
nn.init.zeros_(m.bias)
|
| 111 |
+
if isinstance(m, nn.BatchNorm1d):
|
| 112 |
+
nn.init.constant_(m.weight, 1)
|
| 113 |
+
nn.init.constant_(m.bias, 0)
|
| 114 |
+
if isinstance(m, SqueezeExcite1d):
|
| 115 |
+
stdv1 = math.sqrt(2. / m.w1.size[0])
|
| 116 |
+
nn.init.normal_(m.w1, 0., stdv1)
|
| 117 |
+
stdv2 = math.sqrt(1. / m.w2.size[1])
|
| 118 |
+
nn.init.normal_(m.w2, 0., stdv2)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def create_head1d(nf: int, nc: int, lin_ftrs: Optional[Collection[int]] = None, ps: Union[float, Collection[float]] = 0.5,
|
| 122 |
+
bn_final: bool = False, bn: bool = True, act="relu", concat_pooling=True):
|
| 123 |
+
"""Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here"""
|
| 124 |
+
lin_ftrs = [2 * nf if concat_pooling else nf, nc] if lin_ftrs is None else [
|
| 125 |
+
2 * nf if concat_pooling else nf] + lin_ftrs + [
|
| 126 |
+
nc] # was [nf, 512,nc]
|
| 127 |
+
ps = listify(ps)
|
| 128 |
+
if len(ps) == 1: ps = [ps[0] / 2] * (len(lin_ftrs) - 2) + ps
|
| 129 |
+
actns = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None]
|
| 130 |
+
layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()]
|
| 131 |
+
for ni, no, p, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):
|
| 132 |
+
layers += bn_drop_lin(ni, no, bn, p, actn)
|
| 133 |
+
if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
|
| 134 |
+
return nn.Sequential(*layers)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# basic convolutional architecture
|
| 138 |
+
|
| 139 |
+
class BasicConv1d(nn.Sequential):
|
| 140 |
+
"""basic conv1d"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1,
|
| 143 |
+
squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,
|
| 144 |
+
split_first_layer=False, drop_p=0., lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
|
| 145 |
+
act_head="relu", concat_pooling=True):
|
| 146 |
+
if filters is None:
|
| 147 |
+
filters = [128, 128, 128, 128]
|
| 148 |
+
layers = []
|
| 149 |
+
if isinstance(kernel_size, int):
|
| 150 |
+
kernel_size = [kernel_size] * len(filters)
|
| 151 |
+
for i in range(len(filters)):
|
| 152 |
+
layers_tmp = [_conv1d(input_channels if i == 0 else filters[i - 1], filters[i], kernel_size=kernel_size[i],
|
| 153 |
+
stride=(1 if (split_first_layer is True and i == 0) else stride), dilation=dilation,
|
| 154 |
+
act="none" if ((headless is True and i == len(filters) - 1) or (
|
| 155 |
+
split_first_layer is True and i == 0)) else act,
|
| 156 |
+
bn=False if (headless is True and i == len(filters) - 1) else bn,
|
| 157 |
+
drop_p=(0. if i == 0 else drop_p))]
|
| 158 |
+
|
| 159 |
+
if split_first_layer is True and i == 0:
|
| 160 |
+
layers_tmp.append(_conv1d(filters[0], filters[0], kernel_size=1, stride=1, act=act, bn=bn, drop_p=0.))
|
| 161 |
+
# layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn)))
|
| 162 |
+
# layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn))
|
| 163 |
+
if pool > 0 and i < len(filters) - 1:
|
| 164 |
+
layers_tmp.append(nn.MaxPool1d(pool, stride=pool_stride, padding=(pool - 1) // 2))
|
| 165 |
+
if squeeze_excite_reduction > 0:
|
| 166 |
+
layers_tmp.append(SqueezeExcite1d(filters[i], squeeze_excite_reduction))
|
| 167 |
+
layers.append(nn.Sequential(*layers_tmp))
|
| 168 |
+
|
| 169 |
+
# head layers.append(nn.AdaptiveAvgPool1d(1)) layers.append(nn.Linear(filters[-1],num_classes)) head
|
| 170 |
+
# #inplace=True leads to a runtime error see ReLU+ dropout
|
| 171 |
+
# https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5
|
| 172 |
+
self.headless = headless
|
| 173 |
+
if headless is True:
|
| 174 |
+
head = nn.Sequential(nn.AdaptiveAvgPool1d(1), Flatten())
|
| 175 |
+
else:
|
| 176 |
+
head = create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
|
| 177 |
+
bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
|
| 178 |
+
layers.append(head)
|
| 179 |
+
|
| 180 |
+
super().__init__(*layers)
|
| 181 |
+
|
| 182 |
+
def get_layer_groups(self):
|
| 183 |
+
return self[2], self[-1]
|
| 184 |
+
|
| 185 |
+
def get_output_layer(self):
|
| 186 |
+
if self.headless is False:
|
| 187 |
+
return self[-1][-1]
|
| 188 |
+
else:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
def set_output_layer(self, x):
|
| 192 |
+
if self.headless is False:
|
| 193 |
+
self[-1][-1] = x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# convenience functions for basic convolutional architectures
|
| 197 |
+
def fcn(filters=None, num_classes=2, input_channels=8):
|
| 198 |
+
if filters is None:
|
| 199 |
+
filters = [128] * 5
|
| 200 |
+
filters_in = filters + [num_classes]
|
| 201 |
+
return BasicConv1d(filters=filters_in, kernel_size=3, stride=1, pool=2, pool_stride=2,
|
| 202 |
+
input_channels=input_channels, act="relu", bn=True, headless=True)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def fcn_wang(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
|
| 206 |
+
act_head="relu", concat_pooling=True):
|
| 207 |
+
return BasicConv1d(filters=[128, 256, 128], kernel_size=[8, 5, 3], stride=1, pool=0, pool_stride=2,
|
| 208 |
+
num_classes=num_classes, input_channels=input_channels, act="relu", bn=True,
|
| 209 |
+
lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head,
|
| 210 |
+
act_head=act_head, concat_pooling=concat_pooling)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def schirrmeister(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
|
| 214 |
+
act_head="relu", concat_pooling=True):
|
| 215 |
+
return BasicConv1d(filters=[25, 50, 100, 200], kernel_size=10, stride=3, pool=3, pool_stride=1,
|
| 216 |
+
num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,
|
| 217 |
+
split_first_layer=True, drop_p=0.5, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head,
|
| 218 |
+
bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def sen(filters=None, num_classes=2, input_channels=8, squeeze_excite_reduction=16, drop_p=0., lin_ftrs_head=None,
|
| 222 |
+
ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
|
| 223 |
+
if filters is None:
|
| 224 |
+
filters = [128] * 5
|
| 225 |
+
return BasicConv1d(filters=filters, kernel_size=3, stride=2, pool=0, pool_stride=0, input_channels=input_channels,
|
| 226 |
+
act="relu", bn=True, num_classes=num_classes, squeeze_excite_reduction=squeeze_excite_reduction,
|
| 227 |
+
drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head,
|
| 228 |
+
bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def basic1d(filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0,
|
| 232 |
+
num_classes=2, input_channels=8, act="relu", bn=True, headless=False, drop_p=0., lin_ftrs_head=None,
|
| 233 |
+
ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
|
| 234 |
+
if filters is None:
|
| 235 |
+
filters = [128] * 5
|
| 236 |
+
return BasicConv1d(filters=filters, kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool,
|
| 237 |
+
pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction,
|
| 238 |
+
num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,
|
| 239 |
+
drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head,
|
| 240 |
+
bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
|
models/fastaiModel.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastai.data.core import *
|
| 2 |
+
from fastai.learner import *
|
| 3 |
+
from fastai.callback.schedule import *
|
| 4 |
+
from fastai.torch_core import *
|
| 5 |
+
from fastai.callback.tracker import SaveModelCallback
|
| 6 |
+
# from fastai.callback.gradient import GradientClipping
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from functools import partial
|
| 9 |
+
import math
|
| 10 |
+
# from fastai.callback import GradientClipping
|
| 11 |
+
import torch
|
| 12 |
+
from fastai.tabular.core import range_of
|
| 13 |
+
import numpy as np
|
| 14 |
+
import matplotlib
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
from fastai.callback.core import Callback
|
| 17 |
+
from fastai.data.core import DataLoaders
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
# from fastai.metrics import add_metrics
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from fastcore.utils import ifnone
|
| 22 |
+
import pandas as pd
|
| 23 |
+
from models.base_model import ClassificationModel
|
| 24 |
+
from models.basicconv1d import weight_init, fcn_wang, fcn, schirrmeister, sen, basic1d
|
| 25 |
+
from models.inception1d import inception1d
|
| 26 |
+
from models.resnet1d import resnet1d18, resnet1d34, resnet1d50, resnet1d101, resnet1d152, resnet1d_wang, \
|
| 27 |
+
wrn1d_22
|
| 28 |
+
from models.rnn1d import RNN1d
|
| 29 |
+
from utilities.timeseries_utils import TimeseriesDatasetCrops, ToTensor, aggregate_predictions
|
| 30 |
+
from models.xresnet1d import xresnet1d18_deeper, xresnet1d34_deeper, xresnet1d50_deeper, xresnet1d18_deep, \
|
| 31 |
+
xresnet1d34_deep, xresnet1d50_deep, xresnet1d18, xresnet1d34, xresnet1d101, xresnet1d50, xresnet1d152
|
| 32 |
+
from utilities.utils import evaluate_experiment
|
| 33 |
+
def add_metrics(last_metrics, new_metric):
|
| 34 |
+
"""
|
| 35 |
+
Adds a new metric to the list of last metrics.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
last_metrics (list): List of previous metrics.
|
| 39 |
+
new_metric (float or list): New metric(s) to add.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
list: Updated list of metrics.
|
| 43 |
+
"""
|
| 44 |
+
if isinstance(new_metric, list):
|
| 45 |
+
return last_metrics + new_metric
|
| 46 |
+
else:
|
| 47 |
+
return last_metrics + [new_metric]
|
| 48 |
+
|
| 49 |
+
class MetricFunc(Callback):
|
| 50 |
+
"""Obtains score using user-supplied function func (potentially ignoring targets with ignore_idx)"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, func, name="MetricFunc", ignore_idx=None, one_hot_encode_target=True, argmax_pred=False,
|
| 53 |
+
softmax_pred=True, flatten_target=True, sigmoid_pred=False, metric_component=None):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.metric_complete = self.func(self.y_true, self.y_pred)
|
| 56 |
+
self.y_true = None
|
| 57 |
+
self.y_pred = None
|
| 58 |
+
self.func = func
|
| 59 |
+
self.ignore_idx = ignore_idx
|
| 60 |
+
self.one_hot_encode_target = one_hot_encode_target
|
| 61 |
+
self.argmax_pred = argmax_pred
|
| 62 |
+
self.softmax_pred = softmax_pred
|
| 63 |
+
self.flatten_target = flatten_target
|
| 64 |
+
self.sigmoid_pred = sigmoid_pred
|
| 65 |
+
self.metric_component = metric_component
|
| 66 |
+
self.name = name
|
| 67 |
+
|
| 68 |
+
def on_epoch_begin(self, **kwargs):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
def on_batch_end(self, last_output, last_target, **kwargs):
|
| 72 |
+
# flatten everything (to make it also work for annotation tasks)
|
| 73 |
+
y_pred_flat = last_output.view((-1, last_output.size()[-1]))
|
| 74 |
+
|
| 75 |
+
if self.flatten_target:
|
| 76 |
+
last_target.view(-1)
|
| 77 |
+
y_true_flat = last_target
|
| 78 |
+
|
| 79 |
+
# optionally take argmax of predictions
|
| 80 |
+
if self.argmax_pred is True:
|
| 81 |
+
y_pred_flat = y_pred_flat.argmax(dim=1)
|
| 82 |
+
elif self.softmax_pred is True:
|
| 83 |
+
y_pred_flat = F.softmax(y_pred_flat, dim=1)
|
| 84 |
+
elif self.sigmoid_pred is True:
|
| 85 |
+
y_pred_flat = torch.sigmoid(y_pred_flat)
|
| 86 |
+
|
| 87 |
+
# potentially remove ignore_idx entries
|
| 88 |
+
if self.ignore_idx is not None:
|
| 89 |
+
selected_indices = (y_true_flat != self.ignore_idx).nonzero().squeeze()
|
| 90 |
+
y_pred_flat = y_pred_flat[selected_indices]
|
| 91 |
+
y_true_flat = y_true_flat[selected_indices]
|
| 92 |
+
|
| 93 |
+
y_pred_flat = to_np(y_pred_flat)
|
| 94 |
+
y_true_flat = to_np(y_true_flat)
|
| 95 |
+
|
| 96 |
+
if self.one_hot_encode_target is True:
|
| 97 |
+
y_true_flat = np.one_hot_np(y_true_flat, last_output.size()[-1])
|
| 98 |
+
|
| 99 |
+
if self.y_pred is None:
|
| 100 |
+
self.y_pred = y_pred_flat
|
| 101 |
+
self.y_true = y_true_flat
|
| 102 |
+
else:
|
| 103 |
+
self.y_pred = np.concatenate([self.y_pred, y_pred_flat], axis=0)
|
| 104 |
+
self.y_true = np.concatenate([self.y_true, y_true_flat], axis=0)
|
| 105 |
+
|
| 106 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
| 107 |
+
# access full metric (possibly multiple components) via self.metric_complete
|
| 108 |
+
if self.metric_component is not None:
|
| 109 |
+
return add_metrics(last_metrics, self.metric_complete[self.metric_component])
|
| 110 |
+
else:
|
| 111 |
+
return add_metrics(last_metrics, self.metric_complete)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def fmax_metric(targs, preds):
|
| 115 |
+
return evaluate_experiment(targs, preds)["Fmax"]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def auc_metric(targs, preds):
|
| 119 |
+
return evaluate_experiment(targs, preds)["macro_auc"]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def mse_flat(preds, targs):
|
| 123 |
+
return torch.mean(torch.pow(preds.view(-1) - targs.view(-1), 2))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def nll_regression(preds, targs):
|
| 127 |
+
# preds: bs, 2
|
| 128 |
+
# targs: bs, 1
|
| 129 |
+
preds_mean = preds[:, 0]
|
| 130 |
+
# warning: output goes through exponential map to ensure positivity
|
| 131 |
+
preds_var = torch.clamp(torch.exp(preds[:, 1]), 1e-4, 1e10)
|
| 132 |
+
# print(to_np(preds_mean)[0],to_np(targs)[0,0],to_np(torch.sqrt(preds_var))[0])
|
| 133 |
+
return torch.mean(torch.log(2 * math.pi * preds_var) / 2) + torch.mean(
|
| 134 |
+
torch.pow(preds_mean - targs[:, 0], 2) / 2 / preds_var)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def nll_regression_init(m):
|
| 138 |
+
assert (isinstance(m, nn.Linear))
|
| 139 |
+
nn.init.normal_(m.weight, 0., 0.001)
|
| 140 |
+
nn.init.constant_(m.bias, 4)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def lr_find_plot(learner, path, filename="lr_find", n_skip=10, n_skip_end=2):
|
| 144 |
+
"""
|
| 145 |
+
saves lr_find plot as file (normally only jupyter output)
|
| 146 |
+
on the x-axis is lrs[-1]
|
| 147 |
+
"""
|
| 148 |
+
learner.lr_find()
|
| 149 |
+
|
| 150 |
+
backend_old = matplotlib.get_backend()
|
| 151 |
+
plt.switch_backend('agg')
|
| 152 |
+
plt.ylabel("loss")
|
| 153 |
+
plt.xlabel("learning rate (log scale)")
|
| 154 |
+
losses = [to_np(x) for x in learner.recorder.losses[n_skip:-(n_skip_end + 1)]]
|
| 155 |
+
# print(learner.recorder.val_losses)
|
| 156 |
+
# val_losses = [ to_np(x) for x in learner.recorder.val_losses[n_skip:-(n_skip_end+1)]]
|
| 157 |
+
|
| 158 |
+
plt.plot(learner.recorder.lrs[n_skip:-(n_skip_end + 1)], losses)
|
| 159 |
+
# plt.plot(learner.recorder.lrs[n_skip:-(n_skip_end+1)],val_losses )
|
| 160 |
+
|
| 161 |
+
plt.xscale('log')
|
| 162 |
+
plt.savefig(str(path / (filename + '.png')))
|
| 163 |
+
plt.switch_backend(backend_old)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def losses_plot(learner, path, filename="losses", last: int = None):
|
| 167 |
+
"""
|
| 168 |
+
saves lr_find plot as file (normally only jupyter output)
|
| 169 |
+
on the x-axis is lrs[-1]
|
| 170 |
+
"""
|
| 171 |
+
backend_old = matplotlib.get_backend()
|
| 172 |
+
plt.switch_backend('agg')
|
| 173 |
+
plt.ylabel("loss")
|
| 174 |
+
plt.xlabel("Batches processed")
|
| 175 |
+
|
| 176 |
+
last = ifnone(last, len(learner.recorder.nb_batches))
|
| 177 |
+
l_b = np.sum(learner.recorder.nb_batches[-last:])
|
| 178 |
+
iterations = range_of(learner.recorder.losses)[-l_b:]
|
| 179 |
+
plt.plot(iterations, learner.recorder.losses[-l_b:], label='Train')
|
| 180 |
+
val_iter = learner.recorder.nb_batches[-last:]
|
| 181 |
+
val_iter = np.cumsum(val_iter) + np.sum(learner.recorder.nb_batches[:-last])
|
| 182 |
+
plt.plot(val_iter, learner.recorder.val_losses[-last:], label='Validation')
|
| 183 |
+
plt.legend()
|
| 184 |
+
|
| 185 |
+
plt.savefig(str(path / (filename + '.png')))
|
| 186 |
+
plt.switch_backend(backend_old)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class FastaiModel(ClassificationModel):
|
| 190 |
+
def __init__(self, name, n_classes, freq, output_folder, input_shape, pretrained=False, input_size=2.5,
|
| 191 |
+
input_channels=12, chunkify_train=False, chunkify_valid=True, bs=128, ps_head=0.5, lin_ftrs_head=None,
|
| 192 |
+
wd=1e-2, epochs=50, lr=1e-2, kernel_size=5, loss="binary_cross_entropy", pretrained_folder=None,
|
| 193 |
+
n_classes_pretrained=None, gradual_unfreezing=True, discriminative_lrs=True, epochs_finetuning=30,
|
| 194 |
+
early_stopping=None, aggregate_fn="max", concat_train_val=False):
|
| 195 |
+
super().__init__()
|
| 196 |
+
|
| 197 |
+
if lin_ftrs_head is None:
|
| 198 |
+
lin_ftrs_head = [128]
|
| 199 |
+
self.name = name
|
| 200 |
+
self.num_classes = n_classes if loss != "nll_regression" else 2
|
| 201 |
+
self.target_fs = freq
|
| 202 |
+
self.output_folder = Path(output_folder)
|
| 203 |
+
|
| 204 |
+
self.input_size = int(input_size * self.target_fs)
|
| 205 |
+
self.input_channels = input_channels
|
| 206 |
+
|
| 207 |
+
self.chunkify_train = chunkify_train
|
| 208 |
+
self.chunkify_valid = chunkify_valid
|
| 209 |
+
|
| 210 |
+
self.chunk_length_train = 2 * self.input_size # target_fs*6
|
| 211 |
+
self.chunk_length_valid = self.input_size
|
| 212 |
+
|
| 213 |
+
self.min_chunk_length = self.input_size # chunk_length
|
| 214 |
+
|
| 215 |
+
self.stride_length_train = self.input_size # chunk_length_train//8
|
| 216 |
+
self.stride_length_valid = self.input_size // 2 # chunk_length_valid
|
| 217 |
+
|
| 218 |
+
self.copies_valid = 0 # >0 should only be used with chunkify_valid=False
|
| 219 |
+
|
| 220 |
+
self.bs = bs
|
| 221 |
+
self.ps_head = ps_head
|
| 222 |
+
self.lin_ftrs_head = lin_ftrs_head
|
| 223 |
+
self.wd = wd
|
| 224 |
+
self.epochs = epochs
|
| 225 |
+
self.lr = lr
|
| 226 |
+
self.kernel_size = kernel_size
|
| 227 |
+
self.loss = loss
|
| 228 |
+
self.input_shape = input_shape
|
| 229 |
+
|
| 230 |
+
if pretrained:
|
| 231 |
+
if pretrained_folder is None:
|
| 232 |
+
pretrained_folder = Path('../output/exp0/models/' + name.split("_pretrained")[0] + '/')
|
| 233 |
+
# pretrained_folder = Path('/output/exp0/models/'+name.split("_pretrained")[0]+'/')
|
| 234 |
+
|
| 235 |
+
if n_classes_pretrained is None:
|
| 236 |
+
n_classes_pretrained = 71
|
| 237 |
+
|
| 238 |
+
self.pretrained_folder = None if pretrained_folder is None else Path(pretrained_folder)
|
| 239 |
+
self.n_classes_pretrained = n_classes_pretrained
|
| 240 |
+
self.discriminative_lrs = discriminative_lrs
|
| 241 |
+
self.gradual_unfreezing = gradual_unfreezing
|
| 242 |
+
self.epochs_finetuning = epochs_finetuning
|
| 243 |
+
|
| 244 |
+
self.early_stopping = early_stopping
|
| 245 |
+
self.aggregate_fn = aggregate_fn
|
| 246 |
+
self.concat_train_val = concat_train_val
|
| 247 |
+
|
| 248 |
+
def fit(self, X_train, y_train, X_val, y_val):
|
| 249 |
+
# convert everything to float32
|
| 250 |
+
X_train = [l.astype(np.float32) for l in X_train]
|
| 251 |
+
X_val = [l.astype(np.float32) for l in X_val]
|
| 252 |
+
y_train = [l.astype(np.float32) for l in y_train]
|
| 253 |
+
y_val = [l.astype(np.float32) for l in y_val]
|
| 254 |
+
|
| 255 |
+
if self.concat_train_val:
|
| 256 |
+
X_train += X_val
|
| 257 |
+
y_train += y_val
|
| 258 |
+
|
| 259 |
+
if self.pretrained_folder is None: # from scratch
|
| 260 |
+
print("Training from scratch...")
|
| 261 |
+
learn = self._get_learner(X_train, y_train, X_val, y_val)
|
| 262 |
+
|
| 263 |
+
# if(self.discriminative_lrs):
|
| 264 |
+
# layer_groups=learn.model.get_layer_groups()
|
| 265 |
+
# learn.split(layer_groups)
|
| 266 |
+
learn.model.apply(weight_init)
|
| 267 |
+
|
| 268 |
+
# initialization for regression output
|
| 269 |
+
if self.loss == "nll_regression" or self.loss == "mse":
|
| 270 |
+
output_layer_new = learn.model.get_output_layer()
|
| 271 |
+
output_layer_new.apply(nll_regression_init)
|
| 272 |
+
learn.model.set_output_layer(output_layer_new)
|
| 273 |
+
|
| 274 |
+
lr_find_plot(learn, self.output_folder)
|
| 275 |
+
learn.fit_one_cycle(self.epochs, self.lr) # slice(self.lr) if self.discriminative_lrs else self.lr)
|
| 276 |
+
losses_plot(learn, self.output_folder)
|
| 277 |
+
else: # finetuning
|
| 278 |
+
print("Finetuning...")
|
| 279 |
+
# create learner
|
| 280 |
+
learn = self._get_learner(X_train, y_train, X_val, y_val, self.n_classes_pretrained)
|
| 281 |
+
|
| 282 |
+
# load pretrained model
|
| 283 |
+
learn.path = self.pretrained_folder
|
| 284 |
+
learn.load(self.pretrained_folder.stem)
|
| 285 |
+
learn.path = self.output_folder
|
| 286 |
+
|
| 287 |
+
# exchange top layer
|
| 288 |
+
output_layer = learn.model.get_output_layer()
|
| 289 |
+
output_layer_new = nn.Linear(output_layer.in_features, self.num_classes).cuda()
|
| 290 |
+
apply_init(output_layer_new, nn.init.kaiming_normal_)
|
| 291 |
+
learn.model.set_output_layer(output_layer_new)
|
| 292 |
+
|
| 293 |
+
# layer groups
|
| 294 |
+
if self.discriminative_lrs:
|
| 295 |
+
layer_groups = learn.model.get_layer_groups()
|
| 296 |
+
learn.split(layer_groups)
|
| 297 |
+
|
| 298 |
+
learn.train_bn = True # make sure if bn mode is train
|
| 299 |
+
|
| 300 |
+
# train
|
| 301 |
+
lr = self.lr
|
| 302 |
+
if self.gradual_unfreezing:
|
| 303 |
+
assert (self.discriminative_lrs is True)
|
| 304 |
+
learn.freeze()
|
| 305 |
+
lr_find_plot(learn, self.output_folder, "lr_find0")
|
| 306 |
+
learn.fit_one_cycle(self.epochs_finetuning, lr)
|
| 307 |
+
losses_plot(learn, self.output_folder, "losses0")
|
| 308 |
+
# for n in [0]:#range(len(layer_groups)): learn.freeze_to(-n-1) lr_find_plot(learn,
|
| 309 |
+
# self.output_folder,"lr_find"+str(n)) learn.fit_one_cycle(self.epochs_gradual_unfreezing,slice(lr))
|
| 310 |
+
# losses_plot(learn, self.output_folder,"losses"+str(n)) if(n==0):#reduce lr after first step lr/=10.
|
| 311 |
+
# if(n>0 and (self.name.startswith("fastai_lstm") or self.name.startswith("fastai_gru"))):#reduce lr
|
| 312 |
+
# further for RNNs lr/=10
|
| 313 |
+
|
| 314 |
+
learn.unfreeze()
|
| 315 |
+
lr_find_plot(learn, self.output_folder, "lr_find" + str(len(layer_groups)))
|
| 316 |
+
learn.fit_one_cycle(self.epochs_finetuning, slice(lr / 1000, lr / 10))
|
| 317 |
+
losses_plot(learn, self.output_folder, "losses" + str(len(layer_groups)))
|
| 318 |
+
|
| 319 |
+
learn.save(self.name) # even for early stopping the best model will have been loaded again
|
| 320 |
+
|
| 321 |
+
def predict(self, X):
|
| 322 |
+
X = [l.astype(np.float32) for l in X]
|
| 323 |
+
y_dummy = [np.ones(self.num_classes, dtype=np.float32) for _ in range(len(X))]
|
| 324 |
+
|
| 325 |
+
learn = self._get_learner(X, y_dummy, X, y_dummy)
|
| 326 |
+
learn.load(self.name)
|
| 327 |
+
|
| 328 |
+
preds, targs = learn.get_preds()
|
| 329 |
+
preds = to_np(preds)
|
| 330 |
+
|
| 331 |
+
idmap = learn.data.valid_ds.get_id_mapping()
|
| 332 |
+
|
| 333 |
+
return aggregate_predictions(preds, idmap=idmap,
|
| 334 |
+
aggregate_fn=np.mean if self.aggregate_fn == "mean" else np.amax)
|
| 335 |
+
|
| 336 |
+
def _get_learner(self, X_train, y_train, X_val, y_val, num_classes=None):
|
| 337 |
+
df_train = pd.DataFrame({"data": range(len(X_train)), "label": y_train})
|
| 338 |
+
df_valid = pd.DataFrame({"data": range(len(X_val)), "label": y_val})
|
| 339 |
+
|
| 340 |
+
tfms_ptb_xl = [ToTensor()]
|
| 341 |
+
|
| 342 |
+
ds_train = TimeseriesDatasetCrops(df_train, self.input_size, num_classes=self.num_classes,
|
| 343 |
+
chunk_length=self.chunk_length_train if self.chunkify_train else 0,
|
| 344 |
+
min_chunk_length=self.min_chunk_length,
|
| 345 |
+
stride=self.stride_length_train, transforms=tfms_ptb_xl,
|
| 346 |
+
annotation=False, col_lbl="label", npy_data=X_train)
|
| 347 |
+
ds_valid = TimeseriesDatasetCrops(df_valid, self.input_size, num_classes=self.num_classes,
|
| 348 |
+
chunk_length=self.chunk_length_valid if self.chunkify_valid else 0,
|
| 349 |
+
min_chunk_length=self.min_chunk_length,
|
| 350 |
+
stride=self.stride_length_valid, transforms=tfms_ptb_xl,
|
| 351 |
+
annotation=False, col_lbl="label", npy_data=X_val)
|
| 352 |
+
|
| 353 |
+
db = DataLoaders(ds_train, ds_valid)
|
| 354 |
+
|
| 355 |
+
if self.loss == "binary_cross_entropy":
|
| 356 |
+
loss = F.binary_cross_entropy_with_logits
|
| 357 |
+
elif self.loss == "cross_entropy":
|
| 358 |
+
loss = F.cross_entropy
|
| 359 |
+
elif self.loss == "mse":
|
| 360 |
+
loss = mse_flat
|
| 361 |
+
elif self.loss == "nll_regression":
|
| 362 |
+
loss = nll_regression
|
| 363 |
+
else:
|
| 364 |
+
print("loss not found")
|
| 365 |
+
assert (True)
|
| 366 |
+
|
| 367 |
+
self.input_channels = self.input_shape[-1]
|
| 368 |
+
metrics = []
|
| 369 |
+
|
| 370 |
+
print("model:", self.name)
|
| 371 |
+
# note: all models of a particular kind share the same prefix but potentially a different
|
| 372 |
+
# postfix such as _input256
|
| 373 |
+
num_classes = self.num_classes if num_classes is None else num_classes
|
| 374 |
+
# resnet resnet1d18,resnet1d34,resnet1d50,resnet1d101,resnet1d152,resnet1d_wang,resnet1d,wrn1d_22
|
| 375 |
+
if self.name.startswith("fastai_resnet1d18"):
|
| 376 |
+
model = resnet1d18(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
|
| 377 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 378 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 379 |
+
elif self.name.startswith("fastai_resnet1d34"):
|
| 380 |
+
model = resnet1d34(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
|
| 381 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 382 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 383 |
+
elif self.name.startswith("fastai_resnet1d50"):
|
| 384 |
+
model = resnet1d50(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
|
| 385 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 386 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 387 |
+
elif self.name.startswith("fastai_resnet1d101"):
|
| 388 |
+
model = resnet1d101(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
|
| 389 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 390 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 391 |
+
elif self.name.startswith("fastai_resnet1d152"):
|
| 392 |
+
model = resnet1d152(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
|
| 393 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 394 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 395 |
+
elif self.name.startswith("fastai_resnet1d_wang"):
|
| 396 |
+
model = resnet1d_wang(num_classes=num_classes, input_channels=self.input_channels,
|
| 397 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 398 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 399 |
+
elif self.name.startswith("fastai_wrn1d_22"):
|
| 400 |
+
model = wrn1d_22(num_classes=num_classes, input_channels=self.input_channels,
|
| 401 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 402 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 403 |
+
|
| 404 |
+
# xresnet ... (order important for string capture)
|
| 405 |
+
elif self.name.startswith("fastai_xresnet1d18_deeper"):
|
| 406 |
+
model = xresnet1d18_deeper(num_classes=num_classes, input_channels=self.input_channels,
|
| 407 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 408 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 409 |
+
elif self.name.startswith("fastai_xresnet1d34_deeper"):
|
| 410 |
+
model = xresnet1d34_deeper(num_classes=num_classes, input_channels=self.input_channels,
|
| 411 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 412 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 413 |
+
elif self.name.startswith("fastai_xresnet1d50_deeper"):
|
| 414 |
+
model = xresnet1d50_deeper(num_classes=num_classes, input_channels=self.input_channels,
|
| 415 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 416 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 417 |
+
elif self.name.startswith("fastai_xresnet1d18_deep"):
|
| 418 |
+
model = xresnet1d18_deep(num_classes=num_classes, input_channels=self.input_channels,
|
| 419 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 420 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 421 |
+
elif self.name.startswith("fastai_xresnet1d34_deep"):
|
| 422 |
+
model = xresnet1d34_deep(num_classes=num_classes, input_channels=self.input_channels,
|
| 423 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 424 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 425 |
+
elif self.name.startswith("fastai_xresnet1d50_deep"):
|
| 426 |
+
model = xresnet1d50_deep(num_classes=num_classes, input_channels=self.input_channels,
|
| 427 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 428 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 429 |
+
elif self.name.startswith("fastai_xresnet1d18"):
|
| 430 |
+
model = xresnet1d18(num_classes=num_classes, input_channels=self.input_channels,
|
| 431 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 432 |
+
elif self.name.startswith("fastai_xresnet1d34"):
|
| 433 |
+
model = xresnet1d34(num_classes=num_classes, input_channels=self.input_channels,
|
| 434 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 435 |
+
elif self.name.startswith("fastai_xresnet1d50"):
|
| 436 |
+
model = xresnet1d50(num_classes=num_classes, input_channels=self.input_channels,
|
| 437 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 438 |
+
elif self.name.startswith("fastai_xresnet1d101"):
|
| 439 |
+
model = xresnet1d101(num_classes=num_classes, input_channels=self.input_channels,
|
| 440 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 441 |
+
elif self.name.startswith("fastai_xresnet1d152"):
|
| 442 |
+
model = xresnet1d152(num_classes=num_classes, input_channels=self.input_channels,
|
| 443 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 444 |
+
|
| 445 |
+
# inception passing the default kernel size of 5 leads to a max kernel size of 40-1 in the inception model as
|
| 446 |
+
# proposed in the original paper
|
| 447 |
+
elif self.name == "fastai_inception1d_no_residual": # note: order important for string capture
|
| 448 |
+
model = inception1d(num_classes=num_classes, input_channels=self.input_channels,
|
| 449 |
+
use_residual=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head,
|
| 450 |
+
kernel_size=8 * self.kernel_size)
|
| 451 |
+
elif self.name.startswith("fastai_inception1d"):
|
| 452 |
+
model = inception1d(num_classes=num_classes, input_channels=self.input_channels,
|
| 453 |
+
use_residual=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head,
|
| 454 |
+
kernel_size=8 * self.kernel_size)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# BasicConv1d fcn,fcn_wang,schirrmeister,sen,basic1d
|
| 458 |
+
elif self.name.startswith("fastai_fcn_wang"): # note: order important for string capture
|
| 459 |
+
model = fcn_wang(num_classes=num_classes, input_channels=self.input_channels,
|
| 460 |
+
ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 461 |
+
elif self.name.startswith("fastai_fcn"):
|
| 462 |
+
model = fcn(num_classes=num_classes, input_channels=self.input_channels)
|
| 463 |
+
elif self.name.startswith("fastai_schirrmeister"):
|
| 464 |
+
model = schirrmeister(num_classes=num_classes, input_channels=self.input_channels,
|
| 465 |
+
ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 466 |
+
elif self.name.startswith("fastai_sen"):
|
| 467 |
+
model = sen(num_classes=num_classes, input_channels=self.input_channels, ps_head=self.ps_head,
|
| 468 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 469 |
+
elif self.name.startswith("fastai_basic1d"):
|
| 470 |
+
model = basic1d(num_classes=num_classes, input_channels=self.input_channels,
|
| 471 |
+
kernel_size=self.kernel_size, ps_head=self.ps_head,
|
| 472 |
+
lin_ftrs_head=self.lin_ftrs_head)
|
| 473 |
+
# RNN
|
| 474 |
+
elif self.name.startswith("fastai_lstm_bidir"):
|
| 475 |
+
model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=True,
|
| 476 |
+
bidirectional=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 477 |
+
elif self.name.startswith("fastai_gru_bidir"):
|
| 478 |
+
model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=False,
|
| 479 |
+
bidirectional=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 480 |
+
elif self.name.startswith("fastai_lstm"):
|
| 481 |
+
model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=True,
|
| 482 |
+
bidirectional=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 483 |
+
elif self.name.startswith("fastai_gru"):
|
| 484 |
+
model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=False,
|
| 485 |
+
bidirectional=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
|
| 486 |
+
else:
|
| 487 |
+
print("Model not found.")
|
| 488 |
+
assert True
|
| 489 |
+
|
| 490 |
+
learn = Learner(db, model, loss_func=loss, metrics=metrics, wd=self.wd, path=self.output_folder)
|
| 491 |
+
|
| 492 |
+
if self.name.startswith("fastai_lstm") or self.name.startswith("fastai_gru"):
|
| 493 |
+
learn.callback_fns.append(partial(GradientClipping, clip=0.25))
|
| 494 |
+
|
| 495 |
+
if self.early_stopping is not None:
|
| 496 |
+
# supported options: valid_loss, macro_auc, fmax
|
| 497 |
+
if self.early_stopping == "macro_auc" and self.loss != "mse" and self.loss != "nll_regression":
|
| 498 |
+
metric = MetricFunc(auc_metric, self.early_stopping,
|
| 499 |
+
one_hot_encode_target=False, argmax_pred=False, softmax_pred=False,
|
| 500 |
+
sigmoid_pred=True, flatten_target=False)
|
| 501 |
+
learn.metrics.append(metric)
|
| 502 |
+
learn.callback_fns.append(
|
| 503 |
+
partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))
|
| 504 |
+
elif self.early_stopping == "fmax" and self.loss != "mse" and self.loss != "nll_regression":
|
| 505 |
+
metric = MetricFunc(fmax_metric, self.early_stopping,
|
| 506 |
+
one_hot_encode_target=False, argmax_pred=False, softmax_pred=False,
|
| 507 |
+
sigmoid_pred=True, flatten_target=False)
|
| 508 |
+
learn.metrics.append(metric)
|
| 509 |
+
learn.callback_fns.append(partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))
|
| 510 |
+
elif self.early_stopping == "valid_loss":
|
| 511 |
+
learn.callback_fns.append(partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))
|
| 512 |
+
|
| 513 |
+
return learn
|
models/inception1d.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from fastai.data.core import *
|
| 4 |
+
|
| 5 |
+
# Inception time inspired by https://github.com/hfawaz/InceptionTime/blob/master/classifiers/inception.py and https://github.com/tcapelle/TimeSeries_fastai/blob/master/inception.py
|
| 6 |
+
from models.basicconv1d import create_head1d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1):
|
| 10 |
+
"convolution with padding"
|
| 11 |
+
return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
| 12 |
+
padding=(kernel_size - 1) // 2, bias=False)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def noop(x): return x
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# class InceptionBlock1d(nn.Module):
|
| 19 |
+
# def __init__(self, ni, nb_filters, kss, stride=1, act='linear', bottleneck_size=32):
|
| 20 |
+
# super().__init__()
|
| 21 |
+
# self.bottleneck = conv(ni, bottleneck_size, 1, stride) if (bottleneck_size > 0) else noop
|
| 22 |
+
|
| 23 |
+
# self.convs = nn.ModuleList(
|
| 24 |
+
# [conv(bottleneck_size if (bottleneck_size > 0) else ni, nb_filters, ks) for ks in kss])
|
| 25 |
+
# self.conv_bottle = nn.Sequential(nn.MaxPool1d(3, stride, padding=1), conv(ni, nb_filters, 1))
|
| 26 |
+
# self.bn_relu = nn.Sequential(nn.BatchNorm1d((len(kss) + 1) * nb_filters), nn.ReLU())
|
| 27 |
+
|
| 28 |
+
# def forward(self, x):
|
| 29 |
+
# # print("block in",x.size())
|
| 30 |
+
# bottled = self.bottleneck(x)
|
| 31 |
+
# out = self.bn_relu(torch.cat([c(bottled) for c in self.convs] + [self.conv_bottle(x)], dim=1))
|
| 32 |
+
# return out
|
| 33 |
+
class InceptionBlock1d(nn.Module):
|
| 34 |
+
def __init__(self, ni, nb_filters, kss, stride=1, act='linear', bottleneck_size=32):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.bottleneck = conv(ni, bottleneck_size, 1, stride) if (bottleneck_size > 0) else noop
|
| 37 |
+
|
| 38 |
+
self.convs = nn.ModuleList(
|
| 39 |
+
[conv(bottleneck_size if (bottleneck_size > 0) else ni, nb_filters, ks) for ks in kss])
|
| 40 |
+
self.conv_bottle = nn.Sequential(nn.MaxPool1d(3, stride, padding=1), conv(ni, nb_filters, 1))
|
| 41 |
+
self.bn_relu = nn.Sequential(nn.BatchNorm1d((len(kss) + 1) * nb_filters), nn.ReLU())
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
bottled = self.bottleneck(x)
|
| 45 |
+
conv_outputs = [c(bottled) for c in self.convs]
|
| 46 |
+
bottle_output = self.conv_bottle(x)
|
| 47 |
+
out = self.bn_relu(torch.cat(conv_outputs + [bottle_output], dim=1))
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
class Shortcut1d(nn.Module):
|
| 51 |
+
def __init__(self, ni, nf):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.act_fn = nn.ReLU(True)
|
| 54 |
+
self.conv = conv(ni, nf, 1)
|
| 55 |
+
self.bn = nn.BatchNorm1d(nf)
|
| 56 |
+
|
| 57 |
+
def forward(self, inp, out):
|
| 58 |
+
# print("sk",out.size(), inp.size(), self.conv(inp).size(), self.bn(self.conv(inp)).size)
|
| 59 |
+
# input()
|
| 60 |
+
return self.act_fn(out + self.bn(self.conv(inp)))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class InceptionBackbone(nn.Module):
|
| 64 |
+
def __init__(self, input_channels, kss, depth, bottleneck_size, nb_filters, use_residual):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.depth = depth
|
| 68 |
+
assert ((depth % 3) == 0)
|
| 69 |
+
self.use_residual = use_residual
|
| 70 |
+
|
| 71 |
+
n_ks = len(kss) + 1
|
| 72 |
+
self.im = nn.ModuleList([InceptionBlock1d(input_channels if d == 0 else n_ks * nb_filters,
|
| 73 |
+
nb_filters=nb_filters, kss=kss,
|
| 74 |
+
bottleneck_size=bottleneck_size) for d in range(depth)])
|
| 75 |
+
self.sk = nn.ModuleList(
|
| 76 |
+
[Shortcut1d(input_channels if d == 0 else n_ks * nb_filters, n_ks * nb_filters) for d in
|
| 77 |
+
range(depth // 3)])
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
|
| 81 |
+
input_res = x
|
| 82 |
+
for d in range(self.depth):
|
| 83 |
+
x = self.im[d](x)
|
| 84 |
+
if self.use_residual and d % 3 == 2:
|
| 85 |
+
x = (self.sk[d // 3])(input_res, x)
|
| 86 |
+
input_res = x.clone()
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Inception1d(nn.Module):
|
| 91 |
+
"""inception time architecture"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, num_classes=2, input_channels=8, kernel_size=40, depth=6, bottleneck_size=32, nb_filters=32,
|
| 94 |
+
use_residual=True, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu",
|
| 95 |
+
concat_pooling=True):
|
| 96 |
+
super().__init__()
|
| 97 |
+
assert (kernel_size >= 40)
|
| 98 |
+
kernel_size = [k - 1 if k % 2 == 0 else k for k in
|
| 99 |
+
[kernel_size, kernel_size // 2, kernel_size // 4]] # was 39,19,9
|
| 100 |
+
|
| 101 |
+
layers = [InceptionBackbone(input_channels=input_channels, kss=kernel_size, depth=depth,
|
| 102 |
+
bottleneck_size=bottleneck_size, nb_filters=nb_filters,
|
| 103 |
+
use_residual=use_residual)]
|
| 104 |
+
|
| 105 |
+
n_ks = len(kernel_size) + 1
|
| 106 |
+
# head
|
| 107 |
+
head = create_head1d(n_ks * nb_filters, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
|
| 108 |
+
bn_final=bn_final_head, bn=bn_head, act=act_head,
|
| 109 |
+
concat_pooling=concat_pooling)
|
| 110 |
+
layers.append(head)
|
| 111 |
+
# layers.append(AdaptiveConcatPool1d())
|
| 112 |
+
# layers.append(Flatten())
|
| 113 |
+
# layers.append(nn.Linear(2*n_ks*nb_filters, num_classes))
|
| 114 |
+
self.layers = nn.Sequential(*layers)
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
return self.layers(x)
|
| 118 |
+
|
| 119 |
+
def get_layer_groups(self):
|
| 120 |
+
depth = self.layers[0].depth
|
| 121 |
+
if depth > 3:
|
| 122 |
+
return (self.layers[0].im[3:], self.layers[0].sk[1:]), self.layers[-1]
|
| 123 |
+
else:
|
| 124 |
+
return self.layers[-1]
|
| 125 |
+
|
| 126 |
+
def get_output_layer(self):
|
| 127 |
+
return self.layers[-1][-1]
|
| 128 |
+
|
| 129 |
+
def set_output_layer(self, x):
|
| 130 |
+
self.layers[-1][-1] = x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def inception1d(**kwargs):
|
| 134 |
+
"""
|
| 135 |
+
Constructs an Inception model
|
| 136 |
+
"""
|
| 137 |
+
return Inception1d(**kwargs)
|
models/resnet1d.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
# Standard resnet
|
| 5 |
+
from models.basicconv1d import create_head1d
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def conv(in_planes, out_planes, stride=1, kernel_size=3):
|
| 9 |
+
"""convolution with padding"""
|
| 10 |
+
return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
| 11 |
+
padding=(kernel_size - 1) // 2, bias=False)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BasicBlock1d(nn.Module):
|
| 15 |
+
expansion = 1
|
| 16 |
+
|
| 17 |
+
def __init__(self, inplanes, planes, stride=1, kernel_size=None, down_sample=None):
|
| 18 |
+
if kernel_size is None:
|
| 19 |
+
kernel_size = [3, 3]
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size // 2 + 1]
|
| 23 |
+
|
| 24 |
+
self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=kernel_size[0])
|
| 25 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
| 26 |
+
self.relu = nn.ReLU(inplace=True)
|
| 27 |
+
self.conv2 = conv(planes, planes, kernel_size=kernel_size[1])
|
| 28 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
| 29 |
+
self.down_sample = down_sample
|
| 30 |
+
self.stride = stride
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
residual = x
|
| 34 |
+
|
| 35 |
+
out = self.conv1(x)
|
| 36 |
+
out = self.bn1(out)
|
| 37 |
+
out = self.relu(out)
|
| 38 |
+
|
| 39 |
+
out = self.conv2(out)
|
| 40 |
+
out = self.bn2(out)
|
| 41 |
+
|
| 42 |
+
if self.down_sample is not None:
|
| 43 |
+
residual = self.down_sample(x)
|
| 44 |
+
|
| 45 |
+
out += residual
|
| 46 |
+
out = self.relu(out)
|
| 47 |
+
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Bottleneck1d(nn.Module):
|
| 52 |
+
expansion = 4
|
| 53 |
+
|
| 54 |
+
def __init__(self, inplanes, planes, stride=1, kernel_size=3, down_sample=None):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
|
| 58 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
| 59 |
+
self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=stride,
|
| 60 |
+
padding=(kernel_size - 1) // 2, bias=False)
|
| 61 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
| 62 |
+
self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False)
|
| 63 |
+
self.bn3 = nn.BatchNorm1d(planes * 4)
|
| 64 |
+
self.relu = nn.ReLU(inplace=True)
|
| 65 |
+
self.down_sample = down_sample
|
| 66 |
+
self.stride = stride
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
residual = x
|
| 70 |
+
|
| 71 |
+
out = self.conv1(x)
|
| 72 |
+
out = self.bn1(out)
|
| 73 |
+
out = self.relu(out)
|
| 74 |
+
|
| 75 |
+
out = self.conv2(out)
|
| 76 |
+
out = self.bn2(out)
|
| 77 |
+
out = self.relu(out)
|
| 78 |
+
|
| 79 |
+
out = self.conv3(out)
|
| 80 |
+
out = self.bn3(out)
|
| 81 |
+
|
| 82 |
+
if self.down_sample is not None:
|
| 83 |
+
residual = self.down_sample(x)
|
| 84 |
+
|
| 85 |
+
out += residual
|
| 86 |
+
out = self.relu(out)
|
| 87 |
+
|
| 88 |
+
return out
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ResNet1d(nn.Sequential):
|
| 92 |
+
"""1d adaptation of the torchvision resnet"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, block, layers, kernel_size=3, num_classes=2, input_channels=3, inplanes=64, fix_feature_dim=True,
|
| 95 |
+
kernel_size_stem=None, stride_stem=2, pooling_stem=True, stride=2, lin_ftrs_head=None, ps_head=0.5,
|
| 96 |
+
bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
|
| 97 |
+
self.inplanes = inplanes
|
| 98 |
+
|
| 99 |
+
layers_tmp = []
|
| 100 |
+
|
| 101 |
+
if kernel_size_stem is None:
|
| 102 |
+
kernel_size_stem = kernel_size[0] if isinstance(kernel_size, list) else kernel_size
|
| 103 |
+
# stem
|
| 104 |
+
layers_tmp.append(nn.Conv1d(input_channels, inplanes, kernel_size=kernel_size_stem, stride=stride_stem,
|
| 105 |
+
padding=(kernel_size_stem - 1) // 2, bias=False))
|
| 106 |
+
layers_tmp.append(nn.BatchNorm1d(inplanes))
|
| 107 |
+
layers_tmp.append(nn.ReLU(inplace=True))
|
| 108 |
+
if pooling_stem is True:
|
| 109 |
+
layers_tmp.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1))
|
| 110 |
+
# backbone
|
| 111 |
+
for i, l in enumerate(layers):
|
| 112 |
+
if i == 0:
|
| 113 |
+
layers_tmp.append(self._make_layer(block, inplanes, layers[0], kernel_size=kernel_size))
|
| 114 |
+
else:
|
| 115 |
+
layers_tmp.append(
|
| 116 |
+
self._make_layer(block, inplanes if fix_feature_dim else (2 ** i) * inplanes, layers[i],
|
| 117 |
+
stride=stride, kernel_size=kernel_size))
|
| 118 |
+
|
| 119 |
+
# head
|
| 120 |
+
# layers_tmp.append(nn.AdaptiveAvgPool1d(1))
|
| 121 |
+
# layers_tmp.append(Flatten())
|
| 122 |
+
# layers_tmp.append(nn.Linear((inplanes if fix_feature_dim else (2**len(layers)*inplanes)) * block.expansion, num_classes))
|
| 123 |
+
|
| 124 |
+
head = create_head1d(
|
| 125 |
+
(inplanes if fix_feature_dim else (2 ** len(layers) * inplanes)) * block.expansion, nc=num_classes,
|
| 126 |
+
lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head,
|
| 127 |
+
concat_pooling=concat_pooling)
|
| 128 |
+
layers_tmp.append(head)
|
| 129 |
+
|
| 130 |
+
super().__init__()
|
| 131 |
+
|
| 132 |
+
def _make_layer(self, block, planes, blocks, stride=1, kernel_size=3):
|
| 133 |
+
down_sample = None
|
| 134 |
+
|
| 135 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 136 |
+
down_sample = nn.Sequential(
|
| 137 |
+
nn.Conv1d(self.inplanes, planes * block.expansion,
|
| 138 |
+
kernel_size=1, stride=stride, bias=False),
|
| 139 |
+
nn.BatchNorm1d(planes * block.expansion),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
layers = [block(self.inplanes, planes, stride, kernel_size, down_sample)]
|
| 143 |
+
self.inplanes = planes * block.expansion
|
| 144 |
+
for i in range(1, blocks):
|
| 145 |
+
layers.append(block(self.inplanes, planes))
|
| 146 |
+
|
| 147 |
+
return nn.Sequential(*layers)
|
| 148 |
+
|
| 149 |
+
def get_layer_groups(self):
|
| 150 |
+
return self[6], self[-1]
|
| 151 |
+
|
| 152 |
+
def get_output_layer(self):
|
| 153 |
+
return self[-1][-1]
|
| 154 |
+
|
| 155 |
+
def set_output_layer(self, x):
|
| 156 |
+
self[-1][-1] = x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def resnet1d18(**kwargs):
|
| 160 |
+
"""
|
| 161 |
+
Constructs a ResNet-18 model.
|
| 162 |
+
"""
|
| 163 |
+
return ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def resnet1d34(**kwargs):
|
| 167 |
+
"""
|
| 168 |
+
Constructs a ResNet-34 model.
|
| 169 |
+
"""
|
| 170 |
+
return ResNet1d(BasicBlock1d, [3, 4, 6, 3], **kwargs)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def resnet1d50(**kwargs):
|
| 174 |
+
"""
|
| 175 |
+
Constructs a ResNet-50 model.
|
| 176 |
+
"""
|
| 177 |
+
return ResNet1d(Bottleneck1d, [3, 4, 6, 3], **kwargs)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def resnet1d101(**kwargs):
|
| 181 |
+
"""
|
| 182 |
+
Constructs a ResNet-101 model.
|
| 183 |
+
"""
|
| 184 |
+
return ResNet1d(Bottleneck1d, [3, 4, 23, 3], **kwargs)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def resnet1d152(**kwargs):
|
| 188 |
+
"""
|
| 189 |
+
Constructs a ResNet-152 model.
|
| 190 |
+
"""
|
| 191 |
+
return ResNet1d(Bottleneck1d, [3, 8, 36, 3], **kwargs)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# original used kernel_size_stem = 8
|
| 195 |
+
def resnet1d_wang(**kwargs):
|
| 196 |
+
if not ("kernel_size" in kwargs.keys()):
|
| 197 |
+
kwargs["kernel_size"] = [5, 3]
|
| 198 |
+
if not ("kernel_size_stem" in kwargs.keys()):
|
| 199 |
+
kwargs["kernel_size_stem"] = 7
|
| 200 |
+
if not ("stride_stem" in kwargs.keys()):
|
| 201 |
+
kwargs["stride_stem"] = 1
|
| 202 |
+
if not ("pooling_stem" in kwargs.keys()):
|
| 203 |
+
kwargs["pooling_stem"] = False
|
| 204 |
+
if not ("inplanes" in kwargs.keys()):
|
| 205 |
+
kwargs["inplanes"] = 128
|
| 206 |
+
|
| 207 |
+
return ResNet1d(BasicBlock1d, [1, 1, 1], **kwargs)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def resnet1d(**kwargs):
|
| 211 |
+
"""
|
| 212 |
+
Constructs a custom ResNet model.
|
| 213 |
+
"""
|
| 214 |
+
return ResNet1d(BasicBlock1d, **kwargs)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# wide resnet adopted from fastai wrn
|
| 218 |
+
|
| 219 |
+
def noop(x): return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def conv1d(ni: int, nf: int, ks: int = 3, stride: int = 1, padding: int = None, bias=False) -> nn.Conv1d:
|
| 223 |
+
"Create `nn.Conv1d` layer: `ni` inputs, `nf` outputs, `ks` kernel size. `padding` defaults to `k//2`."
|
| 224 |
+
if padding is None: padding = ks // 2
|
| 225 |
+
return nn.Conv1d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _bn1d(ni, init_zero=False):
|
| 229 |
+
"Batchnorm layer with 0 initialization"
|
| 230 |
+
m = nn.BatchNorm1d(ni)
|
| 231 |
+
m.weight.data.fill_(0 if init_zero else 1)
|
| 232 |
+
m.bias.data.zero_()
|
| 233 |
+
return m
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def bn_relu_conv1d(ni, nf, ks, stride, init_zero=False):
|
| 237 |
+
bn_initzero = _bn1d(ni, init_zero=init_zero)
|
| 238 |
+
return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv1d(ni, nf, ks, stride))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class BasicBlock1dwrn(nn.Module):
|
| 242 |
+
def __init__(self, ni, nf, stride, drop_p=0.0, ks=3):
|
| 243 |
+
super().__init__()
|
| 244 |
+
if isinstance(ks, int):
|
| 245 |
+
ks = [ks, ks // 2 + 1]
|
| 246 |
+
self.bn = nn.BatchNorm1d(ni)
|
| 247 |
+
self.conv1 = conv1d(ni, nf, ks[0], stride)
|
| 248 |
+
self.conv2 = bn_relu_conv1d(nf, nf, ks[0], 1)
|
| 249 |
+
self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
|
| 250 |
+
self.shortcut = conv1d(ni, nf, ks[1], stride) if (
|
| 251 |
+
ni != nf or stride > 1) else noop # adapted to make it work for fix_feature_dim=True
|
| 252 |
+
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
x2 = F.relu(self.bn(x), inplace=True)
|
| 255 |
+
r = self.shortcut(x2)
|
| 256 |
+
x = self.conv1(x2)
|
| 257 |
+
if self.drop: x = self.drop(x)
|
| 258 |
+
x = self.conv2(x) * 0.2
|
| 259 |
+
return x.add_(r)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _make_group(N, ni, nf, block, stride, drop_p, ks=3):
|
| 263 |
+
return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p, ks=ks) for i in range(N)]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class WideResNet1d(nn.Sequential):
|
| 267 |
+
def __init__(self, input_channels: int, num_groups: int, N: int, num_classes: int, k: int = 1, drop_p: float = 0.0,
|
| 268 |
+
start_nf: int = 16, fix_feature_dim=True, kernel_size=5, lin_ftrs_head=None, ps_head=0.5,
|
| 269 |
+
bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
|
| 270 |
+
super().__init__()
|
| 271 |
+
n_channels = [start_nf]
|
| 272 |
+
|
| 273 |
+
for i in range(num_groups): n_channels.append(start_nf if fix_feature_dim else start_nf * (2 ** i) * k)
|
| 274 |
+
|
| 275 |
+
layers = [conv1d(input_channels, n_channels[0], 3, 1)] # conv1 stem
|
| 276 |
+
for i in range(num_groups):
|
| 277 |
+
layers += _make_group(N, n_channels[i], n_channels[i + 1], BasicBlock1dwrn,
|
| 278 |
+
(1 if i == 0 else 2), drop_p, ks=kernel_size)
|
| 279 |
+
|
| 280 |
+
# layers += [nn.BatchNorm1d(n_channels[-1]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool1d(1),
|
| 281 |
+
# Flatten(), nn.Linear(n_channels[-1], num_classes)]
|
| 282 |
+
head = create_head1d(n_channels[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
|
| 283 |
+
bn_final=bn_final_head, bn=bn_head, act=act_head,
|
| 284 |
+
concat_pooling=concat_pooling)
|
| 285 |
+
layers.append(head)
|
| 286 |
+
|
| 287 |
+
super().__init__()
|
| 288 |
+
|
| 289 |
+
def get_layer_groups(self):
|
| 290 |
+
return self[6], self[-1]
|
| 291 |
+
|
| 292 |
+
def get_output_layer(self):
|
| 293 |
+
return self[-1][-1]
|
| 294 |
+
|
| 295 |
+
def set_output_layer(self, x):
|
| 296 |
+
self[-1][-1] = x
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def wrn1d_22(**kwargs): return WideResNet1d(num_groups=3, N=3, k=6, drop_p=0., **kwargs)
|
models/rnn1d.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from fastai.layers import *
|
| 4 |
+
from fastai.data.core import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AdaptiveConcatPoolRNN(nn.Module):
|
| 8 |
+
def __init__(self, bidirectional):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.bidirectional = bidirectional
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
# input shape bs, ch, ts
|
| 14 |
+
t1 = nn.AdaptiveAvgPool1d(1)(x)
|
| 15 |
+
t2 = nn.AdaptiveMaxPool1d(1)(x)
|
| 16 |
+
|
| 17 |
+
if self.bidirectional is False:
|
| 18 |
+
t3 = x[:, :, -1]
|
| 19 |
+
else:
|
| 20 |
+
channels = x.size()[1]
|
| 21 |
+
t3 = torch.cat([x[:, :channels, -1], x[:, channels:, 0]], 1)
|
| 22 |
+
out = torch.cat([t1.squeeze(-1), t2.squeeze(-1), t3], 1) # output shape bs, 3*ch
|
| 23 |
+
return out
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RNN1d(nn.Sequential):
|
| 27 |
+
def __init__(self, input_channels, num_classes, lstm=True, hidden_dim=256, num_layers=2, bidirectional=False,
|
| 28 |
+
ps_head=0.5, act_head="relu", lin_ftrs_head=None, bn=True):
|
| 29 |
+
# bs, ch, ts -> ts, bs, ch
|
| 30 |
+
layers_tmp = [Lambda(lambda x: x.transpose(1, 2)), Lambda(lambda x: x.transpose(0, 1))]
|
| 31 |
+
# LSTM
|
| 32 |
+
if lstm:
|
| 33 |
+
layers_tmp.append(nn.LSTM(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers,
|
| 34 |
+
bidirectional=bidirectional))
|
| 35 |
+
else:
|
| 36 |
+
layers_tmp.append(nn.GRU(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers,
|
| 37 |
+
bidirectional=bidirectional))
|
| 38 |
+
# pooling
|
| 39 |
+
layers_tmp.append(Lambda(lambda x: x[0].transpose(0, 1)))
|
| 40 |
+
layers_tmp.append(Lambda(lambda x: x.transpose(1, 2)))
|
| 41 |
+
|
| 42 |
+
layers_head = [AdaptiveConcatPoolRNN(bidirectional)]
|
| 43 |
+
|
| 44 |
+
# classifier
|
| 45 |
+
nf = 3 * hidden_dim if bidirectional is False else 6 * hidden_dim
|
| 46 |
+
lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
|
| 47 |
+
ps_head = listify(ps_head)
|
| 48 |
+
if len(ps_head) == 1:
|
| 49 |
+
ps_head = [ps_head[0] / 2] * (len(lin_ftrs_head) - 2) + ps_head
|
| 50 |
+
actns = [nn.ReLU(inplace=True) if act_head == "relu" else nn.ELU(inplace=True)] * (
|
| 51 |
+
len(lin_ftrs_head) - 2) + [None]
|
| 52 |
+
|
| 53 |
+
for ni, no, p, actn in zip(lin_ftrs_head[:-1], lin_ftrs_head[1:], ps_head, actns):
|
| 54 |
+
layers_head += bn_drop_lin(ni, no, bn, p, actn)
|
| 55 |
+
layers_head = nn.Sequential(*layers_head)
|
| 56 |
+
layers_tmp.append(layers_head)
|
| 57 |
+
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
def get_layer_groups(self):
|
| 61 |
+
return self[-1],
|
| 62 |
+
|
| 63 |
+
def get_output_layer(self):
|
| 64 |
+
return self[-1][-1]
|
| 65 |
+
|
| 66 |
+
def set_output_layer(self, x):
|
| 67 |
+
self[-1][-1] = x
|
models/wavelet.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.linear_model import LogisticRegression
|
| 2 |
+
from sklearn.multiclass import OneVsRestClassifier
|
| 3 |
+
from models.base_model import ClassificationModel
|
| 4 |
+
import pickle
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 8 |
+
import pywt
|
| 9 |
+
import scipy.stats
|
| 10 |
+
import multiprocessing
|
| 11 |
+
from collections import Counter
|
| 12 |
+
from keras.layers import Dropout, Dense, Input
|
| 13 |
+
from keras.models import Model
|
| 14 |
+
from keras.models import load_model
|
| 15 |
+
from keras.callbacks import ModelCheckpoint
|
| 16 |
+
from sklearn.preprocessing import StandardScaler
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def calculate_entropy(list_values):
|
| 20 |
+
counter_values = Counter(list_values).most_common()
|
| 21 |
+
probabilities = [elem[1] / len(list_values) for elem in counter_values]
|
| 22 |
+
entropy = scipy.stats.entropy(probabilities)
|
| 23 |
+
return entropy
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def calculate_statistics(list_values):
|
| 27 |
+
n5 = np.nanpercentile(list_values, 5)
|
| 28 |
+
n25 = np.nanpercentile(list_values, 25)
|
| 29 |
+
n75 = np.nanpercentile(list_values, 75)
|
| 30 |
+
n95 = np.nanpercentile(list_values, 95)
|
| 31 |
+
median = np.nanpercentile(list_values, 50)
|
| 32 |
+
mean = np.nanmean(list_values)
|
| 33 |
+
std = np.nanstd(list_values)
|
| 34 |
+
var = np.nanvar(list_values)
|
| 35 |
+
rms = np.nanmean(np.sqrt(list_values ** 2))
|
| 36 |
+
return [n5, n25, n75, n95, median, mean, std, var, rms]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def calculate_crossings(list_values):
|
| 40 |
+
zero_crossing_indices = np.nonzero(np.diff(np.array(list_values) > 0))[0]
|
| 41 |
+
no_zero_crossings = len(zero_crossing_indices)
|
| 42 |
+
mean_crossing_indices = np.nonzero(np.diff(np.array(list_values) > np.nanmean(list_values)))[0]
|
| 43 |
+
no_mean_crossings = len(mean_crossing_indices)
|
| 44 |
+
return [no_zero_crossings, no_mean_crossings]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_features(list_values):
|
| 48 |
+
entropy = calculate_entropy(list_values)
|
| 49 |
+
crossings = calculate_crossings(list_values)
|
| 50 |
+
statistics = calculate_statistics(list_values)
|
| 51 |
+
return [entropy] + crossings + statistics
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_single_ecg_features(signal, waveletname='db6'):
|
| 55 |
+
features = []
|
| 56 |
+
for channel in signal.T:
|
| 57 |
+
list_coeff = pywt.wavedec(channel, wavelet=waveletname, level=5)
|
| 58 |
+
channel_features = []
|
| 59 |
+
for coeff in list_coeff:
|
| 60 |
+
channel_features += get_features(coeff)
|
| 61 |
+
features.append(channel_features)
|
| 62 |
+
return np.array(features).flatten()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_ecg_features(ecg_data, parallel=True):
|
| 66 |
+
if parallel:
|
| 67 |
+
pool = multiprocessing.Pool(18)
|
| 68 |
+
return np.array(pool.map(get_single_ecg_features, ecg_data))
|
| 69 |
+
else:
|
| 70 |
+
list_features = []
|
| 71 |
+
for signal in tqdm(ecg_data):
|
| 72 |
+
features = get_single_ecg_features(signal)
|
| 73 |
+
list_features.append(features)
|
| 74 |
+
return np.array(list_features)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# for keras models
|
| 78 |
+
# def keras_macro_auroc(y_true, y_pred):
|
| 79 |
+
# return tf.py_func(macro_auroc, (y_true, y_pred), tf.double)
|
| 80 |
+
|
| 81 |
+
class WaveletModel(ClassificationModel):
|
| 82 |
+
def __init__(self, name, n_classes, freq, outputfolder, input_shape, regularizer_C=.001, classifier='RF'):
|
| 83 |
+
# Disclaimer: This model assumes equal shapes across all samples!
|
| 84 |
+
# standard parameters
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.name = name
|
| 87 |
+
self.outputfolder = outputfolder
|
| 88 |
+
self.n_classes = n_classes
|
| 89 |
+
self.freq = freq
|
| 90 |
+
self.regularizer_C = regularizer_C
|
| 91 |
+
self.classifier = classifier
|
| 92 |
+
self.dropout = .25
|
| 93 |
+
self.activation = 'relu'
|
| 94 |
+
self.final_activation = 'sigmoid'
|
| 95 |
+
self.n_dense_dim = 128
|
| 96 |
+
self.epochs = 30
|
| 97 |
+
|
| 98 |
+
def fit(self, X_train, y_train, X_val, y_val):
|
| 99 |
+
XF_train = get_ecg_features(X_train)
|
| 100 |
+
XF_val = get_ecg_features(X_val)
|
| 101 |
+
|
| 102 |
+
if self.classifier == 'LR':
|
| 103 |
+
if self.n_classes > 1:
|
| 104 |
+
clf = OneVsRestClassifier(
|
| 105 |
+
LogisticRegression(C=self.regularizer_C, solver='lbfgs', max_iter=1000, n_jobs=-1))
|
| 106 |
+
else:
|
| 107 |
+
clf = LogisticRegression(C=self.regularizer_C, solver='lbfgs', max_iter=1000, n_jobs=-1)
|
| 108 |
+
clf.fit(XF_train, y_train)
|
| 109 |
+
pickle.dump(clf, open(self.outputfolder + 'clf.pkl', 'wb'))
|
| 110 |
+
elif self.classifier == 'RF':
|
| 111 |
+
clf = RandomForestClassifier(n_estimators=1000, n_jobs=16)
|
| 112 |
+
clf.fit(XF_train, y_train)
|
| 113 |
+
pickle.dump(clf, open(self.outputfolder + 'clf.pkl', 'wb'))
|
| 114 |
+
elif self.classifier == 'NN':
|
| 115 |
+
# standardize input data
|
| 116 |
+
ss = StandardScaler()
|
| 117 |
+
XFT_train = ss.fit_transform(XF_train)
|
| 118 |
+
XFT_val = ss.transform(XF_val)
|
| 119 |
+
pickle.dump(ss, open(self.outputfolder + 'ss.pkl', 'wb'))
|
| 120 |
+
# classification stage
|
| 121 |
+
input_x = Input(shape=(XFT_train.shape[1],))
|
| 122 |
+
x = Dense(self.n_dense_dim, activation=self.activation)(input_x)
|
| 123 |
+
x = Dropout(self.dropout)(x)
|
| 124 |
+
y = Dense(self.n_classes, activation=self.final_activation)(x)
|
| 125 |
+
self.model = Model(input_x, y)
|
| 126 |
+
|
| 127 |
+
self.model.compile(optimizer='adamax', loss='binary_crossentropy') # , metrics=[keras_macro_auroc])
|
| 128 |
+
# monitor validation error
|
| 129 |
+
mc_loss = ModelCheckpoint(self.outputfolder + 'best_loss_model.h5', monitor='val_loss', mode='min',
|
| 130 |
+
verbose=1, save_best_only=True)
|
| 131 |
+
# mc_score = ModelCheckpoint(self.output_folder +'best_score_model.h5', monitor='val_keras_macro_auroc', mode='max', verbose=1, save_best_only=True)
|
| 132 |
+
self.model.fit(XFT_train, y_train, validation_data=(XFT_val, y_val), epochs=self.epochs, batch_size=128,
|
| 133 |
+
callbacks=[mc_loss]) # , mc_score])
|
| 134 |
+
self.model.save(self.outputfolder + 'last_model.h5')
|
| 135 |
+
|
| 136 |
+
def predict(self, X):
|
| 137 |
+
XF = get_ecg_features(X)
|
| 138 |
+
if self.classifier == 'LR':
|
| 139 |
+
clf = pickle.load(open(self.outputfolder + 'clf.pkl', 'rb'))
|
| 140 |
+
if self.n_classes > 1:
|
| 141 |
+
return clf.predict_proba(XF)
|
| 142 |
+
else:
|
| 143 |
+
return clf.predict_proba(XF)[:, 1][:, np.newaxis]
|
| 144 |
+
elif self.classifier == 'RF':
|
| 145 |
+
clf = pickle.load(open(self.outputfolder + 'clf.pkl', 'rb'))
|
| 146 |
+
y_pred = clf.predict_proba(XF)
|
| 147 |
+
if self.n_classes > 1:
|
| 148 |
+
return np.array([yi[:, 1] for yi in y_pred]).T
|
| 149 |
+
else:
|
| 150 |
+
return y_pred[:, 1][:, np.newaxis]
|
| 151 |
+
elif self.classifier == 'NN':
|
| 152 |
+
ss = pickle.load(open(self.outputfolder + 'ss.pkl', 'rb')) #
|
| 153 |
+
XFT = ss.transform(XF)
|
| 154 |
+
model = load_model(
|
| 155 |
+
self.outputfolder + 'best_loss_model.h5')
|
| 156 |
+
# 'best_score_model.h5', custom_objects={
|
| 157 |
+
# 'keras_macro_auroc': keras_macro_auroc})
|
| 158 |
+
return model.predict(XFT)
|
models/xresnet1d.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from enum import Enum
|
| 4 |
+
import re
|
| 5 |
+
# delegates
|
| 6 |
+
import inspect
|
| 7 |
+
|
| 8 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
| 9 |
+
|
| 10 |
+
from models.basicconv1d import create_head1d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def delegates(to=None, keep=False):
|
| 14 |
+
"""Decorator: replace `**kwargs` in signature with params from `to`"""
|
| 15 |
+
|
| 16 |
+
def _f(f):
|
| 17 |
+
if to is None:
|
| 18 |
+
to_f, from_f = f.__base__.__init__, f.__init__
|
| 19 |
+
else:
|
| 20 |
+
to_f, from_f = to, f
|
| 21 |
+
sig = inspect.signature(from_f)
|
| 22 |
+
sigd = dict(sig.parameters)
|
| 23 |
+
k = sigd.pop('kwargs')
|
| 24 |
+
s2 = {k: v for k, v in inspect.signature(to_f).parameters.items()
|
| 25 |
+
if v.default != inspect.Parameter.empty and k not in sigd}
|
| 26 |
+
sigd.update(s2)
|
| 27 |
+
if keep: sigd['kwargs'] = k
|
| 28 |
+
from_f.__signature__ = sig.replace(parameters=sigd.values())
|
| 29 |
+
return f
|
| 30 |
+
|
| 31 |
+
return _f
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def store_attr(self, nms):
|
| 35 |
+
"""Store params named in comma-separated `nms` from calling context into attrs in `self`"""
|
| 36 |
+
mod = inspect.currentframe().f_back.f_locals
|
| 37 |
+
for n in re.split(', *', nms): setattr(self, n, mod[n])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero')
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _conv_func(ndim=2, transpose=False):
|
| 44 |
+
"""Return the proper conv `ndim` function, potentially `transposed`."""
|
| 45 |
+
assert 1 <= ndim <= 3
|
| 46 |
+
return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d')
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def init_default(m, func=nn.init.kaiming_normal_):
|
| 50 |
+
"""Initialize `m` weights with `func` and set `bias` to 0."""
|
| 51 |
+
if func and hasattr(m, 'weight'): func(m.weight)
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)
|
| 54 |
+
return m
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs):
|
| 58 |
+
"""Norm layer with `nf` features and `ndim` initialized depending on `norm_type`."""
|
| 59 |
+
assert 1 <= ndim <= 3
|
| 60 |
+
bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs)
|
| 61 |
+
if bn.affine:
|
| 62 |
+
bn.bias.data.fill_(1e-3)
|
| 63 |
+
bn.weight.data.fill_(0. if zero else 1.)
|
| 64 |
+
return bn
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs):
|
| 68 |
+
"""BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."""
|
| 69 |
+
return _get_norm('BatchNorm', nf, ndim, zero=norm_type == NormType.BatchZero, **kwargs)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ConvLayer(nn.Sequential):
|
| 73 |
+
"""Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,
|
| 76 |
+
act_cls=nn.ReLU, transpose=False, init=nn.init.kaiming_normal_, xtra=None, **kwargs):
|
| 77 |
+
if padding is None: padding = ((ks - 1) // 2 if not transpose else 0)
|
| 78 |
+
bn = norm_type in (NormType.Batch, NormType.BatchZero)
|
| 79 |
+
inn = norm_type in (NormType.Instance, NormType.InstanceZero)
|
| 80 |
+
if bias is None: bias = not (bn or inn)
|
| 81 |
+
conv_func = _conv_func(ndim, transpose=transpose)
|
| 82 |
+
conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs),
|
| 83 |
+
init)
|
| 84 |
+
if norm_type == NormType.Weight:
|
| 85 |
+
conv = weight_norm(conv)
|
| 86 |
+
elif norm_type == NormType.Spectral:
|
| 87 |
+
conv = spectral_norm(conv)
|
| 88 |
+
layers = [conv]
|
| 89 |
+
act_bn = []
|
| 90 |
+
if act_cls is not None: act_bn.append(act_cls())
|
| 91 |
+
if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))
|
| 92 |
+
if inn: act_bn.append(nn.InstanceNorm2d(nf, norm_type=norm_type, ndim=ndim))
|
| 93 |
+
if bn_1st: act_bn.reverse()
|
| 94 |
+
layers += act_bn
|
| 95 |
+
if xtra: layers.append(xtra)
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def AdaptiveAvgPool(sz=1, ndim=2):
|
| 100 |
+
"""nn.AdaptiveAvgPool layer for `ndim`"""
|
| 101 |
+
assert 1 <= ndim <= 3
|
| 102 |
+
return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
|
| 106 |
+
"""nn.MaxPool layer for `ndim`"""
|
| 107 |
+
assert 1 <= ndim <= 3
|
| 108 |
+
return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
|
| 112 |
+
"""nn.AvgPool layer for `ndim`"""
|
| 113 |
+
assert 1 <= ndim <= 3
|
| 114 |
+
return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ResBlock(nn.Module):
|
| 118 |
+
"Resnet block from `ni` to `nh` with `stride`"
|
| 119 |
+
|
| 120 |
+
@delegates(ConvLayer.__init__)
|
| 121 |
+
def __init__(self, expansion, ni, nf, stride=1, kernel_size=3, groups=1, reduction=None, nh1=None, nh2=None,
|
| 122 |
+
dw=False, g2=1,
|
| 123 |
+
sa=False, sym=False, norm_type=NormType.Batch, act_cls=nn.ReLU, ndim=2,
|
| 124 |
+
pool=AvgPool, pool_first=True, **kwargs):
|
| 125 |
+
super().__init__()
|
| 126 |
+
norm2 = (NormType.BatchZero if norm_type == NormType.Batch else
|
| 127 |
+
NormType.InstanceZero if norm_type == NormType.Instance else norm_type)
|
| 128 |
+
if nh2 is None: nh2 = nf
|
| 129 |
+
if nh1 is None: nh1 = nh2
|
| 130 |
+
nf, ni = nf * expansion, ni * expansion
|
| 131 |
+
k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs)
|
| 132 |
+
k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
|
| 133 |
+
layers = [ConvLayer(ni, nh2, kernel_size, stride=stride, groups=ni if dw else groups, **k0),
|
| 134 |
+
ConvLayer(nh2, nf, kernel_size, groups=g2, **k1)
|
| 135 |
+
] if expansion == 1 else [
|
| 136 |
+
ConvLayer(ni, nh1, 1, **k0),
|
| 137 |
+
ConvLayer(nh1, nh2, kernel_size, stride=stride, groups=nh1 if dw else groups, **k0),
|
| 138 |
+
ConvLayer(nh2, nf, 1, groups=g2, **k1)]
|
| 139 |
+
self.convs = nn.Sequential(*layers)
|
| 140 |
+
convpath = [self.convs]
|
| 141 |
+
if reduction: convpath.append(nn.SEModule(nf, reduction=reduction, act_cls=act_cls))
|
| 142 |
+
if sa: convpath.append(nn.SimpleSelfAttention(nf, ks=1, sym=sym))
|
| 143 |
+
self.convpath = nn.Sequential(*convpath)
|
| 144 |
+
idpath = []
|
| 145 |
+
if ni != nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs))
|
| 146 |
+
if stride != 1: idpath.insert((1, 0)[pool_first], pool(2, ndim=ndim, ceil_mode=True))
|
| 147 |
+
self.idpath = nn.Sequential(*idpath)
|
| 148 |
+
self.act = nn.ReLU(inplace=True) if act_cls is nn.ReLU else act_cls()
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
return self.act(self.convpath(x) + self.idpath(x))
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
######################### adapted from vison.models.xresnet
|
| 155 |
+
def init_cnn(m):
|
| 156 |
+
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
|
| 157 |
+
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight)
|
| 158 |
+
for l in m.children(): init_cnn(l)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class XResNet1d(nn.Sequential):
|
| 162 |
+
@delegates(ResBlock)
|
| 163 |
+
def __init__(self, block, expansion, layers, p=0.0, input_channels=3, num_classes=1000, stem_szs=(32, 32, 64),
|
| 164 |
+
kernel_size=5, kernel_size_stem=5,
|
| 165 |
+
widen=1.0, sa=False, act_cls=nn.ReLU, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False,
|
| 166 |
+
bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
|
| 167 |
+
store_attr(self, 'block,expansion,act_cls')
|
| 168 |
+
stem_szs = [input_channels, *stem_szs]
|
| 169 |
+
stem = [ConvLayer(stem_szs[i], stem_szs[i + 1], ks=kernel_size_stem, stride=2 if i == 0 else 1, act_cls=act_cls,
|
| 170 |
+
ndim=1)
|
| 171 |
+
for i in range(3)]
|
| 172 |
+
|
| 173 |
+
# block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
|
| 174 |
+
block_szs = [int(o * widen) for o in [64, 64, 64, 64] + [32] * (len(layers) - 4)]
|
| 175 |
+
block_szs = [64 // expansion] + block_szs
|
| 176 |
+
blocks = [self._make_layer(ni=block_szs[i], nf=block_szs[i + 1], blocks=l,
|
| 177 |
+
stride=1 if i == 0 else 2, kernel_size=kernel_size, sa=sa and i == len(layers) - 4,
|
| 178 |
+
ndim=1, **kwargs)
|
| 179 |
+
for i, l in enumerate(layers)]
|
| 180 |
+
|
| 181 |
+
head = create_head1d(block_szs[-1] * expansion, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
|
| 182 |
+
bn_final=bn_final_head, bn=bn_head, act=act_head,
|
| 183 |
+
concat_pooling=concat_pooling)
|
| 184 |
+
|
| 185 |
+
super().__init__(nn.MaxPool1d(kernel_size=3, stride=2, padding=1), head)
|
| 186 |
+
init_cnn(self)
|
| 187 |
+
|
| 188 |
+
def _make_layer(self, ni, nf, blocks, stride, kernel_size, sa, **kwargs):
|
| 189 |
+
return nn.Sequential(
|
| 190 |
+
*[self.block(self.expansion, ni if i == 0 else nf, nf, stride=stride if i == 0 else 1,
|
| 191 |
+
kernel_size=kernel_size, sa=sa and i == (blocks - 1), act_cls=self.act_cls, **kwargs)
|
| 192 |
+
for i in range(blocks)])
|
| 193 |
+
|
| 194 |
+
def get_layer_groups(self):
|
| 195 |
+
return self[3], self[-1]
|
| 196 |
+
|
| 197 |
+
def get_output_layer(self):
|
| 198 |
+
return self[-1][-1]
|
| 199 |
+
|
| 200 |
+
def set_output_layer(self, x):
|
| 201 |
+
self[-1][-1] = x
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# xresnets
|
| 205 |
+
def _xresnet1d(expansion, layers, **kwargs):
|
| 206 |
+
return XResNet1d(ResBlock, expansion, layers, **kwargs)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def xresnet1d18(**kwargs): return _xresnet1d(1, [2, 2, 2, 2], **kwargs)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def xresnet1d34(**kwargs): return _xresnet1d(1, [3, 4, 6, 3], **kwargs)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def xresnet1d50(**kwargs): return _xresnet1d(4, [3, 4, 6, 3], **kwargs)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def xresnet1d101(**kwargs): return _xresnet1d(4, [3, 4, 23, 3], **kwargs)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def xresnet1d152(**kwargs): return _xresnet1d(4, [3, 8, 36, 3], **kwargs)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def xresnet1d18_deep(**kwargs): return _xresnet1d(1, [2, 2, 2, 2, 1, 1], **kwargs)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def xresnet1d34_deep(**kwargs): return _xresnet1d(1, [3, 4, 6, 3, 1, 1], **kwargs)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def xresnet1d50_deep(**kwargs): return _xresnet1d(4, [3, 4, 6, 3, 1, 1], **kwargs)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def xresnet1d18_deeper(**kwargs): return _xresnet1d(1, [2, 2, 1, 1, 1, 1, 1, 1], **kwargs)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def xresnet1d34_deeper(**kwargs): return _xresnet1d(1, [3, 4, 6, 3, 1, 1, 1, 1], **kwargs)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def xresnet1d50_deeper(**kwargs): return _xresnet1d(4, [3, 4, 6, 3, 1, 1, 1, 1], **kwargs)
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytorch
|
| 2 |
+
fastai
|
| 3 |
+
blis==0.7.5
|
| 4 |
+
cycler==0.11.0
|
| 5 |
+
cymem==2.0.6
|
| 6 |
+
fastcore==1.3.27
|
| 7 |
+
jinja2==3.0.3
|
| 8 |
+
kiwisolver==1.3.2
|
| 9 |
+
markupsafe==2.0.1
|
| 10 |
+
murmurhash==1.0.6
|
| 11 |
+
pathy==0.6.1
|
| 12 |
+
pillow==8.4.0
|
| 13 |
+
preshed==3.0.6
|
| 14 |
+
pyparsing==3.0.5
|
| 15 |
+
smart-open==5.2.1
|
| 16 |
+
srsly==2.4.2
|
| 17 |
+
thinc==8.0.13
|
| 18 |
+
torchvision==0.11.1
|
| 19 |
+
wfdb==3.4.1
|
| 20 |
+
wget==3.2
|
| 21 |
+
scikit-image
|
| 22 |
+
pyWavelets
|
| 23 |
+
kereas
|
utilities/__pycache__/timeseries_utils.cpython-39.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
utilities/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
utilities/stratify.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def stratify_df(df, new_col_name, n_folds=10, nr_clean_folds=0):
|
| 6 |
+
# compute qualities as described in PTB-XL report
|
| 7 |
+
qualities = []
|
| 8 |
+
for i, row in df.iterrows():
|
| 9 |
+
q = 0
|
| 10 |
+
if 'validated_by_human' in df.columns:
|
| 11 |
+
if row.validated_by_human:
|
| 12 |
+
q = 1
|
| 13 |
+
qualities.append(q)
|
| 14 |
+
df['quality'] = qualities
|
| 15 |
+
|
| 16 |
+
# create stratified folds according to patients
|
| 17 |
+
pat_ids = np.array(sorted(list(set(df.patient_id.values))))
|
| 18 |
+
p_labels = []
|
| 19 |
+
p_qualities = []
|
| 20 |
+
ecgs_per_patient = []
|
| 21 |
+
|
| 22 |
+
for pid in tqdm(pat_ids):
|
| 23 |
+
sel = df[df.patient_id == pid]
|
| 24 |
+
l = np.concatenate([list(d.keys()) for d in sel.scp_codes.values])
|
| 25 |
+
if sel.sex.values[0] == 0:
|
| 26 |
+
gender = 'male'
|
| 27 |
+
else:
|
| 28 |
+
gender = 'female'
|
| 29 |
+
l = np.concatenate((l, [gender] * len(sel)))
|
| 30 |
+
for age in sel.age.values:
|
| 31 |
+
if age < 20:
|
| 32 |
+
l = np.concatenate((l, ['<20']))
|
| 33 |
+
elif 20 <= age < 40:
|
| 34 |
+
l = np.concatenate((l, ['20-40']))
|
| 35 |
+
elif 40 <= age < 60:
|
| 36 |
+
l = np.concatenate((l, ['40-60']))
|
| 37 |
+
elif 60 <= age < 80:
|
| 38 |
+
l = np.concatenate((l, ['60-80']))
|
| 39 |
+
elif age >= 80:
|
| 40 |
+
l = np.concatenate((l, ['>=80']))
|
| 41 |
+
p_labels.append(l)
|
| 42 |
+
ecgs_per_patient.append(len(sel))
|
| 43 |
+
p_qualities.append(sel.quality.min())
|
| 44 |
+
classes = sorted(list(set([item for sublist in p_labels for item in sublist])))
|
| 45 |
+
|
| 46 |
+
stratified_data_ids, stratified_data = stratify(p_labels, classes, [1 / n_folds] * n_folds, p_qualities,
|
| 47 |
+
ecgs_per_patient, nr_clean_folds)
|
| 48 |
+
|
| 49 |
+
df[new_col_name] = np.zeros(len(df)).astype(int)
|
| 50 |
+
for fold_i, fold_ids in tqdm(enumerate(stratified_data_ids)):
|
| 51 |
+
ipat_ids = [pat_ids[pid] for pid in fold_ids]
|
| 52 |
+
df[new_col_name][df.patient_id.isin(ipat_ids)] = fold_i + 1
|
| 53 |
+
|
| 54 |
+
return df
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def stratify(data, classes, ratios, qualities, ecgs_per_patient, nr_clean_folds=1):
|
| 58 |
+
"""Stratifying procedure. Modified from https://vict0rs.ch/2018/05/24/sample-multilabel-dataset/ (based on Sechidis 2011)
|
| 59 |
+
|
| 60 |
+
data is a list of lists: a list of labels, for each sample.
|
| 61 |
+
Each sample's labels should be ints, if they are one-hot encoded, use one_hot=True
|
| 62 |
+
|
| 63 |
+
classes is the list of classes each label can take
|
| 64 |
+
|
| 65 |
+
ratios is a list, summing to 1, of how the dataset should be split
|
| 66 |
+
|
| 67 |
+
qualities: quality per entry (only >0 can be assigned to clean folds; 4 will always be assigned to final fold)
|
| 68 |
+
|
| 69 |
+
ecgs_per_patient: list with number of ecgs per sample
|
| 70 |
+
|
| 71 |
+
nr_clean_folds: the last nr_clean_folds can only take clean entries
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
np.random.seed(0) # fix the random seed
|
| 75 |
+
|
| 76 |
+
# data is now always a list of lists; len(data) is the number of patients; data[i] is the list of all labels for
|
| 77 |
+
# patient i (possibly multiple identical entries)
|
| 78 |
+
|
| 79 |
+
# size is the number of ecgs
|
| 80 |
+
size = np.sum(ecgs_per_patient)
|
| 81 |
+
|
| 82 |
+
# Organize data per label: for each label l, per_label_data[l] contains the list of patients
|
| 83 |
+
# in data which have this label (potentially multiple identical entries)
|
| 84 |
+
per_label_data = {c: [] for c in classes}
|
| 85 |
+
for i, d in enumerate(data):
|
| 86 |
+
for l in d:
|
| 87 |
+
per_label_data[l].append(i)
|
| 88 |
+
|
| 89 |
+
# In order not to compute lengths each time, they are tracked here.
|
| 90 |
+
subset_sizes = [r * size for r in ratios] # list of subset_sizes in terms of ecgs
|
| 91 |
+
per_label_subset_sizes = {c: [r * len(per_label_data[c]) for r in ratios] for c in
|
| 92 |
+
classes} # dictionary with label: list of subset sizes in terms of patients
|
| 93 |
+
|
| 94 |
+
# For each subset we want, the set of sample-ids which should end up in it
|
| 95 |
+
stratified_data_ids = [set() for _ in range(len(ratios))] # initialize empty
|
| 96 |
+
|
| 97 |
+
# For each sample in the data set
|
| 98 |
+
print("Assigning patients to folds...")
|
| 99 |
+
size_prev = size + 1 # just for output
|
| 100 |
+
while size > 0:
|
| 101 |
+
if int(size_prev / 1000) > int(size / 1000):
|
| 102 |
+
print("Remaining patients/ecgs to distribute:", size, "non-empty labels:",
|
| 103 |
+
np.sum([1 for l, label_data in per_label_data.items() if len(label_data) > 0]))
|
| 104 |
+
size_prev = size
|
| 105 |
+
# Compute |Di|
|
| 106 |
+
lengths = {
|
| 107 |
+
l: len(label_data)
|
| 108 |
+
for l, label_data in per_label_data.items()
|
| 109 |
+
} # dictionary label: number of ecgs with this label that have not been assigned to a fold yet
|
| 110 |
+
try:
|
| 111 |
+
# Find label of smallest |Di|
|
| 112 |
+
label = min({k: v for k, v in lengths.items() if v > 0}, key=lengths.get)
|
| 113 |
+
except ValueError:
|
| 114 |
+
# If the dictionary in `min` is empty we get a Value Error.
|
| 115 |
+
# This can happen if there are unlabeled samples.
|
| 116 |
+
# In this case, `size` would be > 0 but only samples without label would remain.
|
| 117 |
+
# "No label" could be a class in itself: it's up to you to format your data accordingly.
|
| 118 |
+
break
|
| 119 |
+
# For each patient with label `label` get patient and corresponding counts
|
| 120 |
+
unique_samples, unique_counts = np.unique(per_label_data[label], return_counts=True)
|
| 121 |
+
idxs_sorted = np.argsort(unique_counts, kind='stable')[::-1]
|
| 122 |
+
unique_samples = unique_samples[
|
| 123 |
+
idxs_sorted] # this is a list of all patient ids with this label sort by size descending
|
| 124 |
+
unique_counts = unique_counts[idxs_sorted] # these are the corresponding counts
|
| 125 |
+
|
| 126 |
+
# loop through all patient ids with this label
|
| 127 |
+
for current_id, current_count in zip(unique_samples, unique_counts):
|
| 128 |
+
|
| 129 |
+
subset_sizes_for_label = per_label_subset_sizes[label] # current subset sizes for the chosen label
|
| 130 |
+
|
| 131 |
+
# if quality is bad remove clean folds (i.e. sample cannot be assigned to clean folds)
|
| 132 |
+
if qualities[current_id] < 1:
|
| 133 |
+
subset_sizes_for_label = subset_sizes_for_label[:len(ratios) - nr_clean_folds]
|
| 134 |
+
|
| 135 |
+
# Find argmax clj i.e. subset in greatest need of the current label
|
| 136 |
+
largest_subsets = np.argwhere(subset_sizes_for_label == np.amax(subset_sizes_for_label)).flatten()
|
| 137 |
+
|
| 138 |
+
# if there is a single best choice: assign it
|
| 139 |
+
if len(largest_subsets) == 1:
|
| 140 |
+
subset = largest_subsets[0]
|
| 141 |
+
# If there is more than one such subset, find the one in greatest need of any label
|
| 142 |
+
else:
|
| 143 |
+
largest_subsets2 = np.argwhere(np.array(subset_sizes)[largest_subsets] == np.amax(
|
| 144 |
+
np.array(subset_sizes)[largest_subsets])).flatten()
|
| 145 |
+
subset = largest_subsets[np.random.choice(largest_subsets2)]
|
| 146 |
+
|
| 147 |
+
# Store the sample's id in the selected subset
|
| 148 |
+
stratified_data_ids[subset].add(current_id)
|
| 149 |
+
|
| 150 |
+
# There is current_count fewer samples to distribute
|
| 151 |
+
size -= ecgs_per_patient[current_id]
|
| 152 |
+
# The selected subset needs current_count fewer samples
|
| 153 |
+
subset_sizes[subset] -= ecgs_per_patient[current_id]
|
| 154 |
+
|
| 155 |
+
# In the selected subset, there is one more example for each label
|
| 156 |
+
# the current sample has
|
| 157 |
+
for l in data[current_id]:
|
| 158 |
+
per_label_subset_sizes[l][subset] -= 1
|
| 159 |
+
|
| 160 |
+
# Remove the sample from the dataset, meaning from all per_label dataset created
|
| 161 |
+
for x in per_label_data.keys():
|
| 162 |
+
per_label_data[x] = [y for y in per_label_data[x] if y != current_id]
|
| 163 |
+
|
| 164 |
+
# Create the stratified dataset as a list of subsets, each containing the original labels
|
| 165 |
+
stratified_data_ids = [sorted(strat) for strat in stratified_data_ids]
|
| 166 |
+
stratified_data = [
|
| 167 |
+
[data[i] for i in strat] for strat in stratified_data_ids
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
# Return both the stratified indexes, to be used to sample the `features` associated with your labels
|
| 171 |
+
# And the stratified labels dataset
|
| 172 |
+
|
| 173 |
+
return stratified_data_ids, stratified_data
|
utilities/timeseries_utils.py
ADDED
|
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.data
|
| 4 |
+
from torch import nn
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from scipy.stats import iqr
|
| 7 |
+
import os
|
| 8 |
+
#Note: due to issues with the numpy rng for multiprocessing
|
| 9 |
+
#(https://github.com/pytorch/pytorch/issues/5059) that could be
|
| 10 |
+
#fixed by a custom worker_init_fn we use random throughout for convenience
|
| 11 |
+
import random
|
| 12 |
+
from skimage import transform
|
| 13 |
+
import warnings
|
| 14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 15 |
+
from scipy.signal import butter, sosfilt, sosfiltfilt, sosfreqz
|
| 16 |
+
#https://stackoverflow.com/questions/12093594/how-to-implement-band-pass-butterworth-filter-with-scipy-signal-butter
|
| 17 |
+
|
| 18 |
+
def butter_filter(lowcut=10, highcut=20, fs=50, order=5, btype='band'): # image processing Butterworth filter
|
| 19 |
+
"""returns butterworth filter with given specifications"""
|
| 20 |
+
nyq = 0.5 * fs
|
| 21 |
+
low = lowcut / nyq
|
| 22 |
+
high = highcut / nyq
|
| 23 |
+
|
| 24 |
+
sos = butter(order, [low, high] if btype == "band" else (low if btype == "low" else high), analog=False,
|
| 25 |
+
btype=btype, output='sos')
|
| 26 |
+
return sos
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def butter_filter_frequency_response(filter):
|
| 30 |
+
"""returns frequency response of a given filter (result of call of butter_filter)"""
|
| 31 |
+
w, h = sosfreqz(filter)
|
| 32 |
+
# gain vs. freq(Hz)
|
| 33 |
+
# plt.plot((fs * 0.5 / np.pi) * w, abs(h))
|
| 34 |
+
return w, h
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def apply_butter_filter(data, filter, forwardbackward=True): # The function provides options for handling the edges of the signal.
|
| 38 |
+
"""pass filter from call of butter_filter to data (assuming time axis at dimension 0)"""
|
| 39 |
+
if forwardbackward:
|
| 40 |
+
return sosfiltfilt(filter, data, axis=0)
|
| 41 |
+
else:
|
| 42 |
+
data = sosfilt(filter, data, axis=0)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def dataset_add_chunk_col(df, col="data"):
|
| 46 |
+
"""add a chunk column to the dataset df"""
|
| 47 |
+
df["chunk"] = df.groupby(col).cumcount()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def dataset_add_length_col(df, col="data", data_folder=None):
|
| 51 |
+
"""add a length column to the dataset df"""
|
| 52 |
+
df[col + "_length"] = df[col].apply(lambda x: len(np.load(x if data_folder is None else data_folder / x)))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def dataset_add_labels_col(df, col="label", data_folder=None):
|
| 56 |
+
"""add a column with unique labels in column col"""
|
| 57 |
+
df[col + "_labels"] = df[col].apply(
|
| 58 |
+
lambda x: list(np.unique(np.load(x if data_folder is None else data_folder / x))))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def dataset_add_mean_col(df, col="data", axis=(0), data_folder=None):
|
| 62 |
+
"""adds a column with mean"""
|
| 63 |
+
df[col + "_mean"] = df[col].apply(
|
| 64 |
+
lambda x: np.mean(np.load(x if data_folder is None else data_folder / x), axis=axis))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def dataset_add_median_col(df, col="data", axis=(0), data_folder=None):
|
| 68 |
+
"""adds a column with median"""
|
| 69 |
+
df[col + "_median"] = df[col].apply(
|
| 70 |
+
lambda x: np.median(np.load(x if data_folder is None else data_folder / x), axis=axis))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def dataset_add_std_col(df, col="data", axis=(0), data_folder=None):
|
| 74 |
+
"""adds a column with mean"""
|
| 75 |
+
df[col + "_std"] = df[col].apply(
|
| 76 |
+
lambda x: np.std(np.load(x if data_folder is None else data_folder / x), axis=axis))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def dataset_add_iqr_col(df, col="data", axis=(0), data_folder=None):
|
| 80 |
+
"""adds a column with mean"""
|
| 81 |
+
df[col + "_iqr"] = df[col].apply(lambda x: iqr(np.load(x if data_folder is None else data_folder / x), axis=axis))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def dataset_get_stats(df, col="data", median=False):
|
| 85 |
+
"""creates weighted means and stds from mean, std and length cols of the df"""
|
| 86 |
+
mean = np.average(np.stack(df[col + ("_median" if median is True else "_mean")], axis=0), axis=0,
|
| 87 |
+
weights=np.array(df[col + "_length"]))
|
| 88 |
+
std = np.average(np.stack(df[col + ("_iqr" if median is True else "_std")], axis=0), axis=0,
|
| 89 |
+
weights=np.array(df[col + "_length"]))
|
| 90 |
+
return mean, std
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def npys_to_memmap(npys, target_filename, delete_npys=False):
|
| 94 |
+
memmap = None
|
| 95 |
+
start = []
|
| 96 |
+
length = []
|
| 97 |
+
files = []
|
| 98 |
+
ids = []
|
| 99 |
+
|
| 100 |
+
for idx, npy in enumerate(npys):
|
| 101 |
+
data = np.load(npy)
|
| 102 |
+
if memmap is None:
|
| 103 |
+
memmap = np.memmap(target_filename, dtype=data.dtype, mode='w+', shape=data.shape)
|
| 104 |
+
start.append(0)
|
| 105 |
+
length.append(data.shape[0])
|
| 106 |
+
else:
|
| 107 |
+
start.append(start[-1] + length[-1])
|
| 108 |
+
length.append(data.shape[0])
|
| 109 |
+
memmap = np.memmap(target_filename, dtype=data.dtype, mode='r+',
|
| 110 |
+
shape=tuple([start[-1] + length[-1]] + [l for l in data.shape[1:]]))
|
| 111 |
+
|
| 112 |
+
ids.append(idx)
|
| 113 |
+
memmap[start[-1]:start[-1] + length[-1]] = data[:]
|
| 114 |
+
memmap.flush()
|
| 115 |
+
if delete_npys is True:
|
| 116 |
+
npy.unlink()
|
| 117 |
+
del memmap
|
| 118 |
+
|
| 119 |
+
np.savez(target_filename.parent / (target_filename.stem + "_meta.npz"), start=start, length=length,
|
| 120 |
+
shape=[start[-1] + length[-1]] + [l for l in data.shape[1:]], dtype=data.dtype)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def reformat_as_memmap(df, target_filename, data_folder=None, annotation=False, delete_npys=False):
|
| 124 |
+
npys_data = []
|
| 125 |
+
npys_label = []
|
| 126 |
+
|
| 127 |
+
for id, row in df.iterrows():
|
| 128 |
+
npys_data.append(data_folder / row["data"] if data_folder is not None else row["data"])
|
| 129 |
+
if annotation:
|
| 130 |
+
npys_label.append(data_folder / row["label"] if data_folder is not None else row["label"])
|
| 131 |
+
|
| 132 |
+
npys_to_memmap(npys_data, target_filename, delete_npys=delete_npys)
|
| 133 |
+
if annotation:
|
| 134 |
+
npys_to_memmap(npys_label, target_filename.parent / (target_filename.stem + "_label.npy"),
|
| 135 |
+
delete_npys=delete_npys)
|
| 136 |
+
|
| 137 |
+
# replace data(filename) by integer
|
| 138 |
+
df_mapped = df.copy()
|
| 139 |
+
df_mapped["data_original"] = df_mapped.data
|
| 140 |
+
df_mapped["data"] = np.arange(len(df_mapped))
|
| 141 |
+
df_mapped.to_pickle(target_filename.parent / ("df_" + target_filename.stem + ".pkl"))
|
| 142 |
+
return df_mapped
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# TimeseriesDatasetCrops
|
| 146 |
+
|
| 147 |
+
class TimeseriesDatasetCrops(torch.utils.data.Dataset):
|
| 148 |
+
"""timeseries dataset with partial crops."""
|
| 149 |
+
|
| 150 |
+
def __init__(self, df, output_size, chunk_length, min_chunk_length, memmap_filename=None, npy_data=None,
|
| 151 |
+
random_crop=True, data_folder=None, num_classes=2, copies=0, col_lbl="label", stride=None, start_idx=0,
|
| 152 |
+
annotation=False, transforms=None):
|
| 153 |
+
"""
|
| 154 |
+
accepts three kinds of input:
|
| 155 |
+
1) filenames pointing to aligned numpy arrays [timesteps,channels,...] for data and either integer labels or filename pointing to numpy arrays[timesteps,...] e.g. for annotations
|
| 156 |
+
2) memmap_filename to memmap for data [concatenated,...] and labels- label column in df corresponds to index in this memmap
|
| 157 |
+
3) npy_data [samples,ts,...] (either path or np.array directly- also supporting variable length input) - label column in df corresponds to sampleid
|
| 158 |
+
|
| 159 |
+
transforms: list of callables (transformations) (applied in the specified order i.e. leftmost element first)
|
| 160 |
+
"""
|
| 161 |
+
if transforms is None:
|
| 162 |
+
transforms = []
|
| 163 |
+
assert not ((memmap_filename is not None) and (npy_data is not None))
|
| 164 |
+
# require integer entries if using memmap or npy
|
| 165 |
+
assert (memmap_filename is None and npy_data is None) or df.data.dtype == np.int64
|
| 166 |
+
|
| 167 |
+
self.timeseries_df = df
|
| 168 |
+
self.output_size = output_size
|
| 169 |
+
self.data_folder = data_folder
|
| 170 |
+
self.transforms = transforms
|
| 171 |
+
self.annotation = annotation
|
| 172 |
+
self.col_lbl = col_lbl
|
| 173 |
+
|
| 174 |
+
self.c = num_classes
|
| 175 |
+
|
| 176 |
+
self.mode = "files"
|
| 177 |
+
self.memmap_filename = memmap_filename
|
| 178 |
+
if memmap_filename is not None:
|
| 179 |
+
self.mode = "memmap"
|
| 180 |
+
memmap_meta = np.load(memmap_filename.parent / (memmap_filename.stem + "_meta.npz"))
|
| 181 |
+
self.memmap_start = memmap_meta["start"]
|
| 182 |
+
self.memmap_shape = tuple(memmap_meta["shape"])
|
| 183 |
+
self.memmap_length = memmap_meta["length"]
|
| 184 |
+
self.memmap_dtype = np.dtype(str(memmap_meta["dtype"]))
|
| 185 |
+
self.memmap_file_process_dict = {}
|
| 186 |
+
if annotation:
|
| 187 |
+
memmap_meta_label = np.load(memmap_filename.parent / (memmap_filename.stem + "_label_meta.npz"))
|
| 188 |
+
self.memmap_filename_label = memmap_filename.parent / (memmap_filename.stem + "_label.npy")
|
| 189 |
+
self.memmap_shape_label = tuple(memmap_meta_label["shape"])
|
| 190 |
+
self.memmap_file_process_dict_label = {}
|
| 191 |
+
self.memmap_dtype_label = np.dtype(str(memmap_meta_label["dtype"]))
|
| 192 |
+
elif npy_data is not None:
|
| 193 |
+
self.mode = "npy"
|
| 194 |
+
if isinstance(npy_data, np.ndarray) or isinstance(npy_data, list):
|
| 195 |
+
self.npy_data = np.array(npy_data)
|
| 196 |
+
assert (annotation is False)
|
| 197 |
+
else:
|
| 198 |
+
self.npy_data = np.load(npy_data)
|
| 199 |
+
if annotation:
|
| 200 |
+
self.npy_data_label = np.load(npy_data.parent / (npy_data.stem + "_label.npy"))
|
| 201 |
+
|
| 202 |
+
self.random_crop = random_crop
|
| 203 |
+
|
| 204 |
+
self.df_idx_mapping = []
|
| 205 |
+
self.start_idx_mapping = []
|
| 206 |
+
self.end_idx_mapping = []
|
| 207 |
+
|
| 208 |
+
for df_idx, (id, row) in enumerate(df.iterrows()):
|
| 209 |
+
if self.mode == "files":
|
| 210 |
+
data_length = row["data_length"]
|
| 211 |
+
elif self.mode == "memmap":
|
| 212 |
+
data_length = self.memmap_length[row["data"]]
|
| 213 |
+
else: # npy
|
| 214 |
+
data_length = len(self.npy_data[row["data"]])
|
| 215 |
+
|
| 216 |
+
if chunk_length == 0: # do not split
|
| 217 |
+
idx_start = [start_idx]
|
| 218 |
+
idx_end = [data_length]
|
| 219 |
+
else:
|
| 220 |
+
idx_start = list(range(start_idx, data_length, chunk_length if stride is None else stride))
|
| 221 |
+
idx_end = [min(l + chunk_length, data_length) for l in idx_start]
|
| 222 |
+
|
| 223 |
+
# remove final chunk(s) if too short
|
| 224 |
+
for i in range(len(idx_start)):
|
| 225 |
+
if idx_end[i] - idx_start[i] < min_chunk_length:
|
| 226 |
+
del idx_start[i:]
|
| 227 |
+
del idx_end[i:]
|
| 228 |
+
break
|
| 229 |
+
# append to lists
|
| 230 |
+
for _ in range(copies + 1):
|
| 231 |
+
for i_s, i_e in zip(idx_start, idx_end):
|
| 232 |
+
self.df_idx_mapping.append(df_idx)
|
| 233 |
+
self.start_idx_mapping.append(i_s)
|
| 234 |
+
self.end_idx_mapping.append(i_e)
|
| 235 |
+
|
| 236 |
+
def __len__(self):
|
| 237 |
+
return len(self.df_idx_mapping)
|
| 238 |
+
|
| 239 |
+
def __getitem__(self, idx):
|
| 240 |
+
df_idx = self.df_idx_mapping[idx]
|
| 241 |
+
start_idx = self.start_idx_mapping[idx]
|
| 242 |
+
end_idx = self.end_idx_mapping[idx]
|
| 243 |
+
# determine crop idxs
|
| 244 |
+
timesteps = end_idx - start_idx
|
| 245 |
+
assert (timesteps >= self.output_size)
|
| 246 |
+
if self.random_crop: # random crop
|
| 247 |
+
if timesteps == self.output_size:
|
| 248 |
+
start_idx_crop = start_idx
|
| 249 |
+
else:
|
| 250 |
+
start_idx_crop = start_idx + random.randint(0, timesteps - self.output_size - 1) # np.random.randint(0, timesteps - self.output_size)
|
| 251 |
+
else:
|
| 252 |
+
start_idx_crop = start_idx + (timesteps - self.output_size) // 2
|
| 253 |
+
end_idx_crop = start_idx_crop + self.output_size
|
| 254 |
+
|
| 255 |
+
# print(idx,start_idx,end_idx,start_idx_crop,end_idx_crop)
|
| 256 |
+
# load the actual data
|
| 257 |
+
if self.mode == "files": # from separate files
|
| 258 |
+
data_filename = self.timeseries_df.iloc[df_idx]["data"]
|
| 259 |
+
if self.data_folder is not None:
|
| 260 |
+
data_filename = self.data_folder / data_filename
|
| 261 |
+
data = np.load(data_filename)[
|
| 262 |
+
start_idx_crop:end_idx_crop] # data type has to be adjusted when saving to npy
|
| 263 |
+
|
| 264 |
+
ID = data_filename.stem
|
| 265 |
+
|
| 266 |
+
if self.annotation is True:
|
| 267 |
+
label_filename = self.timeseries_df.iloc[df_idx][self.col_lbl]
|
| 268 |
+
if self.data_folder is not None:
|
| 269 |
+
label_filename = self.data_folder / label_filename
|
| 270 |
+
label = np.load(label_filename)[
|
| 271 |
+
start_idx_crop:end_idx_crop] # data type has to be adjusted when saving to npy
|
| 272 |
+
else:
|
| 273 |
+
label = self.timeseries_df.iloc[df_idx][self.col_lbl] # input type has to be adjusted in the dataframe
|
| 274 |
+
elif self.mode == "memmap": # from one memmap file
|
| 275 |
+
ID = self.timeseries_df.iloc[df_idx]["data_original"].stem
|
| 276 |
+
memmap_idx = self.timeseries_df.iloc[df_idx][
|
| 277 |
+
"data"] # grab the actual index (Note the df to create the ds might be a subset of the original df used to create the memmap)
|
| 278 |
+
idx_offset = self.memmap_start[memmap_idx]
|
| 279 |
+
|
| 280 |
+
pid = os.getpid()
|
| 281 |
+
# print("idx",idx,"ID",ID,"idx_offset",idx_offset,"start_idx_crop",start_idx_crop,"df_idx", self.df_idx_mapping[idx],"pid",pid)
|
| 282 |
+
mem_file = self.memmap_file_process_dict.get(pid, None) # each process owns its handler.
|
| 283 |
+
if mem_file is None:
|
| 284 |
+
# print("memmap_shape", self.memmap_shape)
|
| 285 |
+
mem_file = np.memmap(self.memmap_filename, self.memmap_dtype, mode='r', shape=self.memmap_shape)
|
| 286 |
+
self.memmap_file_process_dict[pid] = mem_file
|
| 287 |
+
data = np.copy(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
|
| 288 |
+
# print(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
|
| 289 |
+
if self.annotation:
|
| 290 |
+
mem_file_label = self.memmap_file_process_dict_label.get(pid, None) # each process owns its handler.
|
| 291 |
+
if mem_file_label is None:
|
| 292 |
+
mem_file_label = np.memmap(self.memmap_filename_label, self.memmap_dtype, mode='r',
|
| 293 |
+
shape=self.memmap_shape_label)
|
| 294 |
+
self.memmap_file_process_dict_label[pid] = mem_file_label
|
| 295 |
+
label = np.copy(mem_file_label[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
|
| 296 |
+
else:
|
| 297 |
+
label = self.timeseries_df.iloc[df_idx][self.col_lbl]
|
| 298 |
+
else: # single npy array
|
| 299 |
+
ID = self.timeseries_df.iloc[df_idx]["data"]
|
| 300 |
+
|
| 301 |
+
data = self.npy_data[ID][start_idx_crop:end_idx_crop]
|
| 302 |
+
|
| 303 |
+
if self.annotation:
|
| 304 |
+
label = self.npy_data_label[ID][start_idx_crop:end_idx_crop]
|
| 305 |
+
else:
|
| 306 |
+
label = self.timeseries_df.iloc[df_idx][self.col_lbl]
|
| 307 |
+
sample = {'data': data, 'label': label, 'ID': ID}
|
| 308 |
+
|
| 309 |
+
for t in self.transforms:
|
| 310 |
+
sample = t(sample)
|
| 311 |
+
|
| 312 |
+
return sample
|
| 313 |
+
|
| 314 |
+
def get_sampling_weights(self, class_weight_dict, length_weighting=False, group_by_col=None):
|
| 315 |
+
assert (self.annotation is False)
|
| 316 |
+
assert (length_weighting is False or group_by_col is None)
|
| 317 |
+
weights = np.zeros(len(self.df_idx_mapping), dtype=np.float32)
|
| 318 |
+
length_per_class = {}
|
| 319 |
+
length_per_group = {}
|
| 320 |
+
for iw, (i, s, e) in enumerate(zip(self.df_idx_mapping, self.start_idx_mapping, self.end_idx_mapping)):
|
| 321 |
+
label = self.timeseries_df.iloc[i][self.col_lbl]
|
| 322 |
+
weight = class_weight_dict[label]
|
| 323 |
+
if length_weighting:
|
| 324 |
+
if label in length_per_class.keys():
|
| 325 |
+
length_per_class[label] += e - s
|
| 326 |
+
else:
|
| 327 |
+
length_per_class[label] = e - s
|
| 328 |
+
if group_by_col is not None:
|
| 329 |
+
group = self.timeseries_df.iloc[i][group_by_col]
|
| 330 |
+
if group in length_per_group.keys():
|
| 331 |
+
length_per_group[group] += e - s
|
| 332 |
+
else:
|
| 333 |
+
length_per_group[group] = e - s
|
| 334 |
+
weights[iw] = weight
|
| 335 |
+
|
| 336 |
+
if length_weighting: # need second pass to properly take into account the total length per class
|
| 337 |
+
for iw, (i, s, e) in enumerate(zip(self.df_idx_mapping, self.start_idx_mapping, self.end_idx_mapping)):
|
| 338 |
+
label = self.timeseries_df.iloc[i][self.col_lbl]
|
| 339 |
+
weights[iw] = (e - s) / length_per_class[label] * weights[iw]
|
| 340 |
+
if group_by_col is not None:
|
| 341 |
+
for iw, (i, s, e) in enumerate(zip(self.df_idx_mapping, self.start_idx_mapping, self.end_idx_mapping)):
|
| 342 |
+
group = self.timeseries_df.iloc[i][group_by_col]
|
| 343 |
+
weights[iw] = (e - s) / length_per_group[group] * weights[iw]
|
| 344 |
+
|
| 345 |
+
weights = weights / np.min(weights) # normalize smallest weight to 1
|
| 346 |
+
return weights
|
| 347 |
+
|
| 348 |
+
def get_id_mapping(self):
|
| 349 |
+
return self.df_idx_mapping
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class RandomCrop(object):
|
| 353 |
+
"""
|
| 354 |
+
Crop randomly the image in a sample (deprecated).
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(self, output_size, annotation=False):
|
| 358 |
+
self.output_size = output_size
|
| 359 |
+
self.annotation = annotation
|
| 360 |
+
|
| 361 |
+
def __call__(self, sample):
|
| 362 |
+
data, label, ID = sample['data'], sample['label'], sample['ID']
|
| 363 |
+
|
| 364 |
+
timesteps = len(data)
|
| 365 |
+
assert (timesteps >= self.output_size)
|
| 366 |
+
if timesteps == self.output_size:
|
| 367 |
+
start = 0
|
| 368 |
+
else:
|
| 369 |
+
start = random.randint(0, timesteps - self.output_size - 1) # np.random.randint(0, timesteps - self.output_size)
|
| 370 |
+
|
| 371 |
+
data = data[start: start + self.output_size]
|
| 372 |
+
if self.annotation:
|
| 373 |
+
label = label[start: start + self.output_size]
|
| 374 |
+
|
| 375 |
+
return {'data': data, 'label': label, "ID": ID}
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class CenterCrop(object):
|
| 379 |
+
"""
|
| 380 |
+
Center crop the image in a sample (deprecated).
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
def __init__(self, output_size, annotation=False):
|
| 384 |
+
self.output_size = output_size
|
| 385 |
+
self.annotation = annotation
|
| 386 |
+
|
| 387 |
+
def __call__(self, sample):
|
| 388 |
+
data, label, ID = sample['data'], sample['label'], sample['ID']
|
| 389 |
+
|
| 390 |
+
timesteps = len(data)
|
| 391 |
+
|
| 392 |
+
start = (timesteps - self.output_size) // 2
|
| 393 |
+
|
| 394 |
+
data = data[start: start + self.output_size]
|
| 395 |
+
if self.annotation:
|
| 396 |
+
label = label[start: start + self.output_size]
|
| 397 |
+
|
| 398 |
+
return {'data': data, 'label': label, "ID": ID}
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class GaussianNoise(object):
|
| 402 |
+
"""
|
| 403 |
+
Add gaussian noise to sample.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
def __init__(self, scale=0.1):
|
| 407 |
+
self.scale = scale
|
| 408 |
+
|
| 409 |
+
def __call__(self, sample):
|
| 410 |
+
if self.scale == 0:
|
| 411 |
+
return sample
|
| 412 |
+
else:
|
| 413 |
+
data, label, ID = sample['data'], sample['label'], sample['ID']
|
| 414 |
+
data = data + np.reshape(np.array([random.gauss(0, self.scale) for _ in range(np.prod(data.shape))]),
|
| 415 |
+
data.shape) # np.random.normal(scale=self.scale,size=data.shape).astype(np.float32)
|
| 416 |
+
return {'data': data, 'label': label, "ID": ID}
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class Rescale(object):
|
| 420 |
+
"""Rescale by factor.
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
def __init__(self, scale=0.5, interpolation_order=3):
|
| 424 |
+
self.scale = scale
|
| 425 |
+
self.interpolation_order = interpolation_order
|
| 426 |
+
|
| 427 |
+
def __call__(self, sample):
|
| 428 |
+
if self.scale == 1:
|
| 429 |
+
return sample
|
| 430 |
+
else:
|
| 431 |
+
data, label, ID = sample['data'], sample['label'], sample['ID']
|
| 432 |
+
timesteps_new = int(self.scale * len(data))
|
| 433 |
+
data = transform.resize(data, (timesteps_new, data.shape[1]), order=self.interpolation_order).astype(
|
| 434 |
+
np.float32)
|
| 435 |
+
return {'data': data, 'label': label, "ID": ID}
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class ToTensor(object):
|
| 439 |
+
"""Convert ndarrays in sample to Tensors."""
|
| 440 |
+
|
| 441 |
+
def __init__(self, transpose_data1d=True):
|
| 442 |
+
self.transpose_data1d = transpose_data1d
|
| 443 |
+
|
| 444 |
+
def __call__(self, sample):
|
| 445 |
+
def _to_tensor(data, transpose_data1d=False):
|
| 446 |
+
if (
|
| 447 |
+
len(data.shape) == 2 and transpose_data1d is True): # swap channel and time axis for direct application of pytorch's 1d convs
|
| 448 |
+
data = data.transpose((1, 0))
|
| 449 |
+
if isinstance(data, np.ndarray):
|
| 450 |
+
return torch.from_numpy(data)
|
| 451 |
+
else: # default_collate will take care of it
|
| 452 |
+
return data
|
| 453 |
+
|
| 454 |
+
data, label, ID = sample['data'], sample['label'], sample['ID']
|
| 455 |
+
|
| 456 |
+
if not isinstance(data, tuple):
|
| 457 |
+
data = _to_tensor(data, self.transpose_data1d)
|
| 458 |
+
else:
|
| 459 |
+
data = tuple(_to_tensor(x, self.transpose_data1d) for x in data)
|
| 460 |
+
|
| 461 |
+
if not isinstance(label, tuple):
|
| 462 |
+
label = _to_tensor(label)
|
| 463 |
+
else:
|
| 464 |
+
label = tuple(_to_tensor(x) for x in label)
|
| 465 |
+
|
| 466 |
+
return data, label # returning as a tuple (potentially of lists)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class Normalize(object):
|
| 470 |
+
"""
|
| 471 |
+
Normalize using given stats.
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
def __init__(self, stats_mean, stats_std, input=True, channels=None):
|
| 475 |
+
if channels is None:
|
| 476 |
+
channels = []
|
| 477 |
+
self.stats_mean = np.expand_dims(stats_mean.astype(np.float32), axis=0) if stats_mean is not None else None
|
| 478 |
+
self.stats_std = np.expand_dims(stats_std.astype(np.float32), axis=0) + 1e-8 if stats_std is not None else None
|
| 479 |
+
self.input = input
|
| 480 |
+
if len(channels) > 0:
|
| 481 |
+
for i in range(len(stats_mean)):
|
| 482 |
+
if not (i in channels):
|
| 483 |
+
self.stats_mean[:, i] = 0
|
| 484 |
+
self.stats_std[:, i] = 1
|
| 485 |
+
|
| 486 |
+
def __call__(self, sample):
|
| 487 |
+
if self.input:
|
| 488 |
+
data = sample['data']
|
| 489 |
+
else:
|
| 490 |
+
data = sample['label']
|
| 491 |
+
|
| 492 |
+
if self.stats_mean is not None:
|
| 493 |
+
data = data - self.stats_mean
|
| 494 |
+
if self.stats_std is not None:
|
| 495 |
+
data = data / self.stats_std
|
| 496 |
+
|
| 497 |
+
if self.input:
|
| 498 |
+
return {'data': data, 'label': sample['label'], "ID": sample['ID']}
|
| 499 |
+
else:
|
| 500 |
+
return {'data': sample['data'], 'label': data, "ID": sample['ID']}
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class ButterFilter(object):
|
| 504 |
+
"""
|
| 505 |
+
Normalize using given stats.
|
| 506 |
+
"""
|
| 507 |
+
|
| 508 |
+
def __init__(self, lowcut=50, highcut=50, fs=100, order=5, btype='band', forwardbackward=True, input=True):
|
| 509 |
+
self.filter = butter_filter(lowcut, highcut, fs, order, btype)
|
| 510 |
+
self.input = input
|
| 511 |
+
self.forwardbackward = forwardbackward
|
| 512 |
+
|
| 513 |
+
def __call__(self, sample):
|
| 514 |
+
if self.input:
|
| 515 |
+
data = sample['data']
|
| 516 |
+
else:
|
| 517 |
+
data = sample['label']
|
| 518 |
+
|
| 519 |
+
# check multiple axis
|
| 520 |
+
if self.forwardbackward:
|
| 521 |
+
data = sosfiltfilt(self.filter, data, axis=0)
|
| 522 |
+
else:
|
| 523 |
+
data = sosfilt(self.filter, data, axis=0)
|
| 524 |
+
|
| 525 |
+
if self.input:
|
| 526 |
+
return {'data': data, 'label': sample['label'], "ID": sample['ID']}
|
| 527 |
+
else:
|
| 528 |
+
return {'data': sample['data'], 'label': data, "ID": sample['ID']}
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class ChannelFilter(object):
|
| 532 |
+
"""
|
| 533 |
+
Select certain channels.
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
def __init__(self, channels=None, input=True):
|
| 537 |
+
if channels is None:
|
| 538 |
+
channels = [0]
|
| 539 |
+
self.channels = channels
|
| 540 |
+
self.input = input
|
| 541 |
+
|
| 542 |
+
def __call__(self, sample):
|
| 543 |
+
if self.input:
|
| 544 |
+
return {'data': sample['data'][:, self.channels], 'label': sample['label'], "ID": sample['ID']}
|
| 545 |
+
else:
|
| 546 |
+
return {'data': sample['data'], 'label': sample['label'][:, self.channels], "ID": sample['ID']}
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class Transform(object):
|
| 550 |
+
"""
|
| 551 |
+
Transforms data using a given function i.e. data_new = func(data) for input is True else label_new = func(label)
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
def __init__(self, func, input=False):
|
| 555 |
+
self.func = func
|
| 556 |
+
self.input = input
|
| 557 |
+
|
| 558 |
+
def __call__(self, sample):
|
| 559 |
+
if self.input:
|
| 560 |
+
return {'data': self.func(sample['data']), 'label': sample['label'], "ID": sample['ID']}
|
| 561 |
+
else:
|
| 562 |
+
return {'data': sample['data'], 'label': self.func(sample['label']), "ID": sample['ID']}
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
class TupleTransform(object):
|
| 566 |
+
"""
|
| 567 |
+
Transforms data using a given function (operating on both data and label and return a tuple) i.e. data_new, label_new = func(data_old, label_old)
|
| 568 |
+
"""
|
| 569 |
+
|
| 570 |
+
def __init__(self, func, input=False):
|
| 571 |
+
self.func = func
|
| 572 |
+
|
| 573 |
+
def __call__(self, sample):
|
| 574 |
+
data_new, label_new = self.func(sample['data'], sample['label'])
|
| 575 |
+
return {'data': data_new, 'label': label_new, "ID": sample['ID']}
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# MIL and ensemble models
|
| 579 |
+
|
| 580 |
+
def aggregate_predictions(preds, targs=None, idmap=None, aggregate_fn=np.mean, verbose=True):
|
| 581 |
+
"""
|
| 582 |
+
aggregates potentially multiple predictions per sample (can also pass targs for convenience)
|
| 583 |
+
idmap: idmap as returned by TimeSeriesCropsDataset's get_id_mapping
|
| 584 |
+
preds: ordered predictions as returned by learn.get_preds()
|
| 585 |
+
aggregate_fn: function that is used to aggregate multiple predictions per sample (most commonly np.amax or np.mean)
|
| 586 |
+
"""
|
| 587 |
+
if idmap is not None and len(idmap) != len(np.unique(idmap)):
|
| 588 |
+
if verbose:
|
| 589 |
+
print("aggregating predictions...")
|
| 590 |
+
preds_aggregated = []
|
| 591 |
+
targs_aggregated = []
|
| 592 |
+
for i in np.unique(idmap):
|
| 593 |
+
preds_local = preds[np.where(idmap == i)[0]]
|
| 594 |
+
preds_aggregated.append(aggregate_fn(preds_local, axis=0))
|
| 595 |
+
if targs is not None:
|
| 596 |
+
targs_local = targs[np.where(idmap == i)[0]]
|
| 597 |
+
assert (np.all(targs_local == targs_local[0])) # all labels have to agree
|
| 598 |
+
targs_aggregated.append(targs_local[0])
|
| 599 |
+
if targs is None:
|
| 600 |
+
return np.array(preds_aggregated)
|
| 601 |
+
else:
|
| 602 |
+
return np.array(preds_aggregated), np.array(targs_aggregated)
|
| 603 |
+
else:
|
| 604 |
+
if targs is None:
|
| 605 |
+
return preds
|
| 606 |
+
else:
|
| 607 |
+
return preds, targs
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class milwrapper(nn.Module):
|
| 611 |
+
def __init__(self, model, input_size, n, stride=None, softmax=True):
|
| 612 |
+
super().__init__()
|
| 613 |
+
self.n = n
|
| 614 |
+
self.input_size = input_size
|
| 615 |
+
self.model = model
|
| 616 |
+
self.softmax = softmax
|
| 617 |
+
self.stride = input_size if stride is None else stride
|
| 618 |
+
|
| 619 |
+
def forward(self, x):
|
| 620 |
+
# bs,ch,seq
|
| 621 |
+
for i in range(self.n):
|
| 622 |
+
pred_single = self.model(x[:, :, i * self.stride:i * self.stride + self.input_size])
|
| 623 |
+
pred_single = nn.functional.softmax(pred_single, dim=1)
|
| 624 |
+
if i == 0:
|
| 625 |
+
pred = pred_single
|
| 626 |
+
else:
|
| 627 |
+
pred += pred_single
|
| 628 |
+
return pred / self.n
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
class ensemblewrapper(nn.Module):
|
| 632 |
+
def __init__(self, model, checkpts):
|
| 633 |
+
super().__init__()
|
| 634 |
+
self.model = model
|
| 635 |
+
self.checkpts = checkpts
|
| 636 |
+
|
| 637 |
+
def forward(self, x):
|
| 638 |
+
# bs,ch,seq
|
| 639 |
+
for i, c in enumerate(self.checkpts):
|
| 640 |
+
state = torch.load(Path("./models/") / f'{c}.pth', map_location=x.device)
|
| 641 |
+
self.model.load_state_dict(state['model'], strict=True)
|
| 642 |
+
|
| 643 |
+
pred_single = self.model(x)
|
| 644 |
+
pred_single = nn.functional.softmax(pred_single, dim=1)
|
| 645 |
+
if (i == 0):
|
| 646 |
+
pred = pred_single
|
| 647 |
+
else:
|
| 648 |
+
pred += pred_single
|
| 649 |
+
return pred / len(self.checkpts)
|
utilities/utils.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import pickle
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import wfdb
|
| 8 |
+
import ast
|
| 9 |
+
from sklearn.metrics import roc_auc_score, roc_curve
|
| 10 |
+
from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# EVALUATION STUFF
|
| 14 |
+
def generate_results(idxs, y_true, y_pred, thresholds):
|
| 15 |
+
return evaluate_experiment(y_true[idxs], y_pred[idxs], thresholds)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def evaluate_experiment(y_true, y_pred, thresholds=None):
|
| 19 |
+
results = {}
|
| 20 |
+
|
| 21 |
+
if not thresholds is None:
|
| 22 |
+
# binary predictions
|
| 23 |
+
y_pred_binary = apply_thresholds(y_pred, thresholds)
|
| 24 |
+
# PhysioNet/CinC Challenges metrics
|
| 25 |
+
challenge_scores = challenge_metrics(y_true, y_pred_binary, beta1=2, beta2=2)
|
| 26 |
+
results['F_beta_macro'] = challenge_scores['F_beta_macro']
|
| 27 |
+
results['G_beta_macro'] = challenge_scores['G_beta_macro']
|
| 28 |
+
results['TP'] = challenge_scores['TP']
|
| 29 |
+
results['TN'] = challenge_scores['TN']
|
| 30 |
+
results['FP'] = challenge_scores['FP']
|
| 31 |
+
results['FN'] = challenge_scores['FN']
|
| 32 |
+
results['Accuracy'] = challenge_scores['Accuracy']
|
| 33 |
+
results['F1'] = challenge_scores['F1']
|
| 34 |
+
results['Precision'] = challenge_scores['Precision']
|
| 35 |
+
results['Recall'] = challenge_scores['Recall']
|
| 36 |
+
|
| 37 |
+
# label based metric
|
| 38 |
+
results['macro_auc'] = roc_auc_score(y_true, y_pred, average='macro')
|
| 39 |
+
|
| 40 |
+
df_result = pd.DataFrame(results, index=[0])
|
| 41 |
+
return df_result
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def challenge_metrics(y_true, y_pred, beta1=2, beta2=2, single=False):
|
| 45 |
+
f_beta = 0
|
| 46 |
+
g_beta = 0
|
| 47 |
+
TP, FP, TN, FN = 0., 0., 0., 0.
|
| 48 |
+
Accuracy = 0
|
| 49 |
+
Precision = 0
|
| 50 |
+
Recall = 0
|
| 51 |
+
F1 = 0
|
| 52 |
+
|
| 53 |
+
if single: # if evaluating single class in case of threshold-optimization
|
| 54 |
+
sample_weights = np.ones(y_true.sum(axis=1).shape)
|
| 55 |
+
else:
|
| 56 |
+
sample_weights = y_true.sum(axis=1)
|
| 57 |
+
for classi in range(y_true.shape[1]):
|
| 58 |
+
y_truei, y_predi = y_true[:, classi], y_pred[:, classi]
|
| 59 |
+
TP, FP, TN, FN = 0., 0., 0., 0.
|
| 60 |
+
for i in range(len(y_predi)):
|
| 61 |
+
sample_weight = sample_weights[i]
|
| 62 |
+
if y_truei[i] == y_predi[i] == 1:
|
| 63 |
+
TP += 1. / sample_weight
|
| 64 |
+
if (y_predi[i] == 1) and (y_truei[i] != y_predi[i]):
|
| 65 |
+
FP += 1. / sample_weight
|
| 66 |
+
if y_truei[i] == y_predi[i] == 0:
|
| 67 |
+
TN += 1. / sample_weight
|
| 68 |
+
if (y_predi[i] == 0) and (y_truei[i] != y_predi[i]):
|
| 69 |
+
FN += 1. / sample_weight
|
| 70 |
+
f_beta_i = ((1 + beta1 ** 2) * TP) / ((1 + beta1 ** 2) * TP + FP + (beta1 ** 2) * FN)
|
| 71 |
+
g_beta_i = TP / (TP + FP + beta2 * FN)
|
| 72 |
+
|
| 73 |
+
f_beta += f_beta_i
|
| 74 |
+
g_beta += g_beta_i
|
| 75 |
+
|
| 76 |
+
Accuracy = (TP + TN) / (FP + TP + TN + FN)
|
| 77 |
+
# Precision = TP / (TP + FP)
|
| 78 |
+
# Recall = TP / (TP + FN)
|
| 79 |
+
# F1 = 2*(Precision * Recall) / (Precision + Recall)
|
| 80 |
+
F1 = 2 * TP / 2 * TP + FP + FN
|
| 81 |
+
|
| 82 |
+
return {'F_beta_macro': f_beta / y_true.shape[1], 'G_beta_macro': g_beta / y_true.shape[1], 'TP': TP, 'FP': FP,
|
| 83 |
+
'TN': TN, 'FN': FN, 'Accuracy': Accuracy, 'F1': F1, 'Precision': Precision, 'Recall': Recall}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_appropriate_bootstrap_samples(y_true, n_bootstraping_samples):
|
| 87 |
+
samples = []
|
| 88 |
+
while True:
|
| 89 |
+
ridxs = np.random.randint(0, len(y_true), len(y_true))
|
| 90 |
+
if y_true[ridxs].sum(axis=0).min() != 0:
|
| 91 |
+
samples.append(ridxs)
|
| 92 |
+
if len(samples) == n_bootstraping_samples:
|
| 93 |
+
break
|
| 94 |
+
return samples
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def find_optimal_cutoff_threshold(target, predicted):
|
| 98 |
+
"""
|
| 99 |
+
Find the optimal probability cutoff point for a classification model related to event rate
|
| 100 |
+
"""
|
| 101 |
+
fpr, tpr, threshold = roc_curve(target, predicted)
|
| 102 |
+
optimal_idx = np.argmax(tpr - fpr)
|
| 103 |
+
optimal_threshold = threshold[optimal_idx]
|
| 104 |
+
return optimal_threshold
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def find_optimal_cutoff_thresholds(y_true, y_pred):
|
| 108 |
+
return [find_optimal_cutoff_threshold(y_true[:, i], y_pred[:, i]) for i in range(y_true.shape[1])]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def find_optimal_cutoff_threshold_for_Gbeta(target, predicted, n_thresholds=100):
|
| 112 |
+
thresholds = np.linspace(0.00, 1, n_thresholds)
|
| 113 |
+
scores = [challenge_metrics(target, predicted > t, single=True)['G_beta_macro'] for t in thresholds]
|
| 114 |
+
optimal_idx = np.argmax(scores)
|
| 115 |
+
return thresholds[optimal_idx]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def find_optimal_cutoff_thresholds_for_Gbeta(y_true, y_pred):
|
| 119 |
+
print("optimize thresholds with respect to G_beta")
|
| 120 |
+
return [
|
| 121 |
+
find_optimal_cutoff_threshold_for_Gbeta(y_true[:, k][:, np.newaxis], y_pred[:, k][:, np.newaxis])
|
| 122 |
+
for k in tqdm(range(y_true.shape[1]))]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def apply_thresholds(preds, thresholds):
|
| 126 |
+
"""
|
| 127 |
+
apply class-wise thresholds to prediction score in order to get binary format.
|
| 128 |
+
BUT: if no score is above threshold, pick maximum. This is needed due to metric issues.
|
| 129 |
+
"""
|
| 130 |
+
tmp = []
|
| 131 |
+
for p in preds:
|
| 132 |
+
tmp_p = (p > thresholds).astype(int)
|
| 133 |
+
if np.sum(tmp_p) == 0:
|
| 134 |
+
tmp_p[np.argmax(p)] = 1
|
| 135 |
+
tmp.append(tmp_p)
|
| 136 |
+
tmp = np.array(tmp)
|
| 137 |
+
return tmp
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# DATA PROCESSING STUFF
|
| 141 |
+
def load_dataset(path, sampling_rate, release=False):
|
| 142 |
+
if path.split('/')[-2] == 'ptbxl':
|
| 143 |
+
# load and convert annotation data
|
| 144 |
+
Y = pd.read_csv(path + 'ptbxl_database.csv', index_col='ecg_id')
|
| 145 |
+
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
|
| 146 |
+
|
| 147 |
+
# Load raw signal data
|
| 148 |
+
X = load_raw_data_ptbxl(Y, sampling_rate, path)
|
| 149 |
+
|
| 150 |
+
elif path.split('/')[-2] == 'ICBEB':
|
| 151 |
+
# load and convert annotation data
|
| 152 |
+
Y = pd.read_csv(path + 'icbeb_database.csv', index_col='ecg_id')
|
| 153 |
+
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
|
| 154 |
+
|
| 155 |
+
# Load raw signal data
|
| 156 |
+
X = load_raw_data_icbeb(Y, sampling_rate, path)
|
| 157 |
+
|
| 158 |
+
return X, Y
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def load_raw_data_icbeb(df, sampling_rate, path):
|
| 162 |
+
if sampling_rate == 100:
|
| 163 |
+
if os.path.exists(path + 'raw100.npy'):
|
| 164 |
+
data = np.load(path + 'raw100.npy', allow_pickle=True)
|
| 165 |
+
else:
|
| 166 |
+
data = [wfdb.rdsamp(path + 'records100/' + str(f)) for f in tqdm(df.index)]
|
| 167 |
+
data = np.array([signal for signal, meta in data])
|
| 168 |
+
pickle.dump(data, open(path + 'raw100.npy', 'wb'), protocol=4)
|
| 169 |
+
elif sampling_rate == 500:
|
| 170 |
+
if os.path.exists(path + 'raw500.npy'):
|
| 171 |
+
data = np.load(path + 'raw500.npy', allow_pickle=True)
|
| 172 |
+
else:
|
| 173 |
+
data = [wfdb.rdsamp(path + 'records500/' + str(f)) for f in tqdm(df.index)]
|
| 174 |
+
data = np.array([signal for signal, meta in data])
|
| 175 |
+
pickle.dump(data, open(path + 'raw500.npy', 'wb'), protocol=4)
|
| 176 |
+
return data
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def load_raw_data_ptbxl(df, sampling_rate, path):
|
| 180 |
+
if sampling_rate == 100:
|
| 181 |
+
if os.path.exists(path + 'raw100.npy'):
|
| 182 |
+
data = np.load(path + 'raw100.npy', allow_pickle=True)
|
| 183 |
+
else:
|
| 184 |
+
data = [wfdb.rdsamp(path + f) for f in tqdm(df.filename_lr)]
|
| 185 |
+
data = np.array([signal for signal, meta in data])
|
| 186 |
+
pickle.dump(data, open(path + 'raw100.npy', 'wb'), protocol=4)
|
| 187 |
+
elif sampling_rate == 500:
|
| 188 |
+
if os.path.exists(path + 'raw500.npy'):
|
| 189 |
+
data = np.load(path + 'raw500.npy', allow_pickle=True)
|
| 190 |
+
else:
|
| 191 |
+
data = [wfdb.rdsamp(path + f) for f in tqdm(df.filename_hr)]
|
| 192 |
+
data = np.array([signal for signal, meta in data])
|
| 193 |
+
pickle.dump(data, open(path + 'raw500.npy', 'wb'), protocol=4)
|
| 194 |
+
return data
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def compute_label_aggregations(df, folder, ctype):
|
| 198 |
+
df['scp_codes_len'] = df.scp_codes.apply(lambda x: len(x))
|
| 199 |
+
|
| 200 |
+
aggregation_df = pd.read_csv(folder + 'scp_statements.csv', index_col=0)
|
| 201 |
+
|
| 202 |
+
if ctype in ['diagnostic', 'subdiagnostic', 'superdiagnostic']:
|
| 203 |
+
|
| 204 |
+
def aggregate_all_diagnostic(y_dic):
|
| 205 |
+
tmp = []
|
| 206 |
+
for key in y_dic.keys():
|
| 207 |
+
if key in diag_agg_df.index:
|
| 208 |
+
tmp.append(key)
|
| 209 |
+
return list(set(tmp))
|
| 210 |
+
|
| 211 |
+
def aggregate_subdiagnostic(y_dic):
|
| 212 |
+
tmp = []
|
| 213 |
+
for key in y_dic.keys():
|
| 214 |
+
if key in diag_agg_df.index:
|
| 215 |
+
c = diag_agg_df.loc[key].diagnostic_subclass
|
| 216 |
+
if str(c) != 'nan':
|
| 217 |
+
tmp.append(c)
|
| 218 |
+
return list(set(tmp))
|
| 219 |
+
|
| 220 |
+
def aggregate_diagnostic(y_dic):
|
| 221 |
+
tmp = []
|
| 222 |
+
for key in y_dic.keys():
|
| 223 |
+
if key in diag_agg_df.index:
|
| 224 |
+
c = diag_agg_df.loc[key].diagnostic_class
|
| 225 |
+
if str(c) != 'nan':
|
| 226 |
+
tmp.append(c)
|
| 227 |
+
return list(set(tmp))
|
| 228 |
+
|
| 229 |
+
diag_agg_df = aggregation_df[aggregation_df.diagnostic == 1.0]
|
| 230 |
+
if ctype == 'diagnostic':
|
| 231 |
+
df['diagnostic'] = df.scp_codes.apply(aggregate_all_diagnostic)
|
| 232 |
+
df['diagnostic_len'] = df.diagnostic.apply(lambda x: len(x))
|
| 233 |
+
elif ctype == 'subdiagnostic':
|
| 234 |
+
df['subdiagnostic'] = df.scp_codes.apply(aggregate_subdiagnostic)
|
| 235 |
+
df['subdiagnostic_len'] = df.subdiagnostic.apply(lambda x: len(x))
|
| 236 |
+
elif ctype == 'superdiagnostic':
|
| 237 |
+
df['superdiagnostic'] = df.scp_codes.apply(aggregate_diagnostic)
|
| 238 |
+
df['superdiagnostic_len'] = df.superdiagnostic.apply(lambda x: len(x))
|
| 239 |
+
elif ctype == 'form':
|
| 240 |
+
form_agg_df = aggregation_df[aggregation_df.form == 1.0]
|
| 241 |
+
|
| 242 |
+
def aggregate_form(y_dic):
|
| 243 |
+
tmp = []
|
| 244 |
+
for key in y_dic.keys():
|
| 245 |
+
if key in form_agg_df.index:
|
| 246 |
+
c = key
|
| 247 |
+
if str(c) != 'nan':
|
| 248 |
+
tmp.append(c)
|
| 249 |
+
return list(set(tmp))
|
| 250 |
+
|
| 251 |
+
df['form'] = df.scp_codes.apply(aggregate_form)
|
| 252 |
+
df['form_len'] = df.form.apply(lambda x: len(x))
|
| 253 |
+
elif ctype == 'rhythm':
|
| 254 |
+
rhythm_agg_df = aggregation_df[aggregation_df.rhythm == 1.0]
|
| 255 |
+
|
| 256 |
+
def aggregate_rhythm(y_dic):
|
| 257 |
+
tmp = []
|
| 258 |
+
for key in y_dic.keys():
|
| 259 |
+
if key in rhythm_agg_df.index:
|
| 260 |
+
c = key
|
| 261 |
+
if str(c) != 'nan':
|
| 262 |
+
tmp.append(c)
|
| 263 |
+
return list(set(tmp))
|
| 264 |
+
|
| 265 |
+
df['rhythm'] = df.scp_codes.apply(aggregate_rhythm)
|
| 266 |
+
df['rhythm_len'] = df.rhythm.apply(lambda x: len(x))
|
| 267 |
+
elif ctype == 'all':
|
| 268 |
+
df['all_scp'] = df.scp_codes.apply(lambda x: list(set(x.keys())))
|
| 269 |
+
|
| 270 |
+
return df
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def select_data(XX, YY, ctype, min_samples, output_folder):
|
| 274 |
+
# convert multi_label to multi-hot
|
| 275 |
+
mlb = MultiLabelBinarizer()
|
| 276 |
+
|
| 277 |
+
if ctype == 'diagnostic':
|
| 278 |
+
X = XX[YY.diagnostic_len > 0]
|
| 279 |
+
Y = YY[YY.diagnostic_len > 0]
|
| 280 |
+
mlb.fit(Y.diagnostic.values)
|
| 281 |
+
y = mlb.transform(Y.diagnostic.values)
|
| 282 |
+
elif ctype == 'subdiagnostic':
|
| 283 |
+
counts = pd.Series(np.concatenate(YY.subdiagnostic.values)).value_counts()
|
| 284 |
+
counts = counts[counts > min_samples]
|
| 285 |
+
YY.subdiagnostic = YY.subdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
|
| 286 |
+
YY['subdiagnostic_len'] = YY.subdiagnostic.apply(lambda x: len(x))
|
| 287 |
+
X = XX[YY.subdiagnostic_len > 0]
|
| 288 |
+
Y = YY[YY.subdiagnostic_len > 0]
|
| 289 |
+
mlb.fit(Y.subdiagnostic.values)
|
| 290 |
+
y = mlb.transform(Y.subdiagnostic.values)
|
| 291 |
+
elif ctype == 'superdiagnostic':
|
| 292 |
+
counts = pd.Series(np.concatenate(YY.superdiagnostic.values)).value_counts()
|
| 293 |
+
counts = counts[counts > min_samples]
|
| 294 |
+
YY.superdiagnostic = YY.superdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
|
| 295 |
+
YY['superdiagnostic_len'] = YY.superdiagnostic.apply(lambda x: len(x))
|
| 296 |
+
X = XX[YY.superdiagnostic_len > 0]
|
| 297 |
+
Y = YY[YY.superdiagnostic_len > 0]
|
| 298 |
+
mlb.fit(Y.superdiagnostic.values)
|
| 299 |
+
y = mlb.transform(Y.superdiagnostic.values)
|
| 300 |
+
elif ctype == 'form':
|
| 301 |
+
# filter
|
| 302 |
+
counts = pd.Series(np.concatenate(YY.form.values)).value_counts()
|
| 303 |
+
counts = counts[counts > min_samples]
|
| 304 |
+
YY.form = YY.form.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
|
| 305 |
+
YY['form_len'] = YY.form.apply(lambda x: len(x))
|
| 306 |
+
# select
|
| 307 |
+
X = XX[YY.form_len > 0]
|
| 308 |
+
Y = YY[YY.form_len > 0]
|
| 309 |
+
mlb.fit(Y.form.values)
|
| 310 |
+
y = mlb.transform(Y.form.values)
|
| 311 |
+
elif ctype == 'rhythm':
|
| 312 |
+
# filter
|
| 313 |
+
counts = pd.Series(np.concatenate(YY.rhythm.values)).value_counts()
|
| 314 |
+
counts = counts[counts > min_samples]
|
| 315 |
+
YY.rhythm = YY.rhythm.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
|
| 316 |
+
YY['rhythm_len'] = YY.rhythm.apply(lambda x: len(x))
|
| 317 |
+
# select
|
| 318 |
+
X = XX[YY.rhythm_len > 0]
|
| 319 |
+
Y = YY[YY.rhythm_len > 0]
|
| 320 |
+
mlb.fit(Y.rhythm.values)
|
| 321 |
+
y = mlb.transform(Y.rhythm.values)
|
| 322 |
+
elif ctype == 'all':
|
| 323 |
+
# filter
|
| 324 |
+
counts = pd.Series(np.concatenate(YY.all_scp.values)).value_counts()
|
| 325 |
+
counts = counts[counts > min_samples]
|
| 326 |
+
YY.all_scp = YY.all_scp.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
|
| 327 |
+
YY['all_scp_len'] = YY.all_scp.apply(lambda x: len(x))
|
| 328 |
+
# select
|
| 329 |
+
X = XX[YY.all_scp_len > 0]
|
| 330 |
+
Y = YY[YY.all_scp_len > 0]
|
| 331 |
+
mlb.fit(Y.all_scp.values)
|
| 332 |
+
y = mlb.transform(Y.all_scp.values)
|
| 333 |
+
else:
|
| 334 |
+
pass
|
| 335 |
+
|
| 336 |
+
# save Label_Binarizer
|
| 337 |
+
with open(output_folder + 'mlb.pkl', 'wb') as tokenizer:
|
| 338 |
+
pickle.dump(mlb, tokenizer)
|
| 339 |
+
|
| 340 |
+
return X, Y, y, mlb
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def preprocess_signals(X_train, X_validation, X_test, outputfolder):
|
| 344 |
+
# Standardize data such that mean 0 and variance 1
|
| 345 |
+
ss = StandardScaler()
|
| 346 |
+
ss.fit(np.vstack(X_train).flatten()[:, np.newaxis].astype(float))
|
| 347 |
+
|
| 348 |
+
# Save Standardize data
|
| 349 |
+
with open(outputfolder + 'standard_scaler.pkl', 'wb') as ss_file:
|
| 350 |
+
pickle.dump(ss, ss_file)
|
| 351 |
+
|
| 352 |
+
return apply_standardizer(X_train, ss), apply_standardizer(X_validation,
|
| 353 |
+
ss), apply_standardizer(
|
| 354 |
+
X_test, ss)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def apply_standardizer(X, ss):
|
| 358 |
+
X_tmp = []
|
| 359 |
+
for x in X:
|
| 360 |
+
x_shape = x.shape
|
| 361 |
+
X_tmp.append(ss.transform(x.flatten()[:, np.newaxis]).reshape(x_shape))
|
| 362 |
+
X_tmp = np.array(X_tmp)
|
| 363 |
+
return X_tmp
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# DOCUMENTATION STUFF
|
| 367 |
+
|
| 368 |
+
def generate_ptbxl_summary_table(selection=None, folder='/output/'):
|
| 369 |
+
exps = ['exp0', 'exp1', 'exp1.1', 'exp1.1.1', 'exp2', 'exp3']
|
| 370 |
+
metrics = ['macro_auc', 'Accuracy', 'TP', 'TN', 'FP', 'FN', 'Precision', 'Recall', 'F1']
|
| 371 |
+
# 0 1 2 3 4 5 6 7 8
|
| 372 |
+
|
| 373 |
+
# get models
|
| 374 |
+
models = {}
|
| 375 |
+
for i, exp in enumerate(exps):
|
| 376 |
+
if selection is None:
|
| 377 |
+
exp_models = [m.split('/')[-1] for m in glob.glob(folder + str(exp) + '/models/*')]
|
| 378 |
+
else:
|
| 379 |
+
exp_models = selection
|
| 380 |
+
if i == 0:
|
| 381 |
+
models = set(exp_models)
|
| 382 |
+
else:
|
| 383 |
+
models = models.union(set(exp_models))
|
| 384 |
+
|
| 385 |
+
results_dic = {'Method': [],
|
| 386 |
+
'exp0_macro_auc': [],
|
| 387 |
+
'exp1_macro_auc': [],
|
| 388 |
+
'exp1.1_macro_auc': [],
|
| 389 |
+
'exp1.1.1_macro_auc': [],
|
| 390 |
+
'exp2_macro_auc': [],
|
| 391 |
+
'exp3_macro_auc': [],
|
| 392 |
+
'exp0_Accuracy': [],
|
| 393 |
+
'exp1_Accuracy': [],
|
| 394 |
+
'exp1.1_Accuracy': [],
|
| 395 |
+
'exp1.1.1_Accuracy': [],
|
| 396 |
+
'exp2_Accuracy': [],
|
| 397 |
+
'exp3_Accuracy': [],
|
| 398 |
+
'exp0_F1': [],
|
| 399 |
+
'exp1_F1': [],
|
| 400 |
+
'exp1.1_F1': [],
|
| 401 |
+
'exp1.1.1_F1': [],
|
| 402 |
+
'exp2_F1': [],
|
| 403 |
+
'exp3_F1': [],
|
| 404 |
+
'exp0_Precision': [],
|
| 405 |
+
'exp1_Precision': [],
|
| 406 |
+
'exp1.1_Precision': [],
|
| 407 |
+
'exp1.1.1_Precision': [],
|
| 408 |
+
'exp2_Precision': [],
|
| 409 |
+
'exp3_Precision': [],
|
| 410 |
+
'exp0_Recall': [],
|
| 411 |
+
'exp1_Recall': [],
|
| 412 |
+
'exp1.1_Recall': [],
|
| 413 |
+
'exp1.1.1_Recall': [],
|
| 414 |
+
'exp2_Recall': [],
|
| 415 |
+
'exp3_Recall': [],
|
| 416 |
+
'exp0_TP': [],
|
| 417 |
+
'exp1_TP': [],
|
| 418 |
+
'exp1.1_TP': [],
|
| 419 |
+
'exp1.1.1_TP': [],
|
| 420 |
+
'exp2_TP': [],
|
| 421 |
+
'exp3_TP': [],
|
| 422 |
+
'exp0_TN': [],
|
| 423 |
+
'exp1_TN': [],
|
| 424 |
+
'exp1.1_TN': [],
|
| 425 |
+
'exp1.1.1_TN': [],
|
| 426 |
+
'exp2_TN': [],
|
| 427 |
+
'exp3_TN': [],
|
| 428 |
+
'exp0_FP': [],
|
| 429 |
+
'exp1_FP': [],
|
| 430 |
+
'exp1.1_FP': [],
|
| 431 |
+
'exp1.1.1_FP': [],
|
| 432 |
+
'exp2_FP': [],
|
| 433 |
+
'exp3_FP': [],
|
| 434 |
+
'exp0_FN': [],
|
| 435 |
+
'exp1_FN': [],
|
| 436 |
+
'exp1.1_FN': [],
|
| 437 |
+
'exp1.1.1_FN': [],
|
| 438 |
+
'exp2_FN': [],
|
| 439 |
+
'exp3_FN': []
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
for m in models:
|
| 443 |
+
results_dic['Method'].append(m)
|
| 444 |
+
|
| 445 |
+
for e in exps:
|
| 446 |
+
|
| 447 |
+
try:
|
| 448 |
+
me_res = pd.read_csv(folder + str(e) + '/models/' + str(m) + '/results/te_results.csv', index_col=0)
|
| 449 |
+
|
| 450 |
+
mean1 = me_res.loc['point'][metrics[0]]
|
| 451 |
+
unc1 = max(me_res.loc['upper'][metrics[0]] - me_res.loc['point'][metrics[0]],
|
| 452 |
+
me_res.loc['point'][metrics[0]] - me_res.loc['lower'][metrics[0]])
|
| 453 |
+
|
| 454 |
+
acc = me_res.loc['point'][metrics[1]]
|
| 455 |
+
f1 = me_res.loc['point'][metrics[8]]
|
| 456 |
+
precision = me_res.loc['point'][metrics[6]]
|
| 457 |
+
recall = me_res.loc['point'][metrics[7]]
|
| 458 |
+
tp = me_res.loc['point'][metrics[2]]
|
| 459 |
+
tn = me_res.loc['point'][metrics[3]]
|
| 460 |
+
fp = me_res.loc['point'][metrics[4]]
|
| 461 |
+
fn = me_res.loc['point'][metrics[5]]
|
| 462 |
+
|
| 463 |
+
results_dic[e + '_macro_auc'].append("%.3f(%.2d)" % (np.round(mean1, 3), int(unc1 * 1000)))
|
| 464 |
+
results_dic[e + '_Accuracy'].append("%.3f" % acc)
|
| 465 |
+
results_dic[e + '_F1'].append("%.3f" % f1)
|
| 466 |
+
results_dic[e + '_Precision'].append("%.3f" % precision)
|
| 467 |
+
results_dic[e + '_Recall'].append("%.3f" % recall)
|
| 468 |
+
results_dic[e + '_TP'].append("%.3f" % tp)
|
| 469 |
+
results_dic[e + '_TN'].append("%.3f" % tn)
|
| 470 |
+
results_dic[e + '_FP'].append("%.3f" % fp)
|
| 471 |
+
results_dic[e + '_FN'].append("%.3f" % fn)
|
| 472 |
+
|
| 473 |
+
except FileNotFoundError:
|
| 474 |
+
results_dic[e + '_macro_auc'].append("--")
|
| 475 |
+
results_dic[e + '_Accuracy'].append("--")
|
| 476 |
+
results_dic[e + '_F1'].append("--")
|
| 477 |
+
results_dic[e + '_Precision'].append("--")
|
| 478 |
+
results_dic[e + '_Recall'].append("--")
|
| 479 |
+
results_dic[e + '_TP'].append("--")
|
| 480 |
+
results_dic[e + '_TN'].append("--")
|
| 481 |
+
results_dic[e + '_FP'].append("--")
|
| 482 |
+
results_dic[e + '_FN'].append("--")
|
| 483 |
+
|
| 484 |
+
df = pd.DataFrame(results_dic)
|
| 485 |
+
df_index = df[df.Method.isin(['naive', 'ensemble'])]
|
| 486 |
+
df_rest = df[~df.Method.isin(['naive', 'ensemble'])]
|
| 487 |
+
df = pd.concat([df_rest, df_index])
|
| 488 |
+
df.to_csv(folder + 'results_ptbxl.csv')
|
| 489 |
+
|
| 490 |
+
titles = [
|
| 491 |
+
'### 1. PTB-XL: all statements',
|
| 492 |
+
'### 2. PTB-XL: diagnostic statements',
|
| 493 |
+
'### 3. PTB-XL: Diagnostic subclasses',
|
| 494 |
+
'### 4. PTB-XL: Diagnostic superclasses',
|
| 495 |
+
'### 5. PTB-XL: Form statements',
|
| 496 |
+
'### 6. PTB-XL: Rhythm statements'
|
| 497 |
+
]
|
| 498 |
+
|
| 499 |
+
# helper output function for markdown tables
|
| 500 |
+
our_work = 'https://arxiv.org/abs/2004.13701'
|
| 501 |
+
our_repo = 'https://github.com/helme/ecg_ptbxl_benchmarking/'
|
| 502 |
+
md_source = ''
|
| 503 |
+
for i, e in enumerate(exps):
|
| 504 |
+
md_source += '\n ' + titles[i] + ' \n \n'
|
| 505 |
+
md_source += '| Model | AUC |\n'
|
| 506 |
+
|
| 507 |
+
for row in df_rest[['Method', e + '_AUC']].sort_values(e + '_AUC', ascending=False).values:
|
| 508 |
+
md_source += '| ' + row[0].replace('fastai_', '') + ' | ' + row[1] + ' |\n'
|
| 509 |
+
print(md_source)
|