cuongnx2001 commited on
Commit
264b4c4
·
verified ·
1 Parent(s): fa68e3c

Upload 34 files

Browse files
README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Automated ECG Interpretation
2
+
3
+ [![Contributors][contributors-shield]][contributors-url]
4
+ [![GitHub forks](https://img.shields.io/github/forks/AutoECG/Automated-ECG-Interpretation?color=lightgray&style=flat-square)](https://github.com/AutoECG/Automated-ECG-Interpretation/network)
5
+ [![GitHub stars](https://img.shields.io/github/stars/AutoECG/Automated-ECG-Interpretation?color=yellow&style=flat-square)](https://github.com/AutoECG/Automated-ECG-Interpretation/stargazers)
6
+ [![GitHub issues](https://img.shields.io/github/issues/AutoECG/Automated-ECG-Interpretation?color=red&style=flat-square)](https://github.com/AutoECG/Automated-ECG-Interpretation/issues)
7
+
8
+ <br>
9
+
10
+ <div align="center">
11
+ <img src="https://user-images.githubusercontent.com/46399191/191921241-495090db-a088-46b6-bd09-0f7f21170b0a.png" height="350"/>
12
+ </div>
13
+
14
+ ## Summary
15
+
16
+ Electrocardiography (ECG) is a key diagnostic tool to assess the cardiac condition of a patient. Automatic ECG interpretation algorithms as diagnosis support systems promise large reliefs for the medical personnel - only based on the number of ECGs that are routinely taken. However, the development of such algorithms requires large training datasets and clear benchmark procedures.
17
+
18
+ ## Data Description
19
+
20
+ The [PTB-XL ECG dataset](https://physionet.org/content/ptb-xl/1.0.1/) is a large dataset of 21837 clinical 12-lead ECGs from 18885 patients of 10 second length. The raw waveform data was annotated by up to two cardiologists, who assigned potentially multiple ECG statements to each record. In total 71 different ECG statements conform to the SCP-ECG standard and cover diagnostic, form, and rhythm statements. Combined with the extensive annotation, this turns the dataset into a rich resource for training and evaluating automatic ECG interpretation algorithms. The dataset is complemented by extensive metadata on demographics, infarction characteristics, likelihoods for diagnostic ECG statements, and annotated signal properties.
21
+
22
+ In general, the dataset is organized as follows:
23
+
24
+ ```
25
+ ptbxl
26
+ ├── ptbxl_database.csv
27
+ ├── scp_statements.csv
28
+ ├── records100
29
+ ├── 00000
30
+ │ │ ├── 00001_lr.dat
31
+ │ │ ├── 00001_lr.hea
32
+ │ │ ├── ...
33
+ │ │ ├── 00999_lr.dat
34
+ │ │ └── 00999_lr.hea
35
+ │ ├── ...
36
+ │ └── 21000
37
+ │ ├── 21001_lr.dat
38
+ │ ├── 21001_lr.hea
39
+ │ ├── ...
40
+ │ ├── 21837_lr.dat
41
+ │ └── 21837_lr.hea
42
+ └── records500
43
+ ├── 00000
44
+ │ ├── 00001_hr.dat
45
+ │ ├── 00001_hr.hea
46
+ │ ├── ...
47
+ │ ├── 00999_hr.dat
48
+ │ └── 00999_hr.hea
49
+ ├── ...
50
+ └── 21000
51
+ ├── 21001_hr.dat
52
+ ├── 21001_hr.hea
53
+ ├── ...
54
+ ├── 21837_hr.dat
55
+ └── 21837_hr.hea
56
+ ```
57
+
58
+ The dataset comprises 21837 clinical 12-lead ECG records of 10 seconds length from 18885 patients, where 52% are male and 48% are female with ages covering the whole range from 0 to 95 years (median 62 and interquantile range of 22). The value of the dataset results from the comprehensive collection of many different co-occurring pathologies, but also from a large proportion of healthy control samples.
59
+
60
+ | Records | Superclass | Description |
61
+ |:---|:---|:---|
62
+ 9528 | NORM | Normal ECG |
63
+ 5486 | MI | Myocardial Infarction |
64
+ 5250 | STTC | ST/T Change |
65
+ 4907 | CD | Conduction Disturbance |
66
+ 2655 | HYP | Hypertrophy |
67
+
68
+ The waveform files are stored in WaveForm DataBase (WFDB) format with 16-bit precision at a resolution of 1μV/LSB and a sampling frequency of 500Hz (records500/) beside downsampled versions of the waveform data at a sampling frequency of 100Hz (records100/).
69
+
70
+ All relevant metadata is stored in ptbxldatabase.csv with one row per record identified by ecgid and it contains 28 columns.
71
+
72
+ All information related to the used annotation scheme is stored in a dedicated scp_statements.csv that was enriched with mappings to other annotation standards.
73
+
74
+ ## Setup
75
+
76
+ ### Install dependencies
77
+ Install the dependencies (wfdb, pytorch, torchvision, cudatoolkit, fastai, fastprogress) by creating a conda environment:
78
+
79
+ conda env create -f requirements.yml
80
+ conda activate autoecg_env
81
+
82
+ ### Get data
83
+ Download the dataset (PTB-XL) via the follwing bash-script:
84
+
85
+ get_dataset.sh
86
+
87
+ This script first downloads [PTB-XL from PhysioNet](https://physionet.org/content/ptb-xl/) and stores it in `data/ptbxl/`.
88
+
89
+ ## Usage
90
+
91
+ python main.py
92
+
93
+ This will perform all experiments for inception1d.
94
+ Depending on the executing environment, this will take up to several hours.
95
+ Once finished, all trained models, predictions and results are stored in `output/`,
96
+ where for each experiment a sub-folder is created each with `data/`, `models/` and `results/` sub-sub-folders.
97
+
98
+ | Model | AUC &darr; | Experiment |
99
+ |:---|:---|:---|
100
+ | inception1d | 0.927(00) | All statements |
101
+ | inception1d | 0.929(00) | Diagnostic statements |
102
+ | inception1d | 0.926(00) | Diagnostic subclasses |
103
+ | inception1d | 0.919(00) | Diagnostic superclasses |
104
+ | inception1d | 0.883(00) | Form statements |
105
+ | inception1d | 0.949(00) | Rhythm statements |
106
+
107
+ ### Download model and results
108
+
109
+ We also provide a [compressed zip-archive](https://drive.google.com/drive/folders/17za6IanRm7rpb1ZGHLQ80mJvBj_53LXJ?usp=sharing) containing the `output` folder corresponding to our runs including trained model and predictions.
110
+
111
+ ## Results for Inception1d Model
112
+
113
+ | Experiment name | Accuracy | Precision | Recall | F1_Score | Specificity |
114
+ | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
115
+ | All | 0.9792 | 0.8949 | 0.1408 | 0.4824 | 0.9921 |
116
+ | Diagnostic | 0.9806 | 0.8440 | 0.1556 | 0.4746 | 0.9952 |
117
+ | Sub-Diagnostic | 0.9660 | 0.8315 | 0.3021 | 0.5119 | 0.9887 |
118
+ | Super-Diagnostic | 0.8847 | 0.7938 | 0.6757 | 0.7157 | 0.9251 |
119
+ | Form | 0.9452 | 0.5619 | 0.1420 | 0.3843 | 0.9916 |
120
+ | Rhythm | 0.9844 | 0.7676 | 0.4489 | 0.7290 | 0.9722 |
121
+
122
+ For more evaluation (Confusion Matrix, ROC curve) information and visualizations visit: [Model Evaluation](https://github.com/AutoECG/Automated-ECG-Interpretation/blob/main/evaluation/Model_Evaluation.ipynb)
123
+
124
+ ## Contribution
125
+
126
+ Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**.
127
+
128
+ If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement".
129
+ Don't forget to give the project a star! Thanks again!
130
+
131
+ 1. [Fork the Project](https://github.com/AutoECG/Automated-ECG-Interpretation/fork)
132
+ 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`)
133
+ 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`)
134
+ 4. Push to the Branch (`git push origin feature/AmazingFeature`)
135
+ 5. Open a Pull Request
136
+
137
+ ## Future Works
138
+
139
+ 1. Model Deployment.
140
+ 2. Continue Preprocessing new ECG data from hospitals to test model reliability and accuracy.
141
+ 3. Figure out different parsing options for xml ecg files from different ECG machines versions.
142
+
143
+
144
+ ## Contact
145
+
146
+ Feel free to reach out to us:
147
+ - DM [Zaki Kurdya](https://twitter.com/ZakiKurdya)
148
+ - DM [Zeina Saadeddin](https://twitter.com/jszeina)
149
+ - DM [Salam Thabit](https://twitter.com/salamThabetDo)
150
+
151
+ <!-- MARKDOWN LINKS -->
152
+ [contributors-shield]: https://img.shields.io/github/contributors/AutoECG/Automated-ECG-Interpretation.svg?style=flat-square&color=blue
153
+ [contributors-url]: https://github.com/AutoECG/Automated-ECG-Interpretation/graphs/contributors
configurations/__pycache__/fastai_configs.cpython-310.pyc ADDED
Binary file (3.67 kB). View file
 
configurations/__pycache__/fastai_configs.cpython-39.pyc ADDED
Binary file (3.66 kB). View file
 
configurations/fastai_configs.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conf_fastai_resnet1d18 = {'model_name': 'fastai_resnet1d18', 'model_type': 'FastaiModel',
2
+ 'parameters': dict()}
3
+
4
+ conf_fastai_resnet1d34 = {'model_name': 'fastai_resnet1d34', 'model_type': 'FastaiModel',
5
+ 'parameters': dict()}
6
+
7
+ conf_fastai_resnet1d50 = {'model_name': 'fastai_resnet1d50', 'model_type': 'FastaiModel',
8
+ 'parameters': dict()}
9
+
10
+ conf_fastai_resnet1d101 = {'model_name': 'fastai_resnet1d101', 'model_type': 'FastaiModel',
11
+ 'parameters': dict()}
12
+
13
+ conf_fastai_resnet1d152 = {'model_name': 'fastai_resnet1d152', 'model_type': 'FastaiModel',
14
+ 'parameters': dict()}
15
+
16
+ conf_fastai_resnet1d_wang = {'model_name': 'fastai_resnet1d_wang', 'model_type': 'FastaiModel',
17
+ 'parameters': dict()}
18
+
19
+ conf_fastai_wrn1d_22 = {'model_name': 'fastai_wrn1d_22', 'model_type': 'FastaiModel',
20
+ 'parameters': dict()}
21
+
22
+ conf_fastai_xresnet1d18 = {'model_name': 'fastai_xresnet1d18', 'model_type': 'FastaiModel',
23
+ 'parameters': dict()}
24
+
25
+ conf_fastai_xresnet1d34 = {'model_name': 'fastai_xresnet1d34', 'model_type': 'FastaiModel',
26
+ 'parameters': dict()}
27
+
28
+ conf_fastai_xresnet1d50 = {'model_name': 'fastai_xresnet1d50', 'model_type': 'FastaiModel',
29
+ 'parameters': dict()}
30
+
31
+ # more xresnet50s
32
+ conf_fastai_xresnet1d50_ep30 = {'model_name': 'fastai_xresnet1d50_ep30', 'model_type': 'FastaiModel',
33
+ 'parameters': dict(epochs=30)}
34
+
35
+ conf_fastai_xresnet1d50_validloss_ep30 = {'model_name': 'fastai_xresnet1d50_validloss_ep30',
36
+ 'model_type': 'FastaiModel',
37
+ 'parameters': dict(early_stopping="valid_loss", epochs=30)}
38
+
39
+ conf_fastai_xresnet1d50_macroauc_ep30 = {'model_name': 'fastai_xresnet1d50_macroauc_ep30', 'model_type': 'FastaiModel',
40
+ 'parameters': dict(early_stopping="macro_auc", epochs=30)}
41
+
42
+ conf_fastai_xresnet1d50_fmax_ep30 = {'model_name': 'fastai_xresnet1d50_fmax_ep30', 'model_type': 'FastaiModel',
43
+ 'parameters': dict(early_stopping="fmax", epochs=30)}
44
+
45
+ conf_fastai_xresnet1d50_ep50 = {'model_name': 'fastai_xresnet1d50_ep50', 'model_type': 'FastaiModel',
46
+ 'parameters': dict(epochs=50)}
47
+
48
+ conf_fastai_xresnet1d50_validloss_ep50 = {'model_name': 'fastai_xresnet1d50_validloss_ep50',
49
+ 'model_type': 'FastaiModel',
50
+ 'parameters': dict(early_stopping="valid_loss", epochs=50)}
51
+
52
+ conf_fastai_xresnet1d50_macroauc_ep50 = {'model_name': 'fastai_xresnet1d50_macroauc_ep50', 'model_type': 'FastaiModel',
53
+ 'parameters': dict(early_stopping="macro_auc", epochs=50)}
54
+
55
+ conf_fastai_xresnet1d50_fmax_ep50 = {'model_name': 'fastai_xresnet1d50_fmax_ep50', 'model_type': 'FastaiModel',
56
+ 'parameters': dict(early_stopping="fmax", epochs=50)}
57
+
58
+ conf_fastai_xresnet1d101 = {'model_name': 'fastai_xresnet1d101', 'model_type': 'FastaiModel',
59
+ 'parameters': dict()}
60
+
61
+ conf_fastai_xresnet1d152 = {'model_name': 'fastai_xresnet1d152', 'model_type': 'FastaiModel',
62
+ 'parameters': dict()}
63
+
64
+ conf_fastai_xresnet1d18_deep = {'model_name': 'fastai_xresnet1d18_deep', 'model_type': 'FastaiModel',
65
+ 'parameters': dict()}
66
+
67
+ conf_fastai_xresnet1d34_deep = {'model_name': 'fastai_xresnet1d34_deep', 'model_type': 'FastaiModel',
68
+ 'parameters': dict()}
69
+
70
+ conf_fastai_xresnet1d50_deep = {'model_name': 'fastai_xresnet1d50_deep', 'model_type': 'FastaiModel',
71
+ 'parameters': dict()}
72
+
73
+ conf_fastai_xresnet1d18_deeper = {'model_name': 'fastai_xresnet1d18_deeper', 'model_type': 'FastaiModel',
74
+ 'parameters': dict()}
75
+
76
+ conf_fastai_xresnet1d34_deeper = {'model_name': 'fastai_xresnet1d34_deeper', 'model_type': 'FastaiModel',
77
+ 'parameters': dict()}
78
+
79
+ conf_fastai_xresnet1d50_deeper = {'model_name': 'fastai_xresnet1d50_deeper', 'model_type': 'FastaiModel',
80
+ 'parameters': dict()}
81
+
82
+ conf_fastai_inception1d = {'model_name': 'fastai_inception1d', 'model_type': 'FastaiModel',
83
+ 'parameters': dict()}
84
+
85
+ conf_fastai_inception1d_input256 = {'model_name': 'fastai_inception1d_input256', 'model_type': 'FastaiModel',
86
+ 'parameters': dict(input_size=256)}
87
+
88
+ conf_fastai_inception1d_input512 = {'model_name': 'fastai_inception1d_input512', 'model_type': 'FastaiModel',
89
+ 'parameters': dict(input_size=512)}
90
+
91
+ conf_fastai_inception1d_input1000 = {'model_name': 'fastai_inception1d_input1000', 'model_type': 'FastaiModel',
92
+ 'parameters': dict(input_size=1000)}
93
+
94
+ conf_fastai_inception1d_no_residual = {'model_name': 'fastai_inception1d_no_residual', 'model_type': 'FastaiModel',
95
+ 'parameters': dict()}
96
+
97
+ conf_fastai_fcn = {'model_name': 'fastai_fcn', 'model_type': 'FastaiModel',
98
+ 'parameters': dict()}
99
+
100
+ conf_fastai_fcn_wang = {'model_name': 'fastai_fcn_wang', 'model_type': 'FastaiModel',
101
+ 'parameters': dict()}
102
+
103
+ conf_fastai_schirrmeister = {'model_name': 'fastai_schirrmeister', 'model_type': 'FastaiModel',
104
+ 'parameters': dict()}
105
+
106
+ conf_fastai_sen = {'model_name': 'fastai_sen', 'model_type': 'FastaiModel',
107
+ 'parameters': dict()}
108
+
109
+ conf_fastai_basic1d = {'model_name': 'fastai_basic1d', 'model_type': 'FastaiModel',
110
+ 'parameters': dict()}
111
+
112
+ conf_fastai_lstm = {'model_name': 'fastai_lstm', 'model_type': 'FastaiModel',
113
+ 'parameters': dict(lr=1e-3)}
114
+
115
+ conf_fastai_gru = {'model_name': 'fastai_gru', 'model_type': 'FastaiModel',
116
+ 'parameters': dict(lr=1e-3)}
117
+
118
+ conf_fastai_lstm_bidir = {'model_name': 'fastai_lstm_bidir', 'model_type': 'FastaiModel',
119
+ 'parameters': dict(lr=1e-3)}
120
+
121
+ conf_fastai_gru_bidir = {'model_name': 'fastai_gru', 'model_type': 'FastaiModel',
122
+ 'parameters': dict(lr=1e-3)}
123
+
124
+ conf_fastai_lstm_input1000 = {'model_name': 'fastai_lstm_input1000', 'model_type': 'FastaiModel',
125
+ 'parameters': dict(input_size=1000, lr=1e-3)}
126
+
127
+ conf_fastai_gru_input1000 = {'model_name': 'fastai_gru_input1000', 'model_type': 'FastaiModel',
128
+ 'parameters': dict(input_size=1000, lr=1e-3)}
129
+
130
+ conf_fastai_schirrmeister_input500 = {'model_name': 'fastai_schirrmeister_input500', 'model_type': 'FastaiModel',
131
+ 'parameters': dict(input_size=500)}
132
+
133
+ conf_fastai_inception1d_input500 = {'model_name': 'fastai_inception1d_input500', 'model_type': 'FastaiModel',
134
+ 'parameters': dict(input_size=500)}
135
+
136
+ conf_fastai_fcn_wang_input500 = {'model_name': 'fastai_fcn_wang_input500', 'model_type': 'FastaiModel',
137
+ 'parameters': dict(input_size=500)}
configurations/wavelet_configs.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conf_wavelet_standard_lr = {'model_name': 'Wavelet+LR', 'model_type': 'WAVELET',
2
+ 'parameters': dict(
3
+ regularizer_C=.001,
4
+ classifier='LR'
5
+ )}
6
+
7
+ conf_wavelet_standard_rf = {'model_name': 'Wavelet+RF', 'model_type': 'WAVELET',
8
+ 'parameters': dict(
9
+ regularizer_C=.001,
10
+ classifier='RF'
11
+ )}
12
+
13
+ conf_wavelet_standard_nn = {'model_name': 'Wavelet+NN', 'model_type': 'WAVELET',
14
+ 'parameters': dict(
15
+ regularizer_C=.001,
16
+ classifier='NN'
17
+ )}
evaluation/Model_Evaluation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiments/__pycache__/scp_experiment.cpython-310.pyc ADDED
Binary file (5.81 kB). View file
 
experiments/__pycache__/scp_experiment.cpython-39.pyc ADDED
Binary file (5.85 kB). View file
 
experiments/scp_experiment.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ from itertools import repeat
3
+
4
+ from models import fastaiModel
5
+ from models.wavelet import WaveletModel
6
+ from utilities.utils import *
7
+
8
+
9
+ class SCPExperiment:
10
+ """
11
+ Experiment on SCP-ECG statements.
12
+ All experiments based on SCP are performed and evaluated the same way.
13
+ """
14
+
15
+ def __init__(self, experiment_name, task, data_folder, output_folder, models,
16
+ sampling_frequency=100, min_samples=0, train_fold=8, val_fold=9,
17
+ test_fold=10, folds_type='strat'):
18
+ self.models = models
19
+ self.min_samples = min_samples
20
+ self.task = task
21
+ self.train_fold = train_fold
22
+ self.val_fold = val_fold
23
+ self.test_fold = test_fold
24
+ self.folds_type = folds_type
25
+ self.experiment_name = experiment_name
26
+ self.output_folder = output_folder
27
+ self.data_folder = data_folder
28
+ self.sampling_frequency = sampling_frequency
29
+
30
+ # create folder structure if needed
31
+ if not os.path.exists(self.output_folder + self.experiment_name):
32
+ os.makedirs(self.output_folder + self.experiment_name)
33
+ if not os.path.exists(self.output_folder + self.experiment_name + '/results/'):
34
+ os.makedirs(self.output_folder + self.experiment_name + '/results/')
35
+ if not os.path.exists(output_folder + self.experiment_name + '/models/'):
36
+ os.makedirs(self.output_folder + self.experiment_name + '/models/')
37
+ if not os.path.exists(output_folder + self.experiment_name + '/data/'):
38
+ os.makedirs(self.output_folder + self.experiment_name + '/data/')
39
+
40
+ def prepare(self):
41
+ # Load PTB-XL data
42
+ self.data, self.raw_labels = load_dataset(self.data_folder, self.sampling_frequency)
43
+
44
+ # Preprocess label data
45
+ self.labels = compute_label_aggregations(self.raw_labels, self.data_folder, self.task)
46
+
47
+ # Select relevant data and convert to one-hot
48
+ self.data, self.labels, self.Y, _ = select_data(self.data, self.labels, self.task, self.min_samples,
49
+ self.output_folder + self.experiment_name + '/data/')
50
+ self.input_shape = self.data[0].shape
51
+
52
+ # 10th fold for testing (9th for now)
53
+ self.X_test = self.data[self.labels.strat_fold == self.test_fold]
54
+ self.y_test = self.Y[self.labels.strat_fold == self.test_fold]
55
+ # 9th fold for validation (8th for now)
56
+ self.X_val = self.data[self.labels.strat_fold == self.val_fold]
57
+ self.y_val = self.Y[self.labels.strat_fold == self.val_fold]
58
+ # rest for training
59
+ self.X_train = self.data[self.labels.strat_fold <= self.train_fold]
60
+ self.y_train = self.Y[self.labels.strat_fold <= self.train_fold]
61
+
62
+ # Preprocess signal data
63
+ self.X_train, self.X_val, self.X_test = preprocess_signals(self.X_train, self.X_val, self.X_test,
64
+ self.output_folder + self.experiment_name + '/data/')
65
+ self.n_classes = self.y_train.shape[1]
66
+
67
+ # save train and test labels
68
+ self.y_train.dump(self.output_folder + self.experiment_name + '/data/y_train.npy')
69
+ self.y_val.dump(self.output_folder + self.experiment_name + '/data/y_val.npy')
70
+ self.y_test.dump(self.output_folder + self.experiment_name + '/data/y_test.npy')
71
+
72
+ model_name = 'naive'
73
+ # create most naive predictions via simple mean in training
74
+ mpath = self.output_folder + self.experiment_name + '/models/' + model_name + '/'
75
+ # create folder for model outputs
76
+ if not os.path.exists(mpath):
77
+ os.makedirs(mpath)
78
+ if not os.path.exists(mpath + 'results/'):
79
+ os.makedirs(mpath + 'results/')
80
+
81
+ mean_y = np.mean(self.y_train, axis=0)
82
+ np.array([mean_y] * len(self.y_train)).dump(mpath + 'y_train_pred.npy')
83
+ np.array([mean_y] * len(self.y_test)).dump(mpath + 'y_test_pred.npy')
84
+ np.array([mean_y] * len(self.y_val)).dump(mpath + 'y_val_pred.npy')
85
+
86
+ def perform(self):
87
+ for model_description in self.models:
88
+ model_name = model_description['model_name']
89
+ model_type = model_description['model_type']
90
+ model_params = model_description['parameters']
91
+
92
+ mpath = self.output_folder + self.experiment_name + '/models/' + model_name + '/'
93
+ # create folder for model outputs
94
+ if not os.path.exists(mpath):
95
+ os.makedirs(mpath)
96
+ if not os.path.exists(mpath + 'results/'):
97
+ os.makedirs(mpath + 'results/')
98
+
99
+ n_classes = self.Y.shape[1]
100
+ # load respective model
101
+ if model_type == 'WAVELET':
102
+ model = WaveletModel(model_name, n_classes, self.sampling_frequency, mpath, self.input_shape,
103
+ **model_params)
104
+ elif model_type == "FastaiModel":
105
+ model = fastaiModel.FastaiModel(model_name, n_classes, self.sampling_frequency, mpath, self.input_shape,
106
+ **model_params)
107
+ else:
108
+ assert True
109
+ break
110
+ # Print to check
111
+ print("Shape of input", self.X_train.shape)
112
+ # fit model
113
+ model.fit(self.X_train, self.y_train, self.X_val, self.y_val)
114
+ # predict and dump
115
+ model.predict(self.X_train).dump(mpath + 'y_train_pred.npy')
116
+ model.predict(self.X_val).dump(mpath + 'y_val_pred.npy')
117
+ model.predict(self.X_test).dump(mpath + 'y_test_pred.npy')
118
+
119
+ model_name = 'ensemble'
120
+ # create ensemble predictions via simple mean across model predictions (except naive predictions)
121
+ ensemblepath = self.output_folder + self.experiment_name + '/models/' + model_name + '/'
122
+ # create folder for model outputs
123
+ if not os.path.exists(ensemblepath):
124
+ os.makedirs(ensemblepath)
125
+ if not os.path.exists(ensemblepath + 'results/'):
126
+ os.makedirs(ensemblepath + 'results/')
127
+ # load all predictions
128
+ ensemble_train, ensemble_val, ensemble_test = [], [], []
129
+ for model_description in os.listdir(self.output_folder + self.experiment_name + '/models/'):
130
+ if not model_description in ['ensemble', 'naive']:
131
+ mpath = self.output_folder + self.experiment_name + '/models/' + model_description + '/'
132
+ ensemble_train.append(np.load(mpath + 'y_train_pred.npy', allow_pickle=True))
133
+ ensemble_val.append(np.load(mpath + 'y_val_pred.npy', allow_pickle=True))
134
+ ensemble_test.append(np.load(mpath + 'y_test_pred.npy', allow_pickle=True))
135
+ # dump mean predictions
136
+ np.array(ensemble_train).mean(axis=0).dump(ensemblepath + 'y_train_pred.npy')
137
+ np.array(ensemble_test).mean(axis=0).dump(ensemblepath + 'y_test_pred.npy')
138
+ np.array(ensemble_val).mean(axis=0).dump(ensemblepath + 'y_val_pred.npy')
139
+
140
+ def evaluate(self, n_bootstraping_samples=100, n_jobs=20, bootstrap_eval=False, dumped_bootstraps=True):
141
+ # get labels
142
+ global train_samples, val_samples
143
+ y_train = np.load(self.output_folder + self.experiment_name + '/data/y_train.npy', allow_pickle=True)
144
+ y_val = np.load(self.output_folder + self.experiment_name + '/data/y_val.npy', allow_pickle=True)
145
+ y_test = np.load(self.output_folder + self.experiment_name + '/data/y_test.npy', allow_pickle=True)
146
+
147
+ # if bootstrapping then generate appropriate samples for each
148
+ if bootstrap_eval:
149
+ if not dumped_bootstraps:
150
+ train_samples = np.array(get_appropriate_bootstrap_samples(y_train, n_bootstraping_samples))
151
+ test_samples = np.array(get_appropriate_bootstrap_samples(y_test, n_bootstraping_samples))
152
+ val_samples = np.array(get_appropriate_bootstrap_samples(y_val, n_bootstraping_samples))
153
+ else:
154
+ test_samples = np.load(self.output_folder + self.experiment_name + '/test_bootstrap_ids.npy',
155
+ allow_pickle=True)
156
+ else:
157
+ train_samples = np.array([range(len(y_train))])
158
+ test_samples = np.array([range(len(y_test))])
159
+ val_samples = np.array([range(len(y_val))])
160
+
161
+ # store samples for future evaluations
162
+ train_samples.dump(self.output_folder + self.experiment_name + '/train_bootstrap_ids.npy')
163
+ test_samples.dump(self.output_folder + self.experiment_name + '/test_bootstrap_ids.npy')
164
+ val_samples.dump(self.output_folder + self.experiment_name + '/val_bootstrap_ids.npy')
165
+
166
+ # iterate over all models fitted so far
167
+ for m in sorted(os.listdir(self.output_folder + self.experiment_name + '/models')):
168
+ print(m)
169
+ mpath = self.output_folder + self.experiment_name + '/models/' + m + '/'
170
+ rpath = self.output_folder + self.experiment_name + '/models/' + m + '/results/'
171
+
172
+ # load predictions
173
+ y_train_pred = np.load(mpath + 'y_train_pred.npy', allow_pickle=True)
174
+ y_val_pred = np.load(mpath + 'y_val_pred.npy', allow_pickle=True)
175
+ y_test_pred = np.load(mpath + 'y_test_pred.npy', allow_pickle=True)
176
+
177
+ if self.experiment_name == 'exp_ICBEB':
178
+ # compute classwise thresholds such that recall-focused Gbeta is optimized
179
+ thresholds = find_optimal_cutoff_thresholds_for_Gbeta(y_train, y_train_pred)
180
+ else:
181
+ thresholds = None
182
+
183
+ pool = multiprocessing.Pool(n_jobs)
184
+
185
+ tr_df = pd.concat(pool.starmap(generate_results,
186
+ zip(train_samples, repeat(y_train), repeat(y_train_pred),
187
+ repeat(thresholds))))
188
+ tr_df_point = generate_results(range(len(y_train)), y_train, y_train_pred, thresholds)
189
+ tr_df_result = pd.DataFrame(
190
+ np.array([
191
+ tr_df_point.mean().values,
192
+ tr_df.mean().values,
193
+ tr_df.quantile(0.05).values,
194
+ tr_df.quantile(0.95).values]),
195
+ columns=tr_df.columns,
196
+ index=['point', 'mean', 'lower', 'upper'])
197
+
198
+ te_df = pd.concat(pool.starmap(generate_results,
199
+ zip(test_samples, repeat(y_test), repeat(y_test_pred), repeat(thresholds))))
200
+ te_df_point = generate_results(range(len(y_test)), y_test, y_test_pred, thresholds)
201
+ te_df_result = pd.DataFrame(
202
+ np.array([
203
+ te_df_point.mean().values,
204
+ te_df.mean().values,
205
+ te_df.quantile(0.05).values,
206
+ te_df.quantile(0.95).values]),
207
+ columns=te_df.columns,
208
+ index=['point', 'mean', 'lower', 'upper'])
209
+
210
+ val_df = pd.concat(pool.starmap(generate_results,
211
+ zip(val_samples, repeat(y_val), repeat(y_val_pred), repeat(thresholds))))
212
+ val_df_point = generate_results(range(len(y_val)), y_val, y_val_pred, thresholds)
213
+ val_df_result = pd.DataFrame(
214
+ np.array([
215
+ val_df_point.mean().values,
216
+ val_df.mean().values,
217
+ val_df.quantile(0.05).values,
218
+ val_df.quantile(0.95).values]),
219
+ columns=val_df.columns,
220
+ index=['point', 'mean', 'lower', 'upper'])
221
+
222
+ pool.close()
223
+
224
+ # dump results
225
+ tr_df_result.to_csv(rpath + 'tr_results.csv')
226
+ val_df_result.to_csv(rpath + 'val_results.csv')
227
+ te_df_result.to_csv(rpath + 'te_results.csv')
exploratory_data_analysis/AutoECG_EDA.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model configs
2
+ from configurations.fastai_configs import conf_fastai_inception1d
3
+ from experiments.scp_experiment import SCPExperiment
4
+ from utilities.utils import generate_ptbxl_summary_table
5
+
6
+
7
+ def main():
8
+ data_folder = 'data/ptbxl/'
9
+ output_folder = 'output/'
10
+
11
+ models = [conf_fastai_inception1d]
12
+
13
+ # STANDARD SCP EXPERIMENTS ON PTB-XL
14
+
15
+ experiments = [
16
+ ('exp1.1', 'subdiagnostic')
17
+ ]
18
+
19
+ for name, task in experiments:
20
+ e = SCPExperiment(name, task, data_folder, output_folder, models)
21
+ e.prepare()
22
+ e.perform()
23
+ e.evaluate()
24
+
25
+ # generate summary table
26
+ generate_ptbxl_summary_table()
27
+
28
+
29
+ if __name__ == "__main__":
30
+ main()
models/__pycache__/base_model.cpython-39.pyc ADDED
Binary file (760 Bytes). View file
 
models/__pycache__/basicconv1d.cpython-39.pyc ADDED
Binary file (9.01 kB). View file
 
models/__pycache__/fastaiModel.cpython-310.pyc ADDED
Binary file (13.6 kB). View file
 
models/__pycache__/fastaiModel.cpython-39.pyc ADDED
Binary file (14.3 kB). View file
 
models/__pycache__/inception1d.cpython-39.pyc ADDED
Binary file (5.52 kB). View file
 
models/__pycache__/resnet1d.cpython-39.pyc ADDED
Binary file (9.51 kB). View file
 
models/__pycache__/rnn1d.cpython-39.pyc ADDED
Binary file (2.99 kB). View file
 
models/__pycache__/wavelet.cpython-39.pyc ADDED
Binary file (5.4 kB). View file
 
models/__pycache__/xresnet1d.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
models/base_model.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ class ClassificationModel(object):
2
+
3
+ def __init__(self):
4
+ pass
5
+
6
+ def fit(self, X_train, y_train, X_val, y_val):
7
+ pass
8
+
9
+ def predict(self, X, full_sequence=True):
10
+ pass
models/basicconv1d.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from fastai.layers import *
6
+ from fastai.data.core import *
7
+ from typing import Optional, Collection, Union
8
+ from collections.abc import Iterable
9
+
10
+ '''
11
+ This layer creates a convolution kernel that is convolved with the layer input
12
+ over a single spatial (or temporal) dimension to produce a tensor of outputs.
13
+ If use_bias is True, a bias vector is created and added to the outputs.
14
+ Finally, if activation is not None, it is applied to the outputs as well.
15
+ https://keras.io/api/layers/convolution_layers/convolution1d/
16
+ '''
17
+ def listify(o):
18
+ if o is None: return []
19
+ if isinstance(o, list): return o
20
+ if isinstance(o, str): return [o]
21
+ if isinstance(o, Iterable): return list(o)
22
+ return [o]
23
+ import torch.nn as nn
24
+
25
+ def bn_drop_lin(ni, no, bn=True, p=0., actn=None):
26
+ layers = []
27
+ if bn: layers.append(nn.BatchNorm1d(ni))
28
+ if p != 0.: layers.append(nn.Dropout(p))
29
+ layers.append(nn.Linear(ni, no))
30
+ if actn is not None: layers.append(actn)
31
+ return layers
32
+
33
+ def _conv1d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0):
34
+ lst = []
35
+ if (drop_p > 0):
36
+ lst.append(nn.Dropout(drop_p))
37
+ lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
38
+ dilation=dilation, bias=not bn))
39
+ if bn:
40
+ lst.append(nn.BatchNorm1d(out_planes))
41
+ if act == "relu":
42
+ lst.append(nn.ReLU(True))
43
+ if act == "elu":
44
+ lst.append(nn.ELU(True))
45
+ if act == "prelu":
46
+ lst.append(nn.PReLU(True))
47
+ return nn.Sequential(*lst)
48
+
49
+
50
+ def _fc(in_planes, out_planes, act="relu", bn=True):
51
+ lst = [nn.Linear(in_planes, out_planes, bias=not (bn))]
52
+ if bn:
53
+ lst.append(nn.BatchNorm1d(out_planes))
54
+ if act == "relu":
55
+ lst.append(nn.ReLU(True))
56
+ if act == "elu":
57
+ lst.append(nn.ELU(True))
58
+ if act == "prelu":
59
+ lst.append(nn.PReLU(True))
60
+ return nn.Sequential(*lst)
61
+
62
+
63
+ def cd_adaptive_concat_pool(relevant, irrelevant, module):
64
+ mpr, mpi = module.mp.attrib(relevant, irrelevant)
65
+ apr, api = module.ap.attrib(relevant, irrelevant)
66
+ return torch.cat([mpr, apr], 1), torch.cat([mpi, api], 1)
67
+
68
+
69
+ def attrib_adaptive_concat_pool(self, relevant, irrelevant):
70
+ return cd_adaptive_concat_pool(relevant, irrelevant, self)
71
+
72
+
73
+ class AdaptiveConcatPool1d(nn.Module):
74
+ """Layer that concat `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."""
75
+
76
+ def __init__(self, sz: Optional[int] = None):
77
+ """Output will be 2*sz or 2 if sz is None"""
78
+ super().__init__()
79
+ sz = sz or 1
80
+ self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
81
+
82
+ def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
83
+
84
+ def attrib(self, relevant, irrelevant):
85
+ return attrib_adaptive_concat_pool(self, relevant, irrelevant)
86
+
87
+
88
+ class SqueezeExcite1d(nn.Module):
89
+ """squeeze excite block as used for example in LSTM FCN"""
90
+
91
+ def __init__(self, channels, reduction=16):
92
+ super().__init__()
93
+ channels_reduced = channels // reduction
94
+ self.w1 = torch.nn.Parameter(torch.randn(channels_reduced, channels).unsqueeze(0))
95
+ self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0))
96
+
97
+ def forward(self, x):
98
+ # input is bs,ch,seq
99
+ z = torch.mean(x, dim=2, keepdim=True) # bs,ch
100
+ intermed = F.relu(torch.matmul(self.w1, z)) # (1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1)
101
+ s = F.sigmoid(torch.matmul(self.w2, intermed)) # (1,ch,ch_red * bs, ch_red, 1=bs, ch, 1
102
+ return s * x # bs,ch,seq * bs, ch,1 = bs,ch,seq
103
+
104
+
105
+ def weight_init(m):
106
+ """call weight initialization for model n via n.apply(weight_init)"""
107
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
108
+ nn.init.kaiming_normal_(m.weight)
109
+ if m.bias is not None:
110
+ nn.init.zeros_(m.bias)
111
+ if isinstance(m, nn.BatchNorm1d):
112
+ nn.init.constant_(m.weight, 1)
113
+ nn.init.constant_(m.bias, 0)
114
+ if isinstance(m, SqueezeExcite1d):
115
+ stdv1 = math.sqrt(2. / m.w1.size[0])
116
+ nn.init.normal_(m.w1, 0., stdv1)
117
+ stdv2 = math.sqrt(1. / m.w2.size[1])
118
+ nn.init.normal_(m.w2, 0., stdv2)
119
+
120
+
121
+ def create_head1d(nf: int, nc: int, lin_ftrs: Optional[Collection[int]] = None, ps: Union[float, Collection[float]] = 0.5,
122
+ bn_final: bool = False, bn: bool = True, act="relu", concat_pooling=True):
123
+ """Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here"""
124
+ lin_ftrs = [2 * nf if concat_pooling else nf, nc] if lin_ftrs is None else [
125
+ 2 * nf if concat_pooling else nf] + lin_ftrs + [
126
+ nc] # was [nf, 512,nc]
127
+ ps = listify(ps)
128
+ if len(ps) == 1: ps = [ps[0] / 2] * (len(lin_ftrs) - 2) + ps
129
+ actns = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None]
130
+ layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()]
131
+ for ni, no, p, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):
132
+ layers += bn_drop_lin(ni, no, bn, p, actn)
133
+ if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
134
+ return nn.Sequential(*layers)
135
+
136
+
137
+ # basic convolutional architecture
138
+
139
+ class BasicConv1d(nn.Sequential):
140
+ """basic conv1d"""
141
+
142
+ def __init__(self, filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1,
143
+ squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,
144
+ split_first_layer=False, drop_p=0., lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
145
+ act_head="relu", concat_pooling=True):
146
+ if filters is None:
147
+ filters = [128, 128, 128, 128]
148
+ layers = []
149
+ if isinstance(kernel_size, int):
150
+ kernel_size = [kernel_size] * len(filters)
151
+ for i in range(len(filters)):
152
+ layers_tmp = [_conv1d(input_channels if i == 0 else filters[i - 1], filters[i], kernel_size=kernel_size[i],
153
+ stride=(1 if (split_first_layer is True and i == 0) else stride), dilation=dilation,
154
+ act="none" if ((headless is True and i == len(filters) - 1) or (
155
+ split_first_layer is True and i == 0)) else act,
156
+ bn=False if (headless is True and i == len(filters) - 1) else bn,
157
+ drop_p=(0. if i == 0 else drop_p))]
158
+
159
+ if split_first_layer is True and i == 0:
160
+ layers_tmp.append(_conv1d(filters[0], filters[0], kernel_size=1, stride=1, act=act, bn=bn, drop_p=0.))
161
+ # layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn)))
162
+ # layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn))
163
+ if pool > 0 and i < len(filters) - 1:
164
+ layers_tmp.append(nn.MaxPool1d(pool, stride=pool_stride, padding=(pool - 1) // 2))
165
+ if squeeze_excite_reduction > 0:
166
+ layers_tmp.append(SqueezeExcite1d(filters[i], squeeze_excite_reduction))
167
+ layers.append(nn.Sequential(*layers_tmp))
168
+
169
+ # head layers.append(nn.AdaptiveAvgPool1d(1)) layers.append(nn.Linear(filters[-1],num_classes)) head
170
+ # #inplace=True leads to a runtime error see ReLU+ dropout
171
+ # https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5
172
+ self.headless = headless
173
+ if headless is True:
174
+ head = nn.Sequential(nn.AdaptiveAvgPool1d(1), Flatten())
175
+ else:
176
+ head = create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
177
+ bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
178
+ layers.append(head)
179
+
180
+ super().__init__(*layers)
181
+
182
+ def get_layer_groups(self):
183
+ return self[2], self[-1]
184
+
185
+ def get_output_layer(self):
186
+ if self.headless is False:
187
+ return self[-1][-1]
188
+ else:
189
+ return None
190
+
191
+ def set_output_layer(self, x):
192
+ if self.headless is False:
193
+ self[-1][-1] = x
194
+
195
+
196
+ # convenience functions for basic convolutional architectures
197
+ def fcn(filters=None, num_classes=2, input_channels=8):
198
+ if filters is None:
199
+ filters = [128] * 5
200
+ filters_in = filters + [num_classes]
201
+ return BasicConv1d(filters=filters_in, kernel_size=3, stride=1, pool=2, pool_stride=2,
202
+ input_channels=input_channels, act="relu", bn=True, headless=True)
203
+
204
+
205
+ def fcn_wang(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
206
+ act_head="relu", concat_pooling=True):
207
+ return BasicConv1d(filters=[128, 256, 128], kernel_size=[8, 5, 3], stride=1, pool=0, pool_stride=2,
208
+ num_classes=num_classes, input_channels=input_channels, act="relu", bn=True,
209
+ lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head,
210
+ act_head=act_head, concat_pooling=concat_pooling)
211
+
212
+
213
+ def schirrmeister(num_classes=2, input_channels=8, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True,
214
+ act_head="relu", concat_pooling=True):
215
+ return BasicConv1d(filters=[25, 50, 100, 200], kernel_size=10, stride=3, pool=3, pool_stride=1,
216
+ num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,
217
+ split_first_layer=True, drop_p=0.5, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head,
218
+ bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
219
+
220
+
221
+ def sen(filters=None, num_classes=2, input_channels=8, squeeze_excite_reduction=16, drop_p=0., lin_ftrs_head=None,
222
+ ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
223
+ if filters is None:
224
+ filters = [128] * 5
225
+ return BasicConv1d(filters=filters, kernel_size=3, stride=2, pool=0, pool_stride=0, input_channels=input_channels,
226
+ act="relu", bn=True, num_classes=num_classes, squeeze_excite_reduction=squeeze_excite_reduction,
227
+ drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head,
228
+ bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
229
+
230
+
231
+ def basic1d(filters=None, kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0,
232
+ num_classes=2, input_channels=8, act="relu", bn=True, headless=False, drop_p=0., lin_ftrs_head=None,
233
+ ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
234
+ if filters is None:
235
+ filters = [128] * 5
236
+ return BasicConv1d(filters=filters, kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool,
237
+ pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction,
238
+ num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,
239
+ drop_p=drop_p, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head,
240
+ bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
models/fastaiModel.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.data.core import *
2
+ from fastai.learner import *
3
+ from fastai.callback.schedule import *
4
+ from fastai.torch_core import *
5
+ from fastai.callback.tracker import SaveModelCallback
6
+ # from fastai.callback.gradient import GradientClipping
7
+ from pathlib import Path
8
+ from functools import partial
9
+ import math
10
+ # from fastai.callback import GradientClipping
11
+ import torch
12
+ from fastai.tabular.core import range_of
13
+ import numpy as np
14
+ import matplotlib
15
+ import matplotlib.pyplot as plt
16
+ from fastai.callback.core import Callback
17
+ from fastai.data.core import DataLoaders
18
+ import torch.nn.functional as F
19
+ # from fastai.metrics import add_metrics
20
+ import torch.nn as nn
21
+ from fastcore.utils import ifnone
22
+ import pandas as pd
23
+ from models.base_model import ClassificationModel
24
+ from models.basicconv1d import weight_init, fcn_wang, fcn, schirrmeister, sen, basic1d
25
+ from models.inception1d import inception1d
26
+ from models.resnet1d import resnet1d18, resnet1d34, resnet1d50, resnet1d101, resnet1d152, resnet1d_wang, \
27
+ wrn1d_22
28
+ from models.rnn1d import RNN1d
29
+ from utilities.timeseries_utils import TimeseriesDatasetCrops, ToTensor, aggregate_predictions
30
+ from models.xresnet1d import xresnet1d18_deeper, xresnet1d34_deeper, xresnet1d50_deeper, xresnet1d18_deep, \
31
+ xresnet1d34_deep, xresnet1d50_deep, xresnet1d18, xresnet1d34, xresnet1d101, xresnet1d50, xresnet1d152
32
+ from utilities.utils import evaluate_experiment
33
+ def add_metrics(last_metrics, new_metric):
34
+ """
35
+ Adds a new metric to the list of last metrics.
36
+
37
+ Args:
38
+ last_metrics (list): List of previous metrics.
39
+ new_metric (float or list): New metric(s) to add.
40
+
41
+ Returns:
42
+ list: Updated list of metrics.
43
+ """
44
+ if isinstance(new_metric, list):
45
+ return last_metrics + new_metric
46
+ else:
47
+ return last_metrics + [new_metric]
48
+
49
+ class MetricFunc(Callback):
50
+ """Obtains score using user-supplied function func (potentially ignoring targets with ignore_idx)"""
51
+
52
+ def __init__(self, func, name="MetricFunc", ignore_idx=None, one_hot_encode_target=True, argmax_pred=False,
53
+ softmax_pred=True, flatten_target=True, sigmoid_pred=False, metric_component=None):
54
+ super().__init__()
55
+ self.metric_complete = self.func(self.y_true, self.y_pred)
56
+ self.y_true = None
57
+ self.y_pred = None
58
+ self.func = func
59
+ self.ignore_idx = ignore_idx
60
+ self.one_hot_encode_target = one_hot_encode_target
61
+ self.argmax_pred = argmax_pred
62
+ self.softmax_pred = softmax_pred
63
+ self.flatten_target = flatten_target
64
+ self.sigmoid_pred = sigmoid_pred
65
+ self.metric_component = metric_component
66
+ self.name = name
67
+
68
+ def on_epoch_begin(self, **kwargs):
69
+ pass
70
+
71
+ def on_batch_end(self, last_output, last_target, **kwargs):
72
+ # flatten everything (to make it also work for annotation tasks)
73
+ y_pred_flat = last_output.view((-1, last_output.size()[-1]))
74
+
75
+ if self.flatten_target:
76
+ last_target.view(-1)
77
+ y_true_flat = last_target
78
+
79
+ # optionally take argmax of predictions
80
+ if self.argmax_pred is True:
81
+ y_pred_flat = y_pred_flat.argmax(dim=1)
82
+ elif self.softmax_pred is True:
83
+ y_pred_flat = F.softmax(y_pred_flat, dim=1)
84
+ elif self.sigmoid_pred is True:
85
+ y_pred_flat = torch.sigmoid(y_pred_flat)
86
+
87
+ # potentially remove ignore_idx entries
88
+ if self.ignore_idx is not None:
89
+ selected_indices = (y_true_flat != self.ignore_idx).nonzero().squeeze()
90
+ y_pred_flat = y_pred_flat[selected_indices]
91
+ y_true_flat = y_true_flat[selected_indices]
92
+
93
+ y_pred_flat = to_np(y_pred_flat)
94
+ y_true_flat = to_np(y_true_flat)
95
+
96
+ if self.one_hot_encode_target is True:
97
+ y_true_flat = np.one_hot_np(y_true_flat, last_output.size()[-1])
98
+
99
+ if self.y_pred is None:
100
+ self.y_pred = y_pred_flat
101
+ self.y_true = y_true_flat
102
+ else:
103
+ self.y_pred = np.concatenate([self.y_pred, y_pred_flat], axis=0)
104
+ self.y_true = np.concatenate([self.y_true, y_true_flat], axis=0)
105
+
106
+ def on_epoch_end(self, last_metrics, **kwargs):
107
+ # access full metric (possibly multiple components) via self.metric_complete
108
+ if self.metric_component is not None:
109
+ return add_metrics(last_metrics, self.metric_complete[self.metric_component])
110
+ else:
111
+ return add_metrics(last_metrics, self.metric_complete)
112
+
113
+
114
+ def fmax_metric(targs, preds):
115
+ return evaluate_experiment(targs, preds)["Fmax"]
116
+
117
+
118
+ def auc_metric(targs, preds):
119
+ return evaluate_experiment(targs, preds)["macro_auc"]
120
+
121
+
122
+ def mse_flat(preds, targs):
123
+ return torch.mean(torch.pow(preds.view(-1) - targs.view(-1), 2))
124
+
125
+
126
+ def nll_regression(preds, targs):
127
+ # preds: bs, 2
128
+ # targs: bs, 1
129
+ preds_mean = preds[:, 0]
130
+ # warning: output goes through exponential map to ensure positivity
131
+ preds_var = torch.clamp(torch.exp(preds[:, 1]), 1e-4, 1e10)
132
+ # print(to_np(preds_mean)[0],to_np(targs)[0,0],to_np(torch.sqrt(preds_var))[0])
133
+ return torch.mean(torch.log(2 * math.pi * preds_var) / 2) + torch.mean(
134
+ torch.pow(preds_mean - targs[:, 0], 2) / 2 / preds_var)
135
+
136
+
137
+ def nll_regression_init(m):
138
+ assert (isinstance(m, nn.Linear))
139
+ nn.init.normal_(m.weight, 0., 0.001)
140
+ nn.init.constant_(m.bias, 4)
141
+
142
+
143
+ def lr_find_plot(learner, path, filename="lr_find", n_skip=10, n_skip_end=2):
144
+ """
145
+ saves lr_find plot as file (normally only jupyter output)
146
+ on the x-axis is lrs[-1]
147
+ """
148
+ learner.lr_find()
149
+
150
+ backend_old = matplotlib.get_backend()
151
+ plt.switch_backend('agg')
152
+ plt.ylabel("loss")
153
+ plt.xlabel("learning rate (log scale)")
154
+ losses = [to_np(x) for x in learner.recorder.losses[n_skip:-(n_skip_end + 1)]]
155
+ # print(learner.recorder.val_losses)
156
+ # val_losses = [ to_np(x) for x in learner.recorder.val_losses[n_skip:-(n_skip_end+1)]]
157
+
158
+ plt.plot(learner.recorder.lrs[n_skip:-(n_skip_end + 1)], losses)
159
+ # plt.plot(learner.recorder.lrs[n_skip:-(n_skip_end+1)],val_losses )
160
+
161
+ plt.xscale('log')
162
+ plt.savefig(str(path / (filename + '.png')))
163
+ plt.switch_backend(backend_old)
164
+
165
+
166
+ def losses_plot(learner, path, filename="losses", last: int = None):
167
+ """
168
+ saves lr_find plot as file (normally only jupyter output)
169
+ on the x-axis is lrs[-1]
170
+ """
171
+ backend_old = matplotlib.get_backend()
172
+ plt.switch_backend('agg')
173
+ plt.ylabel("loss")
174
+ plt.xlabel("Batches processed")
175
+
176
+ last = ifnone(last, len(learner.recorder.nb_batches))
177
+ l_b = np.sum(learner.recorder.nb_batches[-last:])
178
+ iterations = range_of(learner.recorder.losses)[-l_b:]
179
+ plt.plot(iterations, learner.recorder.losses[-l_b:], label='Train')
180
+ val_iter = learner.recorder.nb_batches[-last:]
181
+ val_iter = np.cumsum(val_iter) + np.sum(learner.recorder.nb_batches[:-last])
182
+ plt.plot(val_iter, learner.recorder.val_losses[-last:], label='Validation')
183
+ plt.legend()
184
+
185
+ plt.savefig(str(path / (filename + '.png')))
186
+ plt.switch_backend(backend_old)
187
+
188
+
189
+ class FastaiModel(ClassificationModel):
190
+ def __init__(self, name, n_classes, freq, output_folder, input_shape, pretrained=False, input_size=2.5,
191
+ input_channels=12, chunkify_train=False, chunkify_valid=True, bs=128, ps_head=0.5, lin_ftrs_head=None,
192
+ wd=1e-2, epochs=50, lr=1e-2, kernel_size=5, loss="binary_cross_entropy", pretrained_folder=None,
193
+ n_classes_pretrained=None, gradual_unfreezing=True, discriminative_lrs=True, epochs_finetuning=30,
194
+ early_stopping=None, aggregate_fn="max", concat_train_val=False):
195
+ super().__init__()
196
+
197
+ if lin_ftrs_head is None:
198
+ lin_ftrs_head = [128]
199
+ self.name = name
200
+ self.num_classes = n_classes if loss != "nll_regression" else 2
201
+ self.target_fs = freq
202
+ self.output_folder = Path(output_folder)
203
+
204
+ self.input_size = int(input_size * self.target_fs)
205
+ self.input_channels = input_channels
206
+
207
+ self.chunkify_train = chunkify_train
208
+ self.chunkify_valid = chunkify_valid
209
+
210
+ self.chunk_length_train = 2 * self.input_size # target_fs*6
211
+ self.chunk_length_valid = self.input_size
212
+
213
+ self.min_chunk_length = self.input_size # chunk_length
214
+
215
+ self.stride_length_train = self.input_size # chunk_length_train//8
216
+ self.stride_length_valid = self.input_size // 2 # chunk_length_valid
217
+
218
+ self.copies_valid = 0 # >0 should only be used with chunkify_valid=False
219
+
220
+ self.bs = bs
221
+ self.ps_head = ps_head
222
+ self.lin_ftrs_head = lin_ftrs_head
223
+ self.wd = wd
224
+ self.epochs = epochs
225
+ self.lr = lr
226
+ self.kernel_size = kernel_size
227
+ self.loss = loss
228
+ self.input_shape = input_shape
229
+
230
+ if pretrained:
231
+ if pretrained_folder is None:
232
+ pretrained_folder = Path('../output/exp0/models/' + name.split("_pretrained")[0] + '/')
233
+ # pretrained_folder = Path('/output/exp0/models/'+name.split("_pretrained")[0]+'/')
234
+
235
+ if n_classes_pretrained is None:
236
+ n_classes_pretrained = 71
237
+
238
+ self.pretrained_folder = None if pretrained_folder is None else Path(pretrained_folder)
239
+ self.n_classes_pretrained = n_classes_pretrained
240
+ self.discriminative_lrs = discriminative_lrs
241
+ self.gradual_unfreezing = gradual_unfreezing
242
+ self.epochs_finetuning = epochs_finetuning
243
+
244
+ self.early_stopping = early_stopping
245
+ self.aggregate_fn = aggregate_fn
246
+ self.concat_train_val = concat_train_val
247
+
248
+ def fit(self, X_train, y_train, X_val, y_val):
249
+ # convert everything to float32
250
+ X_train = [l.astype(np.float32) for l in X_train]
251
+ X_val = [l.astype(np.float32) for l in X_val]
252
+ y_train = [l.astype(np.float32) for l in y_train]
253
+ y_val = [l.astype(np.float32) for l in y_val]
254
+
255
+ if self.concat_train_val:
256
+ X_train += X_val
257
+ y_train += y_val
258
+
259
+ if self.pretrained_folder is None: # from scratch
260
+ print("Training from scratch...")
261
+ learn = self._get_learner(X_train, y_train, X_val, y_val)
262
+
263
+ # if(self.discriminative_lrs):
264
+ # layer_groups=learn.model.get_layer_groups()
265
+ # learn.split(layer_groups)
266
+ learn.model.apply(weight_init)
267
+
268
+ # initialization for regression output
269
+ if self.loss == "nll_regression" or self.loss == "mse":
270
+ output_layer_new = learn.model.get_output_layer()
271
+ output_layer_new.apply(nll_regression_init)
272
+ learn.model.set_output_layer(output_layer_new)
273
+
274
+ lr_find_plot(learn, self.output_folder)
275
+ learn.fit_one_cycle(self.epochs, self.lr) # slice(self.lr) if self.discriminative_lrs else self.lr)
276
+ losses_plot(learn, self.output_folder)
277
+ else: # finetuning
278
+ print("Finetuning...")
279
+ # create learner
280
+ learn = self._get_learner(X_train, y_train, X_val, y_val, self.n_classes_pretrained)
281
+
282
+ # load pretrained model
283
+ learn.path = self.pretrained_folder
284
+ learn.load(self.pretrained_folder.stem)
285
+ learn.path = self.output_folder
286
+
287
+ # exchange top layer
288
+ output_layer = learn.model.get_output_layer()
289
+ output_layer_new = nn.Linear(output_layer.in_features, self.num_classes).cuda()
290
+ apply_init(output_layer_new, nn.init.kaiming_normal_)
291
+ learn.model.set_output_layer(output_layer_new)
292
+
293
+ # layer groups
294
+ if self.discriminative_lrs:
295
+ layer_groups = learn.model.get_layer_groups()
296
+ learn.split(layer_groups)
297
+
298
+ learn.train_bn = True # make sure if bn mode is train
299
+
300
+ # train
301
+ lr = self.lr
302
+ if self.gradual_unfreezing:
303
+ assert (self.discriminative_lrs is True)
304
+ learn.freeze()
305
+ lr_find_plot(learn, self.output_folder, "lr_find0")
306
+ learn.fit_one_cycle(self.epochs_finetuning, lr)
307
+ losses_plot(learn, self.output_folder, "losses0")
308
+ # for n in [0]:#range(len(layer_groups)): learn.freeze_to(-n-1) lr_find_plot(learn,
309
+ # self.output_folder,"lr_find"+str(n)) learn.fit_one_cycle(self.epochs_gradual_unfreezing,slice(lr))
310
+ # losses_plot(learn, self.output_folder,"losses"+str(n)) if(n==0):#reduce lr after first step lr/=10.
311
+ # if(n>0 and (self.name.startswith("fastai_lstm") or self.name.startswith("fastai_gru"))):#reduce lr
312
+ # further for RNNs lr/=10
313
+
314
+ learn.unfreeze()
315
+ lr_find_plot(learn, self.output_folder, "lr_find" + str(len(layer_groups)))
316
+ learn.fit_one_cycle(self.epochs_finetuning, slice(lr / 1000, lr / 10))
317
+ losses_plot(learn, self.output_folder, "losses" + str(len(layer_groups)))
318
+
319
+ learn.save(self.name) # even for early stopping the best model will have been loaded again
320
+
321
+ def predict(self, X):
322
+ X = [l.astype(np.float32) for l in X]
323
+ y_dummy = [np.ones(self.num_classes, dtype=np.float32) for _ in range(len(X))]
324
+
325
+ learn = self._get_learner(X, y_dummy, X, y_dummy)
326
+ learn.load(self.name)
327
+
328
+ preds, targs = learn.get_preds()
329
+ preds = to_np(preds)
330
+
331
+ idmap = learn.data.valid_ds.get_id_mapping()
332
+
333
+ return aggregate_predictions(preds, idmap=idmap,
334
+ aggregate_fn=np.mean if self.aggregate_fn == "mean" else np.amax)
335
+
336
+ def _get_learner(self, X_train, y_train, X_val, y_val, num_classes=None):
337
+ df_train = pd.DataFrame({"data": range(len(X_train)), "label": y_train})
338
+ df_valid = pd.DataFrame({"data": range(len(X_val)), "label": y_val})
339
+
340
+ tfms_ptb_xl = [ToTensor()]
341
+
342
+ ds_train = TimeseriesDatasetCrops(df_train, self.input_size, num_classes=self.num_classes,
343
+ chunk_length=self.chunk_length_train if self.chunkify_train else 0,
344
+ min_chunk_length=self.min_chunk_length,
345
+ stride=self.stride_length_train, transforms=tfms_ptb_xl,
346
+ annotation=False, col_lbl="label", npy_data=X_train)
347
+ ds_valid = TimeseriesDatasetCrops(df_valid, self.input_size, num_classes=self.num_classes,
348
+ chunk_length=self.chunk_length_valid if self.chunkify_valid else 0,
349
+ min_chunk_length=self.min_chunk_length,
350
+ stride=self.stride_length_valid, transforms=tfms_ptb_xl,
351
+ annotation=False, col_lbl="label", npy_data=X_val)
352
+
353
+ db = DataLoaders(ds_train, ds_valid)
354
+
355
+ if self.loss == "binary_cross_entropy":
356
+ loss = F.binary_cross_entropy_with_logits
357
+ elif self.loss == "cross_entropy":
358
+ loss = F.cross_entropy
359
+ elif self.loss == "mse":
360
+ loss = mse_flat
361
+ elif self.loss == "nll_regression":
362
+ loss = nll_regression
363
+ else:
364
+ print("loss not found")
365
+ assert (True)
366
+
367
+ self.input_channels = self.input_shape[-1]
368
+ metrics = []
369
+
370
+ print("model:", self.name)
371
+ # note: all models of a particular kind share the same prefix but potentially a different
372
+ # postfix such as _input256
373
+ num_classes = self.num_classes if num_classes is None else num_classes
374
+ # resnet resnet1d18,resnet1d34,resnet1d50,resnet1d101,resnet1d152,resnet1d_wang,resnet1d,wrn1d_22
375
+ if self.name.startswith("fastai_resnet1d18"):
376
+ model = resnet1d18(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
377
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
378
+ lin_ftrs_head=self.lin_ftrs_head)
379
+ elif self.name.startswith("fastai_resnet1d34"):
380
+ model = resnet1d34(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
381
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
382
+ lin_ftrs_head=self.lin_ftrs_head)
383
+ elif self.name.startswith("fastai_resnet1d50"):
384
+ model = resnet1d50(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
385
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
386
+ lin_ftrs_head=self.lin_ftrs_head)
387
+ elif self.name.startswith("fastai_resnet1d101"):
388
+ model = resnet1d101(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
389
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
390
+ lin_ftrs_head=self.lin_ftrs_head)
391
+ elif self.name.startswith("fastai_resnet1d152"):
392
+ model = resnet1d152(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
393
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
394
+ lin_ftrs_head=self.lin_ftrs_head)
395
+ elif self.name.startswith("fastai_resnet1d_wang"):
396
+ model = resnet1d_wang(num_classes=num_classes, input_channels=self.input_channels,
397
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
398
+ lin_ftrs_head=self.lin_ftrs_head)
399
+ elif self.name.startswith("fastai_wrn1d_22"):
400
+ model = wrn1d_22(num_classes=num_classes, input_channels=self.input_channels,
401
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
402
+ lin_ftrs_head=self.lin_ftrs_head)
403
+
404
+ # xresnet ... (order important for string capture)
405
+ elif self.name.startswith("fastai_xresnet1d18_deeper"):
406
+ model = xresnet1d18_deeper(num_classes=num_classes, input_channels=self.input_channels,
407
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
408
+ lin_ftrs_head=self.lin_ftrs_head)
409
+ elif self.name.startswith("fastai_xresnet1d34_deeper"):
410
+ model = xresnet1d34_deeper(num_classes=num_classes, input_channels=self.input_channels,
411
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
412
+ lin_ftrs_head=self.lin_ftrs_head)
413
+ elif self.name.startswith("fastai_xresnet1d50_deeper"):
414
+ model = xresnet1d50_deeper(num_classes=num_classes, input_channels=self.input_channels,
415
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
416
+ lin_ftrs_head=self.lin_ftrs_head)
417
+ elif self.name.startswith("fastai_xresnet1d18_deep"):
418
+ model = xresnet1d18_deep(num_classes=num_classes, input_channels=self.input_channels,
419
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
420
+ lin_ftrs_head=self.lin_ftrs_head)
421
+ elif self.name.startswith("fastai_xresnet1d34_deep"):
422
+ model = xresnet1d34_deep(num_classes=num_classes, input_channels=self.input_channels,
423
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
424
+ lin_ftrs_head=self.lin_ftrs_head)
425
+ elif self.name.startswith("fastai_xresnet1d50_deep"):
426
+ model = xresnet1d50_deep(num_classes=num_classes, input_channels=self.input_channels,
427
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
428
+ lin_ftrs_head=self.lin_ftrs_head)
429
+ elif self.name.startswith("fastai_xresnet1d18"):
430
+ model = xresnet1d18(num_classes=num_classes, input_channels=self.input_channels,
431
+ kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
432
+ elif self.name.startswith("fastai_xresnet1d34"):
433
+ model = xresnet1d34(num_classes=num_classes, input_channels=self.input_channels,
434
+ kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
435
+ elif self.name.startswith("fastai_xresnet1d50"):
436
+ model = xresnet1d50(num_classes=num_classes, input_channels=self.input_channels,
437
+ kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
438
+ elif self.name.startswith("fastai_xresnet1d101"):
439
+ model = xresnet1d101(num_classes=num_classes, input_channels=self.input_channels,
440
+ kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
441
+ elif self.name.startswith("fastai_xresnet1d152"):
442
+ model = xresnet1d152(num_classes=num_classes, input_channels=self.input_channels,
443
+ kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
444
+
445
+ # inception passing the default kernel size of 5 leads to a max kernel size of 40-1 in the inception model as
446
+ # proposed in the original paper
447
+ elif self.name == "fastai_inception1d_no_residual": # note: order important for string capture
448
+ model = inception1d(num_classes=num_classes, input_channels=self.input_channels,
449
+ use_residual=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head,
450
+ kernel_size=8 * self.kernel_size)
451
+ elif self.name.startswith("fastai_inception1d"):
452
+ model = inception1d(num_classes=num_classes, input_channels=self.input_channels,
453
+ use_residual=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head,
454
+ kernel_size=8 * self.kernel_size)
455
+
456
+
457
+ # BasicConv1d fcn,fcn_wang,schirrmeister,sen,basic1d
458
+ elif self.name.startswith("fastai_fcn_wang"): # note: order important for string capture
459
+ model = fcn_wang(num_classes=num_classes, input_channels=self.input_channels,
460
+ ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
461
+ elif self.name.startswith("fastai_fcn"):
462
+ model = fcn(num_classes=num_classes, input_channels=self.input_channels)
463
+ elif self.name.startswith("fastai_schirrmeister"):
464
+ model = schirrmeister(num_classes=num_classes, input_channels=self.input_channels,
465
+ ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
466
+ elif self.name.startswith("fastai_sen"):
467
+ model = sen(num_classes=num_classes, input_channels=self.input_channels, ps_head=self.ps_head,
468
+ lin_ftrs_head=self.lin_ftrs_head)
469
+ elif self.name.startswith("fastai_basic1d"):
470
+ model = basic1d(num_classes=num_classes, input_channels=self.input_channels,
471
+ kernel_size=self.kernel_size, ps_head=self.ps_head,
472
+ lin_ftrs_head=self.lin_ftrs_head)
473
+ # RNN
474
+ elif self.name.startswith("fastai_lstm_bidir"):
475
+ model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=True,
476
+ bidirectional=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
477
+ elif self.name.startswith("fastai_gru_bidir"):
478
+ model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=False,
479
+ bidirectional=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
480
+ elif self.name.startswith("fastai_lstm"):
481
+ model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=True,
482
+ bidirectional=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
483
+ elif self.name.startswith("fastai_gru"):
484
+ model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=False,
485
+ bidirectional=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
486
+ else:
487
+ print("Model not found.")
488
+ assert True
489
+
490
+ learn = Learner(db, model, loss_func=loss, metrics=metrics, wd=self.wd, path=self.output_folder)
491
+
492
+ if self.name.startswith("fastai_lstm") or self.name.startswith("fastai_gru"):
493
+ learn.callback_fns.append(partial(GradientClipping, clip=0.25))
494
+
495
+ if self.early_stopping is not None:
496
+ # supported options: valid_loss, macro_auc, fmax
497
+ if self.early_stopping == "macro_auc" and self.loss != "mse" and self.loss != "nll_regression":
498
+ metric = MetricFunc(auc_metric, self.early_stopping,
499
+ one_hot_encode_target=False, argmax_pred=False, softmax_pred=False,
500
+ sigmoid_pred=True, flatten_target=False)
501
+ learn.metrics.append(metric)
502
+ learn.callback_fns.append(
503
+ partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))
504
+ elif self.early_stopping == "fmax" and self.loss != "mse" and self.loss != "nll_regression":
505
+ metric = MetricFunc(fmax_metric, self.early_stopping,
506
+ one_hot_encode_target=False, argmax_pred=False, softmax_pred=False,
507
+ sigmoid_pred=True, flatten_target=False)
508
+ learn.metrics.append(metric)
509
+ learn.callback_fns.append(partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))
510
+ elif self.early_stopping == "valid_loss":
511
+ learn.callback_fns.append(partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))
512
+
513
+ return learn
models/inception1d.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.data.core import *
4
+
5
+ # Inception time inspired by https://github.com/hfawaz/InceptionTime/blob/master/classifiers/inception.py and https://github.com/tcapelle/TimeSeries_fastai/blob/master/inception.py
6
+ from models.basicconv1d import create_head1d
7
+
8
+
9
+ def conv(in_planes, out_planes, kernel_size=3, stride=1):
10
+ "convolution with padding"
11
+ return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
12
+ padding=(kernel_size - 1) // 2, bias=False)
13
+
14
+
15
+ def noop(x): return x
16
+
17
+
18
+ # class InceptionBlock1d(nn.Module):
19
+ # def __init__(self, ni, nb_filters, kss, stride=1, act='linear', bottleneck_size=32):
20
+ # super().__init__()
21
+ # self.bottleneck = conv(ni, bottleneck_size, 1, stride) if (bottleneck_size > 0) else noop
22
+
23
+ # self.convs = nn.ModuleList(
24
+ # [conv(bottleneck_size if (bottleneck_size > 0) else ni, nb_filters, ks) for ks in kss])
25
+ # self.conv_bottle = nn.Sequential(nn.MaxPool1d(3, stride, padding=1), conv(ni, nb_filters, 1))
26
+ # self.bn_relu = nn.Sequential(nn.BatchNorm1d((len(kss) + 1) * nb_filters), nn.ReLU())
27
+
28
+ # def forward(self, x):
29
+ # # print("block in",x.size())
30
+ # bottled = self.bottleneck(x)
31
+ # out = self.bn_relu(torch.cat([c(bottled) for c in self.convs] + [self.conv_bottle(x)], dim=1))
32
+ # return out
33
+ class InceptionBlock1d(nn.Module):
34
+ def __init__(self, ni, nb_filters, kss, stride=1, act='linear', bottleneck_size=32):
35
+ super().__init__()
36
+ self.bottleneck = conv(ni, bottleneck_size, 1, stride) if (bottleneck_size > 0) else noop
37
+
38
+ self.convs = nn.ModuleList(
39
+ [conv(bottleneck_size if (bottleneck_size > 0) else ni, nb_filters, ks) for ks in kss])
40
+ self.conv_bottle = nn.Sequential(nn.MaxPool1d(3, stride, padding=1), conv(ni, nb_filters, 1))
41
+ self.bn_relu = nn.Sequential(nn.BatchNorm1d((len(kss) + 1) * nb_filters), nn.ReLU())
42
+
43
+ def forward(self, x):
44
+ bottled = self.bottleneck(x)
45
+ conv_outputs = [c(bottled) for c in self.convs]
46
+ bottle_output = self.conv_bottle(x)
47
+ out = self.bn_relu(torch.cat(conv_outputs + [bottle_output], dim=1))
48
+ return out
49
+
50
+ class Shortcut1d(nn.Module):
51
+ def __init__(self, ni, nf):
52
+ super().__init__()
53
+ self.act_fn = nn.ReLU(True)
54
+ self.conv = conv(ni, nf, 1)
55
+ self.bn = nn.BatchNorm1d(nf)
56
+
57
+ def forward(self, inp, out):
58
+ # print("sk",out.size(), inp.size(), self.conv(inp).size(), self.bn(self.conv(inp)).size)
59
+ # input()
60
+ return self.act_fn(out + self.bn(self.conv(inp)))
61
+
62
+
63
+ class InceptionBackbone(nn.Module):
64
+ def __init__(self, input_channels, kss, depth, bottleneck_size, nb_filters, use_residual):
65
+ super().__init__()
66
+
67
+ self.depth = depth
68
+ assert ((depth % 3) == 0)
69
+ self.use_residual = use_residual
70
+
71
+ n_ks = len(kss) + 1
72
+ self.im = nn.ModuleList([InceptionBlock1d(input_channels if d == 0 else n_ks * nb_filters,
73
+ nb_filters=nb_filters, kss=kss,
74
+ bottleneck_size=bottleneck_size) for d in range(depth)])
75
+ self.sk = nn.ModuleList(
76
+ [Shortcut1d(input_channels if d == 0 else n_ks * nb_filters, n_ks * nb_filters) for d in
77
+ range(depth // 3)])
78
+
79
+ def forward(self, x):
80
+
81
+ input_res = x
82
+ for d in range(self.depth):
83
+ x = self.im[d](x)
84
+ if self.use_residual and d % 3 == 2:
85
+ x = (self.sk[d // 3])(input_res, x)
86
+ input_res = x.clone()
87
+ return x
88
+
89
+
90
+ class Inception1d(nn.Module):
91
+ """inception time architecture"""
92
+
93
+ def __init__(self, num_classes=2, input_channels=8, kernel_size=40, depth=6, bottleneck_size=32, nb_filters=32,
94
+ use_residual=True, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu",
95
+ concat_pooling=True):
96
+ super().__init__()
97
+ assert (kernel_size >= 40)
98
+ kernel_size = [k - 1 if k % 2 == 0 else k for k in
99
+ [kernel_size, kernel_size // 2, kernel_size // 4]] # was 39,19,9
100
+
101
+ layers = [InceptionBackbone(input_channels=input_channels, kss=kernel_size, depth=depth,
102
+ bottleneck_size=bottleneck_size, nb_filters=nb_filters,
103
+ use_residual=use_residual)]
104
+
105
+ n_ks = len(kernel_size) + 1
106
+ # head
107
+ head = create_head1d(n_ks * nb_filters, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
108
+ bn_final=bn_final_head, bn=bn_head, act=act_head,
109
+ concat_pooling=concat_pooling)
110
+ layers.append(head)
111
+ # layers.append(AdaptiveConcatPool1d())
112
+ # layers.append(Flatten())
113
+ # layers.append(nn.Linear(2*n_ks*nb_filters, num_classes))
114
+ self.layers = nn.Sequential(*layers)
115
+
116
+ def forward(self, x):
117
+ return self.layers(x)
118
+
119
+ def get_layer_groups(self):
120
+ depth = self.layers[0].depth
121
+ if depth > 3:
122
+ return (self.layers[0].im[3:], self.layers[0].sk[1:]), self.layers[-1]
123
+ else:
124
+ return self.layers[-1]
125
+
126
+ def get_output_layer(self):
127
+ return self.layers[-1][-1]
128
+
129
+ def set_output_layer(self, x):
130
+ self.layers[-1][-1] = x
131
+
132
+
133
+ def inception1d(**kwargs):
134
+ """
135
+ Constructs an Inception model
136
+ """
137
+ return Inception1d(**kwargs)
models/resnet1d.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ # Standard resnet
5
+ from models.basicconv1d import create_head1d
6
+
7
+
8
+ def conv(in_planes, out_planes, stride=1, kernel_size=3):
9
+ """convolution with padding"""
10
+ return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
11
+ padding=(kernel_size - 1) // 2, bias=False)
12
+
13
+
14
+ class BasicBlock1d(nn.Module):
15
+ expansion = 1
16
+
17
+ def __init__(self, inplanes, planes, stride=1, kernel_size=None, down_sample=None):
18
+ if kernel_size is None:
19
+ kernel_size = [3, 3]
20
+ super().__init__()
21
+
22
+ if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size // 2 + 1]
23
+
24
+ self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=kernel_size[0])
25
+ self.bn1 = nn.BatchNorm1d(planes)
26
+ self.relu = nn.ReLU(inplace=True)
27
+ self.conv2 = conv(planes, planes, kernel_size=kernel_size[1])
28
+ self.bn2 = nn.BatchNorm1d(planes)
29
+ self.down_sample = down_sample
30
+ self.stride = stride
31
+
32
+ def forward(self, x):
33
+ residual = x
34
+
35
+ out = self.conv1(x)
36
+ out = self.bn1(out)
37
+ out = self.relu(out)
38
+
39
+ out = self.conv2(out)
40
+ out = self.bn2(out)
41
+
42
+ if self.down_sample is not None:
43
+ residual = self.down_sample(x)
44
+
45
+ out += residual
46
+ out = self.relu(out)
47
+
48
+ return out
49
+
50
+
51
+ class Bottleneck1d(nn.Module):
52
+ expansion = 4
53
+
54
+ def __init__(self, inplanes, planes, stride=1, kernel_size=3, down_sample=None):
55
+ super().__init__()
56
+
57
+ self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
58
+ self.bn1 = nn.BatchNorm1d(planes)
59
+ self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=stride,
60
+ padding=(kernel_size - 1) // 2, bias=False)
61
+ self.bn2 = nn.BatchNorm1d(planes)
62
+ self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False)
63
+ self.bn3 = nn.BatchNorm1d(planes * 4)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ self.down_sample = down_sample
66
+ self.stride = stride
67
+
68
+ def forward(self, x):
69
+ residual = x
70
+
71
+ out = self.conv1(x)
72
+ out = self.bn1(out)
73
+ out = self.relu(out)
74
+
75
+ out = self.conv2(out)
76
+ out = self.bn2(out)
77
+ out = self.relu(out)
78
+
79
+ out = self.conv3(out)
80
+ out = self.bn3(out)
81
+
82
+ if self.down_sample is not None:
83
+ residual = self.down_sample(x)
84
+
85
+ out += residual
86
+ out = self.relu(out)
87
+
88
+ return out
89
+
90
+
91
+ class ResNet1d(nn.Sequential):
92
+ """1d adaptation of the torchvision resnet"""
93
+
94
+ def __init__(self, block, layers, kernel_size=3, num_classes=2, input_channels=3, inplanes=64, fix_feature_dim=True,
95
+ kernel_size_stem=None, stride_stem=2, pooling_stem=True, stride=2, lin_ftrs_head=None, ps_head=0.5,
96
+ bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
97
+ self.inplanes = inplanes
98
+
99
+ layers_tmp = []
100
+
101
+ if kernel_size_stem is None:
102
+ kernel_size_stem = kernel_size[0] if isinstance(kernel_size, list) else kernel_size
103
+ # stem
104
+ layers_tmp.append(nn.Conv1d(input_channels, inplanes, kernel_size=kernel_size_stem, stride=stride_stem,
105
+ padding=(kernel_size_stem - 1) // 2, bias=False))
106
+ layers_tmp.append(nn.BatchNorm1d(inplanes))
107
+ layers_tmp.append(nn.ReLU(inplace=True))
108
+ if pooling_stem is True:
109
+ layers_tmp.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1))
110
+ # backbone
111
+ for i, l in enumerate(layers):
112
+ if i == 0:
113
+ layers_tmp.append(self._make_layer(block, inplanes, layers[0], kernel_size=kernel_size))
114
+ else:
115
+ layers_tmp.append(
116
+ self._make_layer(block, inplanes if fix_feature_dim else (2 ** i) * inplanes, layers[i],
117
+ stride=stride, kernel_size=kernel_size))
118
+
119
+ # head
120
+ # layers_tmp.append(nn.AdaptiveAvgPool1d(1))
121
+ # layers_tmp.append(Flatten())
122
+ # layers_tmp.append(nn.Linear((inplanes if fix_feature_dim else (2**len(layers)*inplanes)) * block.expansion, num_classes))
123
+
124
+ head = create_head1d(
125
+ (inplanes if fix_feature_dim else (2 ** len(layers) * inplanes)) * block.expansion, nc=num_classes,
126
+ lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head,
127
+ concat_pooling=concat_pooling)
128
+ layers_tmp.append(head)
129
+
130
+ super().__init__()
131
+
132
+ def _make_layer(self, block, planes, blocks, stride=1, kernel_size=3):
133
+ down_sample = None
134
+
135
+ if stride != 1 or self.inplanes != planes * block.expansion:
136
+ down_sample = nn.Sequential(
137
+ nn.Conv1d(self.inplanes, planes * block.expansion,
138
+ kernel_size=1, stride=stride, bias=False),
139
+ nn.BatchNorm1d(planes * block.expansion),
140
+ )
141
+
142
+ layers = [block(self.inplanes, planes, stride, kernel_size, down_sample)]
143
+ self.inplanes = planes * block.expansion
144
+ for i in range(1, blocks):
145
+ layers.append(block(self.inplanes, planes))
146
+
147
+ return nn.Sequential(*layers)
148
+
149
+ def get_layer_groups(self):
150
+ return self[6], self[-1]
151
+
152
+ def get_output_layer(self):
153
+ return self[-1][-1]
154
+
155
+ def set_output_layer(self, x):
156
+ self[-1][-1] = x
157
+
158
+
159
+ def resnet1d18(**kwargs):
160
+ """
161
+ Constructs a ResNet-18 model.
162
+ """
163
+ return ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs)
164
+
165
+
166
+ def resnet1d34(**kwargs):
167
+ """
168
+ Constructs a ResNet-34 model.
169
+ """
170
+ return ResNet1d(BasicBlock1d, [3, 4, 6, 3], **kwargs)
171
+
172
+
173
+ def resnet1d50(**kwargs):
174
+ """
175
+ Constructs a ResNet-50 model.
176
+ """
177
+ return ResNet1d(Bottleneck1d, [3, 4, 6, 3], **kwargs)
178
+
179
+
180
+ def resnet1d101(**kwargs):
181
+ """
182
+ Constructs a ResNet-101 model.
183
+ """
184
+ return ResNet1d(Bottleneck1d, [3, 4, 23, 3], **kwargs)
185
+
186
+
187
+ def resnet1d152(**kwargs):
188
+ """
189
+ Constructs a ResNet-152 model.
190
+ """
191
+ return ResNet1d(Bottleneck1d, [3, 8, 36, 3], **kwargs)
192
+
193
+
194
+ # original used kernel_size_stem = 8
195
+ def resnet1d_wang(**kwargs):
196
+ if not ("kernel_size" in kwargs.keys()):
197
+ kwargs["kernel_size"] = [5, 3]
198
+ if not ("kernel_size_stem" in kwargs.keys()):
199
+ kwargs["kernel_size_stem"] = 7
200
+ if not ("stride_stem" in kwargs.keys()):
201
+ kwargs["stride_stem"] = 1
202
+ if not ("pooling_stem" in kwargs.keys()):
203
+ kwargs["pooling_stem"] = False
204
+ if not ("inplanes" in kwargs.keys()):
205
+ kwargs["inplanes"] = 128
206
+
207
+ return ResNet1d(BasicBlock1d, [1, 1, 1], **kwargs)
208
+
209
+
210
+ def resnet1d(**kwargs):
211
+ """
212
+ Constructs a custom ResNet model.
213
+ """
214
+ return ResNet1d(BasicBlock1d, **kwargs)
215
+
216
+
217
+ # wide resnet adopted from fastai wrn
218
+
219
+ def noop(x): return x
220
+
221
+
222
+ def conv1d(ni: int, nf: int, ks: int = 3, stride: int = 1, padding: int = None, bias=False) -> nn.Conv1d:
223
+ "Create `nn.Conv1d` layer: `ni` inputs, `nf` outputs, `ks` kernel size. `padding` defaults to `k//2`."
224
+ if padding is None: padding = ks // 2
225
+ return nn.Conv1d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias)
226
+
227
+
228
+ def _bn1d(ni, init_zero=False):
229
+ "Batchnorm layer with 0 initialization"
230
+ m = nn.BatchNorm1d(ni)
231
+ m.weight.data.fill_(0 if init_zero else 1)
232
+ m.bias.data.zero_()
233
+ return m
234
+
235
+
236
+ def bn_relu_conv1d(ni, nf, ks, stride, init_zero=False):
237
+ bn_initzero = _bn1d(ni, init_zero=init_zero)
238
+ return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv1d(ni, nf, ks, stride))
239
+
240
+
241
+ class BasicBlock1dwrn(nn.Module):
242
+ def __init__(self, ni, nf, stride, drop_p=0.0, ks=3):
243
+ super().__init__()
244
+ if isinstance(ks, int):
245
+ ks = [ks, ks // 2 + 1]
246
+ self.bn = nn.BatchNorm1d(ni)
247
+ self.conv1 = conv1d(ni, nf, ks[0], stride)
248
+ self.conv2 = bn_relu_conv1d(nf, nf, ks[0], 1)
249
+ self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
250
+ self.shortcut = conv1d(ni, nf, ks[1], stride) if (
251
+ ni != nf or stride > 1) else noop # adapted to make it work for fix_feature_dim=True
252
+
253
+ def forward(self, x):
254
+ x2 = F.relu(self.bn(x), inplace=True)
255
+ r = self.shortcut(x2)
256
+ x = self.conv1(x2)
257
+ if self.drop: x = self.drop(x)
258
+ x = self.conv2(x) * 0.2
259
+ return x.add_(r)
260
+
261
+
262
+ def _make_group(N, ni, nf, block, stride, drop_p, ks=3):
263
+ return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p, ks=ks) for i in range(N)]
264
+
265
+
266
+ class WideResNet1d(nn.Sequential):
267
+ def __init__(self, input_channels: int, num_groups: int, N: int, num_classes: int, k: int = 1, drop_p: float = 0.0,
268
+ start_nf: int = 16, fix_feature_dim=True, kernel_size=5, lin_ftrs_head=None, ps_head=0.5,
269
+ bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
270
+ super().__init__()
271
+ n_channels = [start_nf]
272
+
273
+ for i in range(num_groups): n_channels.append(start_nf if fix_feature_dim else start_nf * (2 ** i) * k)
274
+
275
+ layers = [conv1d(input_channels, n_channels[0], 3, 1)] # conv1 stem
276
+ for i in range(num_groups):
277
+ layers += _make_group(N, n_channels[i], n_channels[i + 1], BasicBlock1dwrn,
278
+ (1 if i == 0 else 2), drop_p, ks=kernel_size)
279
+
280
+ # layers += [nn.BatchNorm1d(n_channels[-1]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool1d(1),
281
+ # Flatten(), nn.Linear(n_channels[-1], num_classes)]
282
+ head = create_head1d(n_channels[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
283
+ bn_final=bn_final_head, bn=bn_head, act=act_head,
284
+ concat_pooling=concat_pooling)
285
+ layers.append(head)
286
+
287
+ super().__init__()
288
+
289
+ def get_layer_groups(self):
290
+ return self[6], self[-1]
291
+
292
+ def get_output_layer(self):
293
+ return self[-1][-1]
294
+
295
+ def set_output_layer(self, x):
296
+ self[-1][-1] = x
297
+
298
+
299
+ def wrn1d_22(**kwargs): return WideResNet1d(num_groups=3, N=3, k=6, drop_p=0., **kwargs)
models/rnn1d.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.layers import *
4
+ from fastai.data.core import *
5
+
6
+
7
+ class AdaptiveConcatPoolRNN(nn.Module):
8
+ def __init__(self, bidirectional):
9
+ super().__init__()
10
+ self.bidirectional = bidirectional
11
+
12
+ def forward(self, x):
13
+ # input shape bs, ch, ts
14
+ t1 = nn.AdaptiveAvgPool1d(1)(x)
15
+ t2 = nn.AdaptiveMaxPool1d(1)(x)
16
+
17
+ if self.bidirectional is False:
18
+ t3 = x[:, :, -1]
19
+ else:
20
+ channels = x.size()[1]
21
+ t3 = torch.cat([x[:, :channels, -1], x[:, channels:, 0]], 1)
22
+ out = torch.cat([t1.squeeze(-1), t2.squeeze(-1), t3], 1) # output shape bs, 3*ch
23
+ return out
24
+
25
+
26
+ class RNN1d(nn.Sequential):
27
+ def __init__(self, input_channels, num_classes, lstm=True, hidden_dim=256, num_layers=2, bidirectional=False,
28
+ ps_head=0.5, act_head="relu", lin_ftrs_head=None, bn=True):
29
+ # bs, ch, ts -> ts, bs, ch
30
+ layers_tmp = [Lambda(lambda x: x.transpose(1, 2)), Lambda(lambda x: x.transpose(0, 1))]
31
+ # LSTM
32
+ if lstm:
33
+ layers_tmp.append(nn.LSTM(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers,
34
+ bidirectional=bidirectional))
35
+ else:
36
+ layers_tmp.append(nn.GRU(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers,
37
+ bidirectional=bidirectional))
38
+ # pooling
39
+ layers_tmp.append(Lambda(lambda x: x[0].transpose(0, 1)))
40
+ layers_tmp.append(Lambda(lambda x: x.transpose(1, 2)))
41
+
42
+ layers_head = [AdaptiveConcatPoolRNN(bidirectional)]
43
+
44
+ # classifier
45
+ nf = 3 * hidden_dim if bidirectional is False else 6 * hidden_dim
46
+ lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
47
+ ps_head = listify(ps_head)
48
+ if len(ps_head) == 1:
49
+ ps_head = [ps_head[0] / 2] * (len(lin_ftrs_head) - 2) + ps_head
50
+ actns = [nn.ReLU(inplace=True) if act_head == "relu" else nn.ELU(inplace=True)] * (
51
+ len(lin_ftrs_head) - 2) + [None]
52
+
53
+ for ni, no, p, actn in zip(lin_ftrs_head[:-1], lin_ftrs_head[1:], ps_head, actns):
54
+ layers_head += bn_drop_lin(ni, no, bn, p, actn)
55
+ layers_head = nn.Sequential(*layers_head)
56
+ layers_tmp.append(layers_head)
57
+
58
+ super().__init__()
59
+
60
+ def get_layer_groups(self):
61
+ return self[-1],
62
+
63
+ def get_output_layer(self):
64
+ return self[-1][-1]
65
+
66
+ def set_output_layer(self, x):
67
+ self[-1][-1] = x
models/wavelet.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.linear_model import LogisticRegression
2
+ from sklearn.multiclass import OneVsRestClassifier
3
+ from models.base_model import ClassificationModel
4
+ import pickle
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from sklearn.ensemble import RandomForestClassifier
8
+ import pywt
9
+ import scipy.stats
10
+ import multiprocessing
11
+ from collections import Counter
12
+ from keras.layers import Dropout, Dense, Input
13
+ from keras.models import Model
14
+ from keras.models import load_model
15
+ from keras.callbacks import ModelCheckpoint
16
+ from sklearn.preprocessing import StandardScaler
17
+
18
+
19
+ def calculate_entropy(list_values):
20
+ counter_values = Counter(list_values).most_common()
21
+ probabilities = [elem[1] / len(list_values) for elem in counter_values]
22
+ entropy = scipy.stats.entropy(probabilities)
23
+ return entropy
24
+
25
+
26
+ def calculate_statistics(list_values):
27
+ n5 = np.nanpercentile(list_values, 5)
28
+ n25 = np.nanpercentile(list_values, 25)
29
+ n75 = np.nanpercentile(list_values, 75)
30
+ n95 = np.nanpercentile(list_values, 95)
31
+ median = np.nanpercentile(list_values, 50)
32
+ mean = np.nanmean(list_values)
33
+ std = np.nanstd(list_values)
34
+ var = np.nanvar(list_values)
35
+ rms = np.nanmean(np.sqrt(list_values ** 2))
36
+ return [n5, n25, n75, n95, median, mean, std, var, rms]
37
+
38
+
39
+ def calculate_crossings(list_values):
40
+ zero_crossing_indices = np.nonzero(np.diff(np.array(list_values) > 0))[0]
41
+ no_zero_crossings = len(zero_crossing_indices)
42
+ mean_crossing_indices = np.nonzero(np.diff(np.array(list_values) > np.nanmean(list_values)))[0]
43
+ no_mean_crossings = len(mean_crossing_indices)
44
+ return [no_zero_crossings, no_mean_crossings]
45
+
46
+
47
+ def get_features(list_values):
48
+ entropy = calculate_entropy(list_values)
49
+ crossings = calculate_crossings(list_values)
50
+ statistics = calculate_statistics(list_values)
51
+ return [entropy] + crossings + statistics
52
+
53
+
54
+ def get_single_ecg_features(signal, waveletname='db6'):
55
+ features = []
56
+ for channel in signal.T:
57
+ list_coeff = pywt.wavedec(channel, wavelet=waveletname, level=5)
58
+ channel_features = []
59
+ for coeff in list_coeff:
60
+ channel_features += get_features(coeff)
61
+ features.append(channel_features)
62
+ return np.array(features).flatten()
63
+
64
+
65
+ def get_ecg_features(ecg_data, parallel=True):
66
+ if parallel:
67
+ pool = multiprocessing.Pool(18)
68
+ return np.array(pool.map(get_single_ecg_features, ecg_data))
69
+ else:
70
+ list_features = []
71
+ for signal in tqdm(ecg_data):
72
+ features = get_single_ecg_features(signal)
73
+ list_features.append(features)
74
+ return np.array(list_features)
75
+
76
+
77
+ # for keras models
78
+ # def keras_macro_auroc(y_true, y_pred):
79
+ # return tf.py_func(macro_auroc, (y_true, y_pred), tf.double)
80
+
81
+ class WaveletModel(ClassificationModel):
82
+ def __init__(self, name, n_classes, freq, outputfolder, input_shape, regularizer_C=.001, classifier='RF'):
83
+ # Disclaimer: This model assumes equal shapes across all samples!
84
+ # standard parameters
85
+ super().__init__()
86
+ self.name = name
87
+ self.outputfolder = outputfolder
88
+ self.n_classes = n_classes
89
+ self.freq = freq
90
+ self.regularizer_C = regularizer_C
91
+ self.classifier = classifier
92
+ self.dropout = .25
93
+ self.activation = 'relu'
94
+ self.final_activation = 'sigmoid'
95
+ self.n_dense_dim = 128
96
+ self.epochs = 30
97
+
98
+ def fit(self, X_train, y_train, X_val, y_val):
99
+ XF_train = get_ecg_features(X_train)
100
+ XF_val = get_ecg_features(X_val)
101
+
102
+ if self.classifier == 'LR':
103
+ if self.n_classes > 1:
104
+ clf = OneVsRestClassifier(
105
+ LogisticRegression(C=self.regularizer_C, solver='lbfgs', max_iter=1000, n_jobs=-1))
106
+ else:
107
+ clf = LogisticRegression(C=self.regularizer_C, solver='lbfgs', max_iter=1000, n_jobs=-1)
108
+ clf.fit(XF_train, y_train)
109
+ pickle.dump(clf, open(self.outputfolder + 'clf.pkl', 'wb'))
110
+ elif self.classifier == 'RF':
111
+ clf = RandomForestClassifier(n_estimators=1000, n_jobs=16)
112
+ clf.fit(XF_train, y_train)
113
+ pickle.dump(clf, open(self.outputfolder + 'clf.pkl', 'wb'))
114
+ elif self.classifier == 'NN':
115
+ # standardize input data
116
+ ss = StandardScaler()
117
+ XFT_train = ss.fit_transform(XF_train)
118
+ XFT_val = ss.transform(XF_val)
119
+ pickle.dump(ss, open(self.outputfolder + 'ss.pkl', 'wb'))
120
+ # classification stage
121
+ input_x = Input(shape=(XFT_train.shape[1],))
122
+ x = Dense(self.n_dense_dim, activation=self.activation)(input_x)
123
+ x = Dropout(self.dropout)(x)
124
+ y = Dense(self.n_classes, activation=self.final_activation)(x)
125
+ self.model = Model(input_x, y)
126
+
127
+ self.model.compile(optimizer='adamax', loss='binary_crossentropy') # , metrics=[keras_macro_auroc])
128
+ # monitor validation error
129
+ mc_loss = ModelCheckpoint(self.outputfolder + 'best_loss_model.h5', monitor='val_loss', mode='min',
130
+ verbose=1, save_best_only=True)
131
+ # mc_score = ModelCheckpoint(self.output_folder +'best_score_model.h5', monitor='val_keras_macro_auroc', mode='max', verbose=1, save_best_only=True)
132
+ self.model.fit(XFT_train, y_train, validation_data=(XFT_val, y_val), epochs=self.epochs, batch_size=128,
133
+ callbacks=[mc_loss]) # , mc_score])
134
+ self.model.save(self.outputfolder + 'last_model.h5')
135
+
136
+ def predict(self, X):
137
+ XF = get_ecg_features(X)
138
+ if self.classifier == 'LR':
139
+ clf = pickle.load(open(self.outputfolder + 'clf.pkl', 'rb'))
140
+ if self.n_classes > 1:
141
+ return clf.predict_proba(XF)
142
+ else:
143
+ return clf.predict_proba(XF)[:, 1][:, np.newaxis]
144
+ elif self.classifier == 'RF':
145
+ clf = pickle.load(open(self.outputfolder + 'clf.pkl', 'rb'))
146
+ y_pred = clf.predict_proba(XF)
147
+ if self.n_classes > 1:
148
+ return np.array([yi[:, 1] for yi in y_pred]).T
149
+ else:
150
+ return y_pred[:, 1][:, np.newaxis]
151
+ elif self.classifier == 'NN':
152
+ ss = pickle.load(open(self.outputfolder + 'ss.pkl', 'rb')) #
153
+ XFT = ss.transform(XF)
154
+ model = load_model(
155
+ self.outputfolder + 'best_loss_model.h5')
156
+ # 'best_score_model.h5', custom_objects={
157
+ # 'keras_macro_auroc': keras_macro_auroc})
158
+ return model.predict(XFT)
models/xresnet1d.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from enum import Enum
4
+ import re
5
+ # delegates
6
+ import inspect
7
+
8
+ from torch.nn.utils import weight_norm, spectral_norm
9
+
10
+ from models.basicconv1d import create_head1d
11
+
12
+
13
+ def delegates(to=None, keep=False):
14
+ """Decorator: replace `**kwargs` in signature with params from `to`"""
15
+
16
+ def _f(f):
17
+ if to is None:
18
+ to_f, from_f = f.__base__.__init__, f.__init__
19
+ else:
20
+ to_f, from_f = to, f
21
+ sig = inspect.signature(from_f)
22
+ sigd = dict(sig.parameters)
23
+ k = sigd.pop('kwargs')
24
+ s2 = {k: v for k, v in inspect.signature(to_f).parameters.items()
25
+ if v.default != inspect.Parameter.empty and k not in sigd}
26
+ sigd.update(s2)
27
+ if keep: sigd['kwargs'] = k
28
+ from_f.__signature__ = sig.replace(parameters=sigd.values())
29
+ return f
30
+
31
+ return _f
32
+
33
+
34
+ def store_attr(self, nms):
35
+ """Store params named in comma-separated `nms` from calling context into attrs in `self`"""
36
+ mod = inspect.currentframe().f_back.f_locals
37
+ for n in re.split(', *', nms): setattr(self, n, mod[n])
38
+
39
+
40
+ NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero')
41
+
42
+
43
+ def _conv_func(ndim=2, transpose=False):
44
+ """Return the proper conv `ndim` function, potentially `transposed`."""
45
+ assert 1 <= ndim <= 3
46
+ return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d')
47
+
48
+
49
+ def init_default(m, func=nn.init.kaiming_normal_):
50
+ """Initialize `m` weights with `func` and set `bias` to 0."""
51
+ if func and hasattr(m, 'weight'): func(m.weight)
52
+ with torch.no_grad():
53
+ if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)
54
+ return m
55
+
56
+
57
+ def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs):
58
+ """Norm layer with `nf` features and `ndim` initialized depending on `norm_type`."""
59
+ assert 1 <= ndim <= 3
60
+ bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs)
61
+ if bn.affine:
62
+ bn.bias.data.fill_(1e-3)
63
+ bn.weight.data.fill_(0. if zero else 1.)
64
+ return bn
65
+
66
+
67
+ def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs):
68
+ """BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."""
69
+ return _get_norm('BatchNorm', nf, ndim, zero=norm_type == NormType.BatchZero, **kwargs)
70
+
71
+
72
+ class ConvLayer(nn.Sequential):
73
+ """Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers."""
74
+
75
+ def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,
76
+ act_cls=nn.ReLU, transpose=False, init=nn.init.kaiming_normal_, xtra=None, **kwargs):
77
+ if padding is None: padding = ((ks - 1) // 2 if not transpose else 0)
78
+ bn = norm_type in (NormType.Batch, NormType.BatchZero)
79
+ inn = norm_type in (NormType.Instance, NormType.InstanceZero)
80
+ if bias is None: bias = not (bn or inn)
81
+ conv_func = _conv_func(ndim, transpose=transpose)
82
+ conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs),
83
+ init)
84
+ if norm_type == NormType.Weight:
85
+ conv = weight_norm(conv)
86
+ elif norm_type == NormType.Spectral:
87
+ conv = spectral_norm(conv)
88
+ layers = [conv]
89
+ act_bn = []
90
+ if act_cls is not None: act_bn.append(act_cls())
91
+ if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))
92
+ if inn: act_bn.append(nn.InstanceNorm2d(nf, norm_type=norm_type, ndim=ndim))
93
+ if bn_1st: act_bn.reverse()
94
+ layers += act_bn
95
+ if xtra: layers.append(xtra)
96
+ super().__init__()
97
+
98
+
99
+ def AdaptiveAvgPool(sz=1, ndim=2):
100
+ """nn.AdaptiveAvgPool layer for `ndim`"""
101
+ assert 1 <= ndim <= 3
102
+ return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz)
103
+
104
+
105
+ def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
106
+ """nn.MaxPool layer for `ndim`"""
107
+ assert 1 <= ndim <= 3
108
+ return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding)
109
+
110
+
111
+ def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
112
+ """nn.AvgPool layer for `ndim`"""
113
+ assert 1 <= ndim <= 3
114
+ return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)
115
+
116
+
117
+ class ResBlock(nn.Module):
118
+ "Resnet block from `ni` to `nh` with `stride`"
119
+
120
+ @delegates(ConvLayer.__init__)
121
+ def __init__(self, expansion, ni, nf, stride=1, kernel_size=3, groups=1, reduction=None, nh1=None, nh2=None,
122
+ dw=False, g2=1,
123
+ sa=False, sym=False, norm_type=NormType.Batch, act_cls=nn.ReLU, ndim=2,
124
+ pool=AvgPool, pool_first=True, **kwargs):
125
+ super().__init__()
126
+ norm2 = (NormType.BatchZero if norm_type == NormType.Batch else
127
+ NormType.InstanceZero if norm_type == NormType.Instance else norm_type)
128
+ if nh2 is None: nh2 = nf
129
+ if nh1 is None: nh1 = nh2
130
+ nf, ni = nf * expansion, ni * expansion
131
+ k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs)
132
+ k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
133
+ layers = [ConvLayer(ni, nh2, kernel_size, stride=stride, groups=ni if dw else groups, **k0),
134
+ ConvLayer(nh2, nf, kernel_size, groups=g2, **k1)
135
+ ] if expansion == 1 else [
136
+ ConvLayer(ni, nh1, 1, **k0),
137
+ ConvLayer(nh1, nh2, kernel_size, stride=stride, groups=nh1 if dw else groups, **k0),
138
+ ConvLayer(nh2, nf, 1, groups=g2, **k1)]
139
+ self.convs = nn.Sequential(*layers)
140
+ convpath = [self.convs]
141
+ if reduction: convpath.append(nn.SEModule(nf, reduction=reduction, act_cls=act_cls))
142
+ if sa: convpath.append(nn.SimpleSelfAttention(nf, ks=1, sym=sym))
143
+ self.convpath = nn.Sequential(*convpath)
144
+ idpath = []
145
+ if ni != nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs))
146
+ if stride != 1: idpath.insert((1, 0)[pool_first], pool(2, ndim=ndim, ceil_mode=True))
147
+ self.idpath = nn.Sequential(*idpath)
148
+ self.act = nn.ReLU(inplace=True) if act_cls is nn.ReLU else act_cls()
149
+
150
+ def forward(self, x):
151
+ return self.act(self.convpath(x) + self.idpath(x))
152
+
153
+
154
+ ######################### adapted from vison.models.xresnet
155
+ def init_cnn(m):
156
+ if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
157
+ if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight)
158
+ for l in m.children(): init_cnn(l)
159
+
160
+
161
+ class XResNet1d(nn.Sequential):
162
+ @delegates(ResBlock)
163
+ def __init__(self, block, expansion, layers, p=0.0, input_channels=3, num_classes=1000, stem_szs=(32, 32, 64),
164
+ kernel_size=5, kernel_size_stem=5,
165
+ widen=1.0, sa=False, act_cls=nn.ReLU, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False,
166
+ bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
167
+ store_attr(self, 'block,expansion,act_cls')
168
+ stem_szs = [input_channels, *stem_szs]
169
+ stem = [ConvLayer(stem_szs[i], stem_szs[i + 1], ks=kernel_size_stem, stride=2 if i == 0 else 1, act_cls=act_cls,
170
+ ndim=1)
171
+ for i in range(3)]
172
+
173
+ # block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
174
+ block_szs = [int(o * widen) for o in [64, 64, 64, 64] + [32] * (len(layers) - 4)]
175
+ block_szs = [64 // expansion] + block_szs
176
+ blocks = [self._make_layer(ni=block_szs[i], nf=block_szs[i + 1], blocks=l,
177
+ stride=1 if i == 0 else 2, kernel_size=kernel_size, sa=sa and i == len(layers) - 4,
178
+ ndim=1, **kwargs)
179
+ for i, l in enumerate(layers)]
180
+
181
+ head = create_head1d(block_szs[-1] * expansion, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head,
182
+ bn_final=bn_final_head, bn=bn_head, act=act_head,
183
+ concat_pooling=concat_pooling)
184
+
185
+ super().__init__(nn.MaxPool1d(kernel_size=3, stride=2, padding=1), head)
186
+ init_cnn(self)
187
+
188
+ def _make_layer(self, ni, nf, blocks, stride, kernel_size, sa, **kwargs):
189
+ return nn.Sequential(
190
+ *[self.block(self.expansion, ni if i == 0 else nf, nf, stride=stride if i == 0 else 1,
191
+ kernel_size=kernel_size, sa=sa and i == (blocks - 1), act_cls=self.act_cls, **kwargs)
192
+ for i in range(blocks)])
193
+
194
+ def get_layer_groups(self):
195
+ return self[3], self[-1]
196
+
197
+ def get_output_layer(self):
198
+ return self[-1][-1]
199
+
200
+ def set_output_layer(self, x):
201
+ self[-1][-1] = x
202
+
203
+
204
+ # xresnets
205
+ def _xresnet1d(expansion, layers, **kwargs):
206
+ return XResNet1d(ResBlock, expansion, layers, **kwargs)
207
+
208
+
209
+ def xresnet1d18(**kwargs): return _xresnet1d(1, [2, 2, 2, 2], **kwargs)
210
+
211
+
212
+ def xresnet1d34(**kwargs): return _xresnet1d(1, [3, 4, 6, 3], **kwargs)
213
+
214
+
215
+ def xresnet1d50(**kwargs): return _xresnet1d(4, [3, 4, 6, 3], **kwargs)
216
+
217
+
218
+ def xresnet1d101(**kwargs): return _xresnet1d(4, [3, 4, 23, 3], **kwargs)
219
+
220
+
221
+ def xresnet1d152(**kwargs): return _xresnet1d(4, [3, 8, 36, 3], **kwargs)
222
+
223
+
224
+ def xresnet1d18_deep(**kwargs): return _xresnet1d(1, [2, 2, 2, 2, 1, 1], **kwargs)
225
+
226
+
227
+ def xresnet1d34_deep(**kwargs): return _xresnet1d(1, [3, 4, 6, 3, 1, 1], **kwargs)
228
+
229
+
230
+ def xresnet1d50_deep(**kwargs): return _xresnet1d(4, [3, 4, 6, 3, 1, 1], **kwargs)
231
+
232
+
233
+ def xresnet1d18_deeper(**kwargs): return _xresnet1d(1, [2, 2, 1, 1, 1, 1, 1, 1], **kwargs)
234
+
235
+
236
+ def xresnet1d34_deeper(**kwargs): return _xresnet1d(1, [3, 4, 6, 3, 1, 1, 1, 1], **kwargs)
237
+
238
+
239
+ def xresnet1d50_deeper(**kwargs): return _xresnet1d(4, [3, 4, 6, 3, 1, 1, 1, 1], **kwargs)
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch
2
+ fastai
3
+ blis==0.7.5
4
+ cycler==0.11.0
5
+ cymem==2.0.6
6
+ fastcore==1.3.27
7
+ jinja2==3.0.3
8
+ kiwisolver==1.3.2
9
+ markupsafe==2.0.1
10
+ murmurhash==1.0.6
11
+ pathy==0.6.1
12
+ pillow==8.4.0
13
+ preshed==3.0.6
14
+ pyparsing==3.0.5
15
+ smart-open==5.2.1
16
+ srsly==2.4.2
17
+ thinc==8.0.13
18
+ torchvision==0.11.1
19
+ wfdb==3.4.1
20
+ wget==3.2
21
+ scikit-image
22
+ pyWavelets
23
+ kereas
utilities/__pycache__/timeseries_utils.cpython-39.pyc ADDED
Binary file (21 kB). View file
 
utilities/__pycache__/utils.cpython-39.pyc ADDED
Binary file (16.7 kB). View file
 
utilities/stratify.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+
4
+
5
+ def stratify_df(df, new_col_name, n_folds=10, nr_clean_folds=0):
6
+ # compute qualities as described in PTB-XL report
7
+ qualities = []
8
+ for i, row in df.iterrows():
9
+ q = 0
10
+ if 'validated_by_human' in df.columns:
11
+ if row.validated_by_human:
12
+ q = 1
13
+ qualities.append(q)
14
+ df['quality'] = qualities
15
+
16
+ # create stratified folds according to patients
17
+ pat_ids = np.array(sorted(list(set(df.patient_id.values))))
18
+ p_labels = []
19
+ p_qualities = []
20
+ ecgs_per_patient = []
21
+
22
+ for pid in tqdm(pat_ids):
23
+ sel = df[df.patient_id == pid]
24
+ l = np.concatenate([list(d.keys()) for d in sel.scp_codes.values])
25
+ if sel.sex.values[0] == 0:
26
+ gender = 'male'
27
+ else:
28
+ gender = 'female'
29
+ l = np.concatenate((l, [gender] * len(sel)))
30
+ for age in sel.age.values:
31
+ if age < 20:
32
+ l = np.concatenate((l, ['<20']))
33
+ elif 20 <= age < 40:
34
+ l = np.concatenate((l, ['20-40']))
35
+ elif 40 <= age < 60:
36
+ l = np.concatenate((l, ['40-60']))
37
+ elif 60 <= age < 80:
38
+ l = np.concatenate((l, ['60-80']))
39
+ elif age >= 80:
40
+ l = np.concatenate((l, ['>=80']))
41
+ p_labels.append(l)
42
+ ecgs_per_patient.append(len(sel))
43
+ p_qualities.append(sel.quality.min())
44
+ classes = sorted(list(set([item for sublist in p_labels for item in sublist])))
45
+
46
+ stratified_data_ids, stratified_data = stratify(p_labels, classes, [1 / n_folds] * n_folds, p_qualities,
47
+ ecgs_per_patient, nr_clean_folds)
48
+
49
+ df[new_col_name] = np.zeros(len(df)).astype(int)
50
+ for fold_i, fold_ids in tqdm(enumerate(stratified_data_ids)):
51
+ ipat_ids = [pat_ids[pid] for pid in fold_ids]
52
+ df[new_col_name][df.patient_id.isin(ipat_ids)] = fold_i + 1
53
+
54
+ return df
55
+
56
+
57
+ def stratify(data, classes, ratios, qualities, ecgs_per_patient, nr_clean_folds=1):
58
+ """Stratifying procedure. Modified from https://vict0rs.ch/2018/05/24/sample-multilabel-dataset/ (based on Sechidis 2011)
59
+
60
+ data is a list of lists: a list of labels, for each sample.
61
+ Each sample's labels should be ints, if they are one-hot encoded, use one_hot=True
62
+
63
+ classes is the list of classes each label can take
64
+
65
+ ratios is a list, summing to 1, of how the dataset should be split
66
+
67
+ qualities: quality per entry (only >0 can be assigned to clean folds; 4 will always be assigned to final fold)
68
+
69
+ ecgs_per_patient: list with number of ecgs per sample
70
+
71
+ nr_clean_folds: the last nr_clean_folds can only take clean entries
72
+
73
+ """
74
+ np.random.seed(0) # fix the random seed
75
+
76
+ # data is now always a list of lists; len(data) is the number of patients; data[i] is the list of all labels for
77
+ # patient i (possibly multiple identical entries)
78
+
79
+ # size is the number of ecgs
80
+ size = np.sum(ecgs_per_patient)
81
+
82
+ # Organize data per label: for each label l, per_label_data[l] contains the list of patients
83
+ # in data which have this label (potentially multiple identical entries)
84
+ per_label_data = {c: [] for c in classes}
85
+ for i, d in enumerate(data):
86
+ for l in d:
87
+ per_label_data[l].append(i)
88
+
89
+ # In order not to compute lengths each time, they are tracked here.
90
+ subset_sizes = [r * size for r in ratios] # list of subset_sizes in terms of ecgs
91
+ per_label_subset_sizes = {c: [r * len(per_label_data[c]) for r in ratios] for c in
92
+ classes} # dictionary with label: list of subset sizes in terms of patients
93
+
94
+ # For each subset we want, the set of sample-ids which should end up in it
95
+ stratified_data_ids = [set() for _ in range(len(ratios))] # initialize empty
96
+
97
+ # For each sample in the data set
98
+ print("Assigning patients to folds...")
99
+ size_prev = size + 1 # just for output
100
+ while size > 0:
101
+ if int(size_prev / 1000) > int(size / 1000):
102
+ print("Remaining patients/ecgs to distribute:", size, "non-empty labels:",
103
+ np.sum([1 for l, label_data in per_label_data.items() if len(label_data) > 0]))
104
+ size_prev = size
105
+ # Compute |Di|
106
+ lengths = {
107
+ l: len(label_data)
108
+ for l, label_data in per_label_data.items()
109
+ } # dictionary label: number of ecgs with this label that have not been assigned to a fold yet
110
+ try:
111
+ # Find label of smallest |Di|
112
+ label = min({k: v for k, v in lengths.items() if v > 0}, key=lengths.get)
113
+ except ValueError:
114
+ # If the dictionary in `min` is empty we get a Value Error.
115
+ # This can happen if there are unlabeled samples.
116
+ # In this case, `size` would be > 0 but only samples without label would remain.
117
+ # "No label" could be a class in itself: it's up to you to format your data accordingly.
118
+ break
119
+ # For each patient with label `label` get patient and corresponding counts
120
+ unique_samples, unique_counts = np.unique(per_label_data[label], return_counts=True)
121
+ idxs_sorted = np.argsort(unique_counts, kind='stable')[::-1]
122
+ unique_samples = unique_samples[
123
+ idxs_sorted] # this is a list of all patient ids with this label sort by size descending
124
+ unique_counts = unique_counts[idxs_sorted] # these are the corresponding counts
125
+
126
+ # loop through all patient ids with this label
127
+ for current_id, current_count in zip(unique_samples, unique_counts):
128
+
129
+ subset_sizes_for_label = per_label_subset_sizes[label] # current subset sizes for the chosen label
130
+
131
+ # if quality is bad remove clean folds (i.e. sample cannot be assigned to clean folds)
132
+ if qualities[current_id] < 1:
133
+ subset_sizes_for_label = subset_sizes_for_label[:len(ratios) - nr_clean_folds]
134
+
135
+ # Find argmax clj i.e. subset in greatest need of the current label
136
+ largest_subsets = np.argwhere(subset_sizes_for_label == np.amax(subset_sizes_for_label)).flatten()
137
+
138
+ # if there is a single best choice: assign it
139
+ if len(largest_subsets) == 1:
140
+ subset = largest_subsets[0]
141
+ # If there is more than one such subset, find the one in greatest need of any label
142
+ else:
143
+ largest_subsets2 = np.argwhere(np.array(subset_sizes)[largest_subsets] == np.amax(
144
+ np.array(subset_sizes)[largest_subsets])).flatten()
145
+ subset = largest_subsets[np.random.choice(largest_subsets2)]
146
+
147
+ # Store the sample's id in the selected subset
148
+ stratified_data_ids[subset].add(current_id)
149
+
150
+ # There is current_count fewer samples to distribute
151
+ size -= ecgs_per_patient[current_id]
152
+ # The selected subset needs current_count fewer samples
153
+ subset_sizes[subset] -= ecgs_per_patient[current_id]
154
+
155
+ # In the selected subset, there is one more example for each label
156
+ # the current sample has
157
+ for l in data[current_id]:
158
+ per_label_subset_sizes[l][subset] -= 1
159
+
160
+ # Remove the sample from the dataset, meaning from all per_label dataset created
161
+ for x in per_label_data.keys():
162
+ per_label_data[x] = [y for y in per_label_data[x] if y != current_id]
163
+
164
+ # Create the stratified dataset as a list of subsets, each containing the original labels
165
+ stratified_data_ids = [sorted(strat) for strat in stratified_data_ids]
166
+ stratified_data = [
167
+ [data[i] for i in strat] for strat in stratified_data_ids
168
+ ]
169
+
170
+ # Return both the stratified indexes, to be used to sample the `features` associated with your labels
171
+ # And the stratified labels dataset
172
+
173
+ return stratified_data_ids, stratified_data
utilities/timeseries_utils.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from torch import nn
5
+ from pathlib import Path
6
+ from scipy.stats import iqr
7
+ import os
8
+ #Note: due to issues with the numpy rng for multiprocessing
9
+ #(https://github.com/pytorch/pytorch/issues/5059) that could be
10
+ #fixed by a custom worker_init_fn we use random throughout for convenience
11
+ import random
12
+ from skimage import transform
13
+ import warnings
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ from scipy.signal import butter, sosfilt, sosfiltfilt, sosfreqz
16
+ #https://stackoverflow.com/questions/12093594/how-to-implement-band-pass-butterworth-filter-with-scipy-signal-butter
17
+
18
+ def butter_filter(lowcut=10, highcut=20, fs=50, order=5, btype='band'): # image processing Butterworth filter
19
+ """returns butterworth filter with given specifications"""
20
+ nyq = 0.5 * fs
21
+ low = lowcut / nyq
22
+ high = highcut / nyq
23
+
24
+ sos = butter(order, [low, high] if btype == "band" else (low if btype == "low" else high), analog=False,
25
+ btype=btype, output='sos')
26
+ return sos
27
+
28
+
29
+ def butter_filter_frequency_response(filter):
30
+ """returns frequency response of a given filter (result of call of butter_filter)"""
31
+ w, h = sosfreqz(filter)
32
+ # gain vs. freq(Hz)
33
+ # plt.plot((fs * 0.5 / np.pi) * w, abs(h))
34
+ return w, h
35
+
36
+
37
+ def apply_butter_filter(data, filter, forwardbackward=True): # The function provides options for handling the edges of the signal.
38
+ """pass filter from call of butter_filter to data (assuming time axis at dimension 0)"""
39
+ if forwardbackward:
40
+ return sosfiltfilt(filter, data, axis=0)
41
+ else:
42
+ data = sosfilt(filter, data, axis=0)
43
+
44
+
45
+ def dataset_add_chunk_col(df, col="data"):
46
+ """add a chunk column to the dataset df"""
47
+ df["chunk"] = df.groupby(col).cumcount()
48
+
49
+
50
+ def dataset_add_length_col(df, col="data", data_folder=None):
51
+ """add a length column to the dataset df"""
52
+ df[col + "_length"] = df[col].apply(lambda x: len(np.load(x if data_folder is None else data_folder / x)))
53
+
54
+
55
+ def dataset_add_labels_col(df, col="label", data_folder=None):
56
+ """add a column with unique labels in column col"""
57
+ df[col + "_labels"] = df[col].apply(
58
+ lambda x: list(np.unique(np.load(x if data_folder is None else data_folder / x))))
59
+
60
+
61
+ def dataset_add_mean_col(df, col="data", axis=(0), data_folder=None):
62
+ """adds a column with mean"""
63
+ df[col + "_mean"] = df[col].apply(
64
+ lambda x: np.mean(np.load(x if data_folder is None else data_folder / x), axis=axis))
65
+
66
+
67
+ def dataset_add_median_col(df, col="data", axis=(0), data_folder=None):
68
+ """adds a column with median"""
69
+ df[col + "_median"] = df[col].apply(
70
+ lambda x: np.median(np.load(x if data_folder is None else data_folder / x), axis=axis))
71
+
72
+
73
+ def dataset_add_std_col(df, col="data", axis=(0), data_folder=None):
74
+ """adds a column with mean"""
75
+ df[col + "_std"] = df[col].apply(
76
+ lambda x: np.std(np.load(x if data_folder is None else data_folder / x), axis=axis))
77
+
78
+
79
+ def dataset_add_iqr_col(df, col="data", axis=(0), data_folder=None):
80
+ """adds a column with mean"""
81
+ df[col + "_iqr"] = df[col].apply(lambda x: iqr(np.load(x if data_folder is None else data_folder / x), axis=axis))
82
+
83
+
84
+ def dataset_get_stats(df, col="data", median=False):
85
+ """creates weighted means and stds from mean, std and length cols of the df"""
86
+ mean = np.average(np.stack(df[col + ("_median" if median is True else "_mean")], axis=0), axis=0,
87
+ weights=np.array(df[col + "_length"]))
88
+ std = np.average(np.stack(df[col + ("_iqr" if median is True else "_std")], axis=0), axis=0,
89
+ weights=np.array(df[col + "_length"]))
90
+ return mean, std
91
+
92
+
93
+ def npys_to_memmap(npys, target_filename, delete_npys=False):
94
+ memmap = None
95
+ start = []
96
+ length = []
97
+ files = []
98
+ ids = []
99
+
100
+ for idx, npy in enumerate(npys):
101
+ data = np.load(npy)
102
+ if memmap is None:
103
+ memmap = np.memmap(target_filename, dtype=data.dtype, mode='w+', shape=data.shape)
104
+ start.append(0)
105
+ length.append(data.shape[0])
106
+ else:
107
+ start.append(start[-1] + length[-1])
108
+ length.append(data.shape[0])
109
+ memmap = np.memmap(target_filename, dtype=data.dtype, mode='r+',
110
+ shape=tuple([start[-1] + length[-1]] + [l for l in data.shape[1:]]))
111
+
112
+ ids.append(idx)
113
+ memmap[start[-1]:start[-1] + length[-1]] = data[:]
114
+ memmap.flush()
115
+ if delete_npys is True:
116
+ npy.unlink()
117
+ del memmap
118
+
119
+ np.savez(target_filename.parent / (target_filename.stem + "_meta.npz"), start=start, length=length,
120
+ shape=[start[-1] + length[-1]] + [l for l in data.shape[1:]], dtype=data.dtype)
121
+
122
+
123
+ def reformat_as_memmap(df, target_filename, data_folder=None, annotation=False, delete_npys=False):
124
+ npys_data = []
125
+ npys_label = []
126
+
127
+ for id, row in df.iterrows():
128
+ npys_data.append(data_folder / row["data"] if data_folder is not None else row["data"])
129
+ if annotation:
130
+ npys_label.append(data_folder / row["label"] if data_folder is not None else row["label"])
131
+
132
+ npys_to_memmap(npys_data, target_filename, delete_npys=delete_npys)
133
+ if annotation:
134
+ npys_to_memmap(npys_label, target_filename.parent / (target_filename.stem + "_label.npy"),
135
+ delete_npys=delete_npys)
136
+
137
+ # replace data(filename) by integer
138
+ df_mapped = df.copy()
139
+ df_mapped["data_original"] = df_mapped.data
140
+ df_mapped["data"] = np.arange(len(df_mapped))
141
+ df_mapped.to_pickle(target_filename.parent / ("df_" + target_filename.stem + ".pkl"))
142
+ return df_mapped
143
+
144
+
145
+ # TimeseriesDatasetCrops
146
+
147
+ class TimeseriesDatasetCrops(torch.utils.data.Dataset):
148
+ """timeseries dataset with partial crops."""
149
+
150
+ def __init__(self, df, output_size, chunk_length, min_chunk_length, memmap_filename=None, npy_data=None,
151
+ random_crop=True, data_folder=None, num_classes=2, copies=0, col_lbl="label", stride=None, start_idx=0,
152
+ annotation=False, transforms=None):
153
+ """
154
+ accepts three kinds of input:
155
+ 1) filenames pointing to aligned numpy arrays [timesteps,channels,...] for data and either integer labels or filename pointing to numpy arrays[timesteps,...] e.g. for annotations
156
+ 2) memmap_filename to memmap for data [concatenated,...] and labels- label column in df corresponds to index in this memmap
157
+ 3) npy_data [samples,ts,...] (either path or np.array directly- also supporting variable length input) - label column in df corresponds to sampleid
158
+
159
+ transforms: list of callables (transformations) (applied in the specified order i.e. leftmost element first)
160
+ """
161
+ if transforms is None:
162
+ transforms = []
163
+ assert not ((memmap_filename is not None) and (npy_data is not None))
164
+ # require integer entries if using memmap or npy
165
+ assert (memmap_filename is None and npy_data is None) or df.data.dtype == np.int64
166
+
167
+ self.timeseries_df = df
168
+ self.output_size = output_size
169
+ self.data_folder = data_folder
170
+ self.transforms = transforms
171
+ self.annotation = annotation
172
+ self.col_lbl = col_lbl
173
+
174
+ self.c = num_classes
175
+
176
+ self.mode = "files"
177
+ self.memmap_filename = memmap_filename
178
+ if memmap_filename is not None:
179
+ self.mode = "memmap"
180
+ memmap_meta = np.load(memmap_filename.parent / (memmap_filename.stem + "_meta.npz"))
181
+ self.memmap_start = memmap_meta["start"]
182
+ self.memmap_shape = tuple(memmap_meta["shape"])
183
+ self.memmap_length = memmap_meta["length"]
184
+ self.memmap_dtype = np.dtype(str(memmap_meta["dtype"]))
185
+ self.memmap_file_process_dict = {}
186
+ if annotation:
187
+ memmap_meta_label = np.load(memmap_filename.parent / (memmap_filename.stem + "_label_meta.npz"))
188
+ self.memmap_filename_label = memmap_filename.parent / (memmap_filename.stem + "_label.npy")
189
+ self.memmap_shape_label = tuple(memmap_meta_label["shape"])
190
+ self.memmap_file_process_dict_label = {}
191
+ self.memmap_dtype_label = np.dtype(str(memmap_meta_label["dtype"]))
192
+ elif npy_data is not None:
193
+ self.mode = "npy"
194
+ if isinstance(npy_data, np.ndarray) or isinstance(npy_data, list):
195
+ self.npy_data = np.array(npy_data)
196
+ assert (annotation is False)
197
+ else:
198
+ self.npy_data = np.load(npy_data)
199
+ if annotation:
200
+ self.npy_data_label = np.load(npy_data.parent / (npy_data.stem + "_label.npy"))
201
+
202
+ self.random_crop = random_crop
203
+
204
+ self.df_idx_mapping = []
205
+ self.start_idx_mapping = []
206
+ self.end_idx_mapping = []
207
+
208
+ for df_idx, (id, row) in enumerate(df.iterrows()):
209
+ if self.mode == "files":
210
+ data_length = row["data_length"]
211
+ elif self.mode == "memmap":
212
+ data_length = self.memmap_length[row["data"]]
213
+ else: # npy
214
+ data_length = len(self.npy_data[row["data"]])
215
+
216
+ if chunk_length == 0: # do not split
217
+ idx_start = [start_idx]
218
+ idx_end = [data_length]
219
+ else:
220
+ idx_start = list(range(start_idx, data_length, chunk_length if stride is None else stride))
221
+ idx_end = [min(l + chunk_length, data_length) for l in idx_start]
222
+
223
+ # remove final chunk(s) if too short
224
+ for i in range(len(idx_start)):
225
+ if idx_end[i] - idx_start[i] < min_chunk_length:
226
+ del idx_start[i:]
227
+ del idx_end[i:]
228
+ break
229
+ # append to lists
230
+ for _ in range(copies + 1):
231
+ for i_s, i_e in zip(idx_start, idx_end):
232
+ self.df_idx_mapping.append(df_idx)
233
+ self.start_idx_mapping.append(i_s)
234
+ self.end_idx_mapping.append(i_e)
235
+
236
+ def __len__(self):
237
+ return len(self.df_idx_mapping)
238
+
239
+ def __getitem__(self, idx):
240
+ df_idx = self.df_idx_mapping[idx]
241
+ start_idx = self.start_idx_mapping[idx]
242
+ end_idx = self.end_idx_mapping[idx]
243
+ # determine crop idxs
244
+ timesteps = end_idx - start_idx
245
+ assert (timesteps >= self.output_size)
246
+ if self.random_crop: # random crop
247
+ if timesteps == self.output_size:
248
+ start_idx_crop = start_idx
249
+ else:
250
+ start_idx_crop = start_idx + random.randint(0, timesteps - self.output_size - 1) # np.random.randint(0, timesteps - self.output_size)
251
+ else:
252
+ start_idx_crop = start_idx + (timesteps - self.output_size) // 2
253
+ end_idx_crop = start_idx_crop + self.output_size
254
+
255
+ # print(idx,start_idx,end_idx,start_idx_crop,end_idx_crop)
256
+ # load the actual data
257
+ if self.mode == "files": # from separate files
258
+ data_filename = self.timeseries_df.iloc[df_idx]["data"]
259
+ if self.data_folder is not None:
260
+ data_filename = self.data_folder / data_filename
261
+ data = np.load(data_filename)[
262
+ start_idx_crop:end_idx_crop] # data type has to be adjusted when saving to npy
263
+
264
+ ID = data_filename.stem
265
+
266
+ if self.annotation is True:
267
+ label_filename = self.timeseries_df.iloc[df_idx][self.col_lbl]
268
+ if self.data_folder is not None:
269
+ label_filename = self.data_folder / label_filename
270
+ label = np.load(label_filename)[
271
+ start_idx_crop:end_idx_crop] # data type has to be adjusted when saving to npy
272
+ else:
273
+ label = self.timeseries_df.iloc[df_idx][self.col_lbl] # input type has to be adjusted in the dataframe
274
+ elif self.mode == "memmap": # from one memmap file
275
+ ID = self.timeseries_df.iloc[df_idx]["data_original"].stem
276
+ memmap_idx = self.timeseries_df.iloc[df_idx][
277
+ "data"] # grab the actual index (Note the df to create the ds might be a subset of the original df used to create the memmap)
278
+ idx_offset = self.memmap_start[memmap_idx]
279
+
280
+ pid = os.getpid()
281
+ # print("idx",idx,"ID",ID,"idx_offset",idx_offset,"start_idx_crop",start_idx_crop,"df_idx", self.df_idx_mapping[idx],"pid",pid)
282
+ mem_file = self.memmap_file_process_dict.get(pid, None) # each process owns its handler.
283
+ if mem_file is None:
284
+ # print("memmap_shape", self.memmap_shape)
285
+ mem_file = np.memmap(self.memmap_filename, self.memmap_dtype, mode='r', shape=self.memmap_shape)
286
+ self.memmap_file_process_dict[pid] = mem_file
287
+ data = np.copy(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
288
+ # print(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
289
+ if self.annotation:
290
+ mem_file_label = self.memmap_file_process_dict_label.get(pid, None) # each process owns its handler.
291
+ if mem_file_label is None:
292
+ mem_file_label = np.memmap(self.memmap_filename_label, self.memmap_dtype, mode='r',
293
+ shape=self.memmap_shape_label)
294
+ self.memmap_file_process_dict_label[pid] = mem_file_label
295
+ label = np.copy(mem_file_label[idx_offset + start_idx_crop: idx_offset + end_idx_crop])
296
+ else:
297
+ label = self.timeseries_df.iloc[df_idx][self.col_lbl]
298
+ else: # single npy array
299
+ ID = self.timeseries_df.iloc[df_idx]["data"]
300
+
301
+ data = self.npy_data[ID][start_idx_crop:end_idx_crop]
302
+
303
+ if self.annotation:
304
+ label = self.npy_data_label[ID][start_idx_crop:end_idx_crop]
305
+ else:
306
+ label = self.timeseries_df.iloc[df_idx][self.col_lbl]
307
+ sample = {'data': data, 'label': label, 'ID': ID}
308
+
309
+ for t in self.transforms:
310
+ sample = t(sample)
311
+
312
+ return sample
313
+
314
+ def get_sampling_weights(self, class_weight_dict, length_weighting=False, group_by_col=None):
315
+ assert (self.annotation is False)
316
+ assert (length_weighting is False or group_by_col is None)
317
+ weights = np.zeros(len(self.df_idx_mapping), dtype=np.float32)
318
+ length_per_class = {}
319
+ length_per_group = {}
320
+ for iw, (i, s, e) in enumerate(zip(self.df_idx_mapping, self.start_idx_mapping, self.end_idx_mapping)):
321
+ label = self.timeseries_df.iloc[i][self.col_lbl]
322
+ weight = class_weight_dict[label]
323
+ if length_weighting:
324
+ if label in length_per_class.keys():
325
+ length_per_class[label] += e - s
326
+ else:
327
+ length_per_class[label] = e - s
328
+ if group_by_col is not None:
329
+ group = self.timeseries_df.iloc[i][group_by_col]
330
+ if group in length_per_group.keys():
331
+ length_per_group[group] += e - s
332
+ else:
333
+ length_per_group[group] = e - s
334
+ weights[iw] = weight
335
+
336
+ if length_weighting: # need second pass to properly take into account the total length per class
337
+ for iw, (i, s, e) in enumerate(zip(self.df_idx_mapping, self.start_idx_mapping, self.end_idx_mapping)):
338
+ label = self.timeseries_df.iloc[i][self.col_lbl]
339
+ weights[iw] = (e - s) / length_per_class[label] * weights[iw]
340
+ if group_by_col is not None:
341
+ for iw, (i, s, e) in enumerate(zip(self.df_idx_mapping, self.start_idx_mapping, self.end_idx_mapping)):
342
+ group = self.timeseries_df.iloc[i][group_by_col]
343
+ weights[iw] = (e - s) / length_per_group[group] * weights[iw]
344
+
345
+ weights = weights / np.min(weights) # normalize smallest weight to 1
346
+ return weights
347
+
348
+ def get_id_mapping(self):
349
+ return self.df_idx_mapping
350
+
351
+
352
+ class RandomCrop(object):
353
+ """
354
+ Crop randomly the image in a sample (deprecated).
355
+ """
356
+
357
+ def __init__(self, output_size, annotation=False):
358
+ self.output_size = output_size
359
+ self.annotation = annotation
360
+
361
+ def __call__(self, sample):
362
+ data, label, ID = sample['data'], sample['label'], sample['ID']
363
+
364
+ timesteps = len(data)
365
+ assert (timesteps >= self.output_size)
366
+ if timesteps == self.output_size:
367
+ start = 0
368
+ else:
369
+ start = random.randint(0, timesteps - self.output_size - 1) # np.random.randint(0, timesteps - self.output_size)
370
+
371
+ data = data[start: start + self.output_size]
372
+ if self.annotation:
373
+ label = label[start: start + self.output_size]
374
+
375
+ return {'data': data, 'label': label, "ID": ID}
376
+
377
+
378
+ class CenterCrop(object):
379
+ """
380
+ Center crop the image in a sample (deprecated).
381
+ """
382
+
383
+ def __init__(self, output_size, annotation=False):
384
+ self.output_size = output_size
385
+ self.annotation = annotation
386
+
387
+ def __call__(self, sample):
388
+ data, label, ID = sample['data'], sample['label'], sample['ID']
389
+
390
+ timesteps = len(data)
391
+
392
+ start = (timesteps - self.output_size) // 2
393
+
394
+ data = data[start: start + self.output_size]
395
+ if self.annotation:
396
+ label = label[start: start + self.output_size]
397
+
398
+ return {'data': data, 'label': label, "ID": ID}
399
+
400
+
401
+ class GaussianNoise(object):
402
+ """
403
+ Add gaussian noise to sample.
404
+ """
405
+
406
+ def __init__(self, scale=0.1):
407
+ self.scale = scale
408
+
409
+ def __call__(self, sample):
410
+ if self.scale == 0:
411
+ return sample
412
+ else:
413
+ data, label, ID = sample['data'], sample['label'], sample['ID']
414
+ data = data + np.reshape(np.array([random.gauss(0, self.scale) for _ in range(np.prod(data.shape))]),
415
+ data.shape) # np.random.normal(scale=self.scale,size=data.shape).astype(np.float32)
416
+ return {'data': data, 'label': label, "ID": ID}
417
+
418
+
419
+ class Rescale(object):
420
+ """Rescale by factor.
421
+ """
422
+
423
+ def __init__(self, scale=0.5, interpolation_order=3):
424
+ self.scale = scale
425
+ self.interpolation_order = interpolation_order
426
+
427
+ def __call__(self, sample):
428
+ if self.scale == 1:
429
+ return sample
430
+ else:
431
+ data, label, ID = sample['data'], sample['label'], sample['ID']
432
+ timesteps_new = int(self.scale * len(data))
433
+ data = transform.resize(data, (timesteps_new, data.shape[1]), order=self.interpolation_order).astype(
434
+ np.float32)
435
+ return {'data': data, 'label': label, "ID": ID}
436
+
437
+
438
+ class ToTensor(object):
439
+ """Convert ndarrays in sample to Tensors."""
440
+
441
+ def __init__(self, transpose_data1d=True):
442
+ self.transpose_data1d = transpose_data1d
443
+
444
+ def __call__(self, sample):
445
+ def _to_tensor(data, transpose_data1d=False):
446
+ if (
447
+ len(data.shape) == 2 and transpose_data1d is True): # swap channel and time axis for direct application of pytorch's 1d convs
448
+ data = data.transpose((1, 0))
449
+ if isinstance(data, np.ndarray):
450
+ return torch.from_numpy(data)
451
+ else: # default_collate will take care of it
452
+ return data
453
+
454
+ data, label, ID = sample['data'], sample['label'], sample['ID']
455
+
456
+ if not isinstance(data, tuple):
457
+ data = _to_tensor(data, self.transpose_data1d)
458
+ else:
459
+ data = tuple(_to_tensor(x, self.transpose_data1d) for x in data)
460
+
461
+ if not isinstance(label, tuple):
462
+ label = _to_tensor(label)
463
+ else:
464
+ label = tuple(_to_tensor(x) for x in label)
465
+
466
+ return data, label # returning as a tuple (potentially of lists)
467
+
468
+
469
+ class Normalize(object):
470
+ """
471
+ Normalize using given stats.
472
+ """
473
+
474
+ def __init__(self, stats_mean, stats_std, input=True, channels=None):
475
+ if channels is None:
476
+ channels = []
477
+ self.stats_mean = np.expand_dims(stats_mean.astype(np.float32), axis=0) if stats_mean is not None else None
478
+ self.stats_std = np.expand_dims(stats_std.astype(np.float32), axis=0) + 1e-8 if stats_std is not None else None
479
+ self.input = input
480
+ if len(channels) > 0:
481
+ for i in range(len(stats_mean)):
482
+ if not (i in channels):
483
+ self.stats_mean[:, i] = 0
484
+ self.stats_std[:, i] = 1
485
+
486
+ def __call__(self, sample):
487
+ if self.input:
488
+ data = sample['data']
489
+ else:
490
+ data = sample['label']
491
+
492
+ if self.stats_mean is not None:
493
+ data = data - self.stats_mean
494
+ if self.stats_std is not None:
495
+ data = data / self.stats_std
496
+
497
+ if self.input:
498
+ return {'data': data, 'label': sample['label'], "ID": sample['ID']}
499
+ else:
500
+ return {'data': sample['data'], 'label': data, "ID": sample['ID']}
501
+
502
+
503
+ class ButterFilter(object):
504
+ """
505
+ Normalize using given stats.
506
+ """
507
+
508
+ def __init__(self, lowcut=50, highcut=50, fs=100, order=5, btype='band', forwardbackward=True, input=True):
509
+ self.filter = butter_filter(lowcut, highcut, fs, order, btype)
510
+ self.input = input
511
+ self.forwardbackward = forwardbackward
512
+
513
+ def __call__(self, sample):
514
+ if self.input:
515
+ data = sample['data']
516
+ else:
517
+ data = sample['label']
518
+
519
+ # check multiple axis
520
+ if self.forwardbackward:
521
+ data = sosfiltfilt(self.filter, data, axis=0)
522
+ else:
523
+ data = sosfilt(self.filter, data, axis=0)
524
+
525
+ if self.input:
526
+ return {'data': data, 'label': sample['label'], "ID": sample['ID']}
527
+ else:
528
+ return {'data': sample['data'], 'label': data, "ID": sample['ID']}
529
+
530
+
531
+ class ChannelFilter(object):
532
+ """
533
+ Select certain channels.
534
+ """
535
+
536
+ def __init__(self, channels=None, input=True):
537
+ if channels is None:
538
+ channels = [0]
539
+ self.channels = channels
540
+ self.input = input
541
+
542
+ def __call__(self, sample):
543
+ if self.input:
544
+ return {'data': sample['data'][:, self.channels], 'label': sample['label'], "ID": sample['ID']}
545
+ else:
546
+ return {'data': sample['data'], 'label': sample['label'][:, self.channels], "ID": sample['ID']}
547
+
548
+
549
+ class Transform(object):
550
+ """
551
+ Transforms data using a given function i.e. data_new = func(data) for input is True else label_new = func(label)
552
+ """
553
+
554
+ def __init__(self, func, input=False):
555
+ self.func = func
556
+ self.input = input
557
+
558
+ def __call__(self, sample):
559
+ if self.input:
560
+ return {'data': self.func(sample['data']), 'label': sample['label'], "ID": sample['ID']}
561
+ else:
562
+ return {'data': sample['data'], 'label': self.func(sample['label']), "ID": sample['ID']}
563
+
564
+
565
+ class TupleTransform(object):
566
+ """
567
+ Transforms data using a given function (operating on both data and label and return a tuple) i.e. data_new, label_new = func(data_old, label_old)
568
+ """
569
+
570
+ def __init__(self, func, input=False):
571
+ self.func = func
572
+
573
+ def __call__(self, sample):
574
+ data_new, label_new = self.func(sample['data'], sample['label'])
575
+ return {'data': data_new, 'label': label_new, "ID": sample['ID']}
576
+
577
+
578
+ # MIL and ensemble models
579
+
580
+ def aggregate_predictions(preds, targs=None, idmap=None, aggregate_fn=np.mean, verbose=True):
581
+ """
582
+ aggregates potentially multiple predictions per sample (can also pass targs for convenience)
583
+ idmap: idmap as returned by TimeSeriesCropsDataset's get_id_mapping
584
+ preds: ordered predictions as returned by learn.get_preds()
585
+ aggregate_fn: function that is used to aggregate multiple predictions per sample (most commonly np.amax or np.mean)
586
+ """
587
+ if idmap is not None and len(idmap) != len(np.unique(idmap)):
588
+ if verbose:
589
+ print("aggregating predictions...")
590
+ preds_aggregated = []
591
+ targs_aggregated = []
592
+ for i in np.unique(idmap):
593
+ preds_local = preds[np.where(idmap == i)[0]]
594
+ preds_aggregated.append(aggregate_fn(preds_local, axis=0))
595
+ if targs is not None:
596
+ targs_local = targs[np.where(idmap == i)[0]]
597
+ assert (np.all(targs_local == targs_local[0])) # all labels have to agree
598
+ targs_aggregated.append(targs_local[0])
599
+ if targs is None:
600
+ return np.array(preds_aggregated)
601
+ else:
602
+ return np.array(preds_aggregated), np.array(targs_aggregated)
603
+ else:
604
+ if targs is None:
605
+ return preds
606
+ else:
607
+ return preds, targs
608
+
609
+
610
+ class milwrapper(nn.Module):
611
+ def __init__(self, model, input_size, n, stride=None, softmax=True):
612
+ super().__init__()
613
+ self.n = n
614
+ self.input_size = input_size
615
+ self.model = model
616
+ self.softmax = softmax
617
+ self.stride = input_size if stride is None else stride
618
+
619
+ def forward(self, x):
620
+ # bs,ch,seq
621
+ for i in range(self.n):
622
+ pred_single = self.model(x[:, :, i * self.stride:i * self.stride + self.input_size])
623
+ pred_single = nn.functional.softmax(pred_single, dim=1)
624
+ if i == 0:
625
+ pred = pred_single
626
+ else:
627
+ pred += pred_single
628
+ return pred / self.n
629
+
630
+
631
+ class ensemblewrapper(nn.Module):
632
+ def __init__(self, model, checkpts):
633
+ super().__init__()
634
+ self.model = model
635
+ self.checkpts = checkpts
636
+
637
+ def forward(self, x):
638
+ # bs,ch,seq
639
+ for i, c in enumerate(self.checkpts):
640
+ state = torch.load(Path("./models/") / f'{c}.pth', map_location=x.device)
641
+ self.model.load_state_dict(state['model'], strict=True)
642
+
643
+ pred_single = self.model(x)
644
+ pred_single = nn.functional.softmax(pred_single, dim=1)
645
+ if (i == 0):
646
+ pred = pred_single
647
+ else:
648
+ pred += pred_single
649
+ return pred / len(self.checkpts)
utilities/utils.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import pickle
4
+ import pandas as pd
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import wfdb
8
+ import ast
9
+ from sklearn.metrics import roc_auc_score, roc_curve
10
+ from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer
11
+
12
+
13
+ # EVALUATION STUFF
14
+ def generate_results(idxs, y_true, y_pred, thresholds):
15
+ return evaluate_experiment(y_true[idxs], y_pred[idxs], thresholds)
16
+
17
+
18
+ def evaluate_experiment(y_true, y_pred, thresholds=None):
19
+ results = {}
20
+
21
+ if not thresholds is None:
22
+ # binary predictions
23
+ y_pred_binary = apply_thresholds(y_pred, thresholds)
24
+ # PhysioNet/CinC Challenges metrics
25
+ challenge_scores = challenge_metrics(y_true, y_pred_binary, beta1=2, beta2=2)
26
+ results['F_beta_macro'] = challenge_scores['F_beta_macro']
27
+ results['G_beta_macro'] = challenge_scores['G_beta_macro']
28
+ results['TP'] = challenge_scores['TP']
29
+ results['TN'] = challenge_scores['TN']
30
+ results['FP'] = challenge_scores['FP']
31
+ results['FN'] = challenge_scores['FN']
32
+ results['Accuracy'] = challenge_scores['Accuracy']
33
+ results['F1'] = challenge_scores['F1']
34
+ results['Precision'] = challenge_scores['Precision']
35
+ results['Recall'] = challenge_scores['Recall']
36
+
37
+ # label based metric
38
+ results['macro_auc'] = roc_auc_score(y_true, y_pred, average='macro')
39
+
40
+ df_result = pd.DataFrame(results, index=[0])
41
+ return df_result
42
+
43
+
44
+ def challenge_metrics(y_true, y_pred, beta1=2, beta2=2, single=False):
45
+ f_beta = 0
46
+ g_beta = 0
47
+ TP, FP, TN, FN = 0., 0., 0., 0.
48
+ Accuracy = 0
49
+ Precision = 0
50
+ Recall = 0
51
+ F1 = 0
52
+
53
+ if single: # if evaluating single class in case of threshold-optimization
54
+ sample_weights = np.ones(y_true.sum(axis=1).shape)
55
+ else:
56
+ sample_weights = y_true.sum(axis=1)
57
+ for classi in range(y_true.shape[1]):
58
+ y_truei, y_predi = y_true[:, classi], y_pred[:, classi]
59
+ TP, FP, TN, FN = 0., 0., 0., 0.
60
+ for i in range(len(y_predi)):
61
+ sample_weight = sample_weights[i]
62
+ if y_truei[i] == y_predi[i] == 1:
63
+ TP += 1. / sample_weight
64
+ if (y_predi[i] == 1) and (y_truei[i] != y_predi[i]):
65
+ FP += 1. / sample_weight
66
+ if y_truei[i] == y_predi[i] == 0:
67
+ TN += 1. / sample_weight
68
+ if (y_predi[i] == 0) and (y_truei[i] != y_predi[i]):
69
+ FN += 1. / sample_weight
70
+ f_beta_i = ((1 + beta1 ** 2) * TP) / ((1 + beta1 ** 2) * TP + FP + (beta1 ** 2) * FN)
71
+ g_beta_i = TP / (TP + FP + beta2 * FN)
72
+
73
+ f_beta += f_beta_i
74
+ g_beta += g_beta_i
75
+
76
+ Accuracy = (TP + TN) / (FP + TP + TN + FN)
77
+ # Precision = TP / (TP + FP)
78
+ # Recall = TP / (TP + FN)
79
+ # F1 = 2*(Precision * Recall) / (Precision + Recall)
80
+ F1 = 2 * TP / 2 * TP + FP + FN
81
+
82
+ return {'F_beta_macro': f_beta / y_true.shape[1], 'G_beta_macro': g_beta / y_true.shape[1], 'TP': TP, 'FP': FP,
83
+ 'TN': TN, 'FN': FN, 'Accuracy': Accuracy, 'F1': F1, 'Precision': Precision, 'Recall': Recall}
84
+
85
+
86
+ def get_appropriate_bootstrap_samples(y_true, n_bootstraping_samples):
87
+ samples = []
88
+ while True:
89
+ ridxs = np.random.randint(0, len(y_true), len(y_true))
90
+ if y_true[ridxs].sum(axis=0).min() != 0:
91
+ samples.append(ridxs)
92
+ if len(samples) == n_bootstraping_samples:
93
+ break
94
+ return samples
95
+
96
+
97
+ def find_optimal_cutoff_threshold(target, predicted):
98
+ """
99
+ Find the optimal probability cutoff point for a classification model related to event rate
100
+ """
101
+ fpr, tpr, threshold = roc_curve(target, predicted)
102
+ optimal_idx = np.argmax(tpr - fpr)
103
+ optimal_threshold = threshold[optimal_idx]
104
+ return optimal_threshold
105
+
106
+
107
+ def find_optimal_cutoff_thresholds(y_true, y_pred):
108
+ return [find_optimal_cutoff_threshold(y_true[:, i], y_pred[:, i]) for i in range(y_true.shape[1])]
109
+
110
+
111
+ def find_optimal_cutoff_threshold_for_Gbeta(target, predicted, n_thresholds=100):
112
+ thresholds = np.linspace(0.00, 1, n_thresholds)
113
+ scores = [challenge_metrics(target, predicted > t, single=True)['G_beta_macro'] for t in thresholds]
114
+ optimal_idx = np.argmax(scores)
115
+ return thresholds[optimal_idx]
116
+
117
+
118
+ def find_optimal_cutoff_thresholds_for_Gbeta(y_true, y_pred):
119
+ print("optimize thresholds with respect to G_beta")
120
+ return [
121
+ find_optimal_cutoff_threshold_for_Gbeta(y_true[:, k][:, np.newaxis], y_pred[:, k][:, np.newaxis])
122
+ for k in tqdm(range(y_true.shape[1]))]
123
+
124
+
125
+ def apply_thresholds(preds, thresholds):
126
+ """
127
+ apply class-wise thresholds to prediction score in order to get binary format.
128
+ BUT: if no score is above threshold, pick maximum. This is needed due to metric issues.
129
+ """
130
+ tmp = []
131
+ for p in preds:
132
+ tmp_p = (p > thresholds).astype(int)
133
+ if np.sum(tmp_p) == 0:
134
+ tmp_p[np.argmax(p)] = 1
135
+ tmp.append(tmp_p)
136
+ tmp = np.array(tmp)
137
+ return tmp
138
+
139
+
140
+ # DATA PROCESSING STUFF
141
+ def load_dataset(path, sampling_rate, release=False):
142
+ if path.split('/')[-2] == 'ptbxl':
143
+ # load and convert annotation data
144
+ Y = pd.read_csv(path + 'ptbxl_database.csv', index_col='ecg_id')
145
+ Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
146
+
147
+ # Load raw signal data
148
+ X = load_raw_data_ptbxl(Y, sampling_rate, path)
149
+
150
+ elif path.split('/')[-2] == 'ICBEB':
151
+ # load and convert annotation data
152
+ Y = pd.read_csv(path + 'icbeb_database.csv', index_col='ecg_id')
153
+ Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
154
+
155
+ # Load raw signal data
156
+ X = load_raw_data_icbeb(Y, sampling_rate, path)
157
+
158
+ return X, Y
159
+
160
+
161
+ def load_raw_data_icbeb(df, sampling_rate, path):
162
+ if sampling_rate == 100:
163
+ if os.path.exists(path + 'raw100.npy'):
164
+ data = np.load(path + 'raw100.npy', allow_pickle=True)
165
+ else:
166
+ data = [wfdb.rdsamp(path + 'records100/' + str(f)) for f in tqdm(df.index)]
167
+ data = np.array([signal for signal, meta in data])
168
+ pickle.dump(data, open(path + 'raw100.npy', 'wb'), protocol=4)
169
+ elif sampling_rate == 500:
170
+ if os.path.exists(path + 'raw500.npy'):
171
+ data = np.load(path + 'raw500.npy', allow_pickle=True)
172
+ else:
173
+ data = [wfdb.rdsamp(path + 'records500/' + str(f)) for f in tqdm(df.index)]
174
+ data = np.array([signal for signal, meta in data])
175
+ pickle.dump(data, open(path + 'raw500.npy', 'wb'), protocol=4)
176
+ return data
177
+
178
+
179
+ def load_raw_data_ptbxl(df, sampling_rate, path):
180
+ if sampling_rate == 100:
181
+ if os.path.exists(path + 'raw100.npy'):
182
+ data = np.load(path + 'raw100.npy', allow_pickle=True)
183
+ else:
184
+ data = [wfdb.rdsamp(path + f) for f in tqdm(df.filename_lr)]
185
+ data = np.array([signal for signal, meta in data])
186
+ pickle.dump(data, open(path + 'raw100.npy', 'wb'), protocol=4)
187
+ elif sampling_rate == 500:
188
+ if os.path.exists(path + 'raw500.npy'):
189
+ data = np.load(path + 'raw500.npy', allow_pickle=True)
190
+ else:
191
+ data = [wfdb.rdsamp(path + f) for f in tqdm(df.filename_hr)]
192
+ data = np.array([signal for signal, meta in data])
193
+ pickle.dump(data, open(path + 'raw500.npy', 'wb'), protocol=4)
194
+ return data
195
+
196
+
197
+ def compute_label_aggregations(df, folder, ctype):
198
+ df['scp_codes_len'] = df.scp_codes.apply(lambda x: len(x))
199
+
200
+ aggregation_df = pd.read_csv(folder + 'scp_statements.csv', index_col=0)
201
+
202
+ if ctype in ['diagnostic', 'subdiagnostic', 'superdiagnostic']:
203
+
204
+ def aggregate_all_diagnostic(y_dic):
205
+ tmp = []
206
+ for key in y_dic.keys():
207
+ if key in diag_agg_df.index:
208
+ tmp.append(key)
209
+ return list(set(tmp))
210
+
211
+ def aggregate_subdiagnostic(y_dic):
212
+ tmp = []
213
+ for key in y_dic.keys():
214
+ if key in diag_agg_df.index:
215
+ c = diag_agg_df.loc[key].diagnostic_subclass
216
+ if str(c) != 'nan':
217
+ tmp.append(c)
218
+ return list(set(tmp))
219
+
220
+ def aggregate_diagnostic(y_dic):
221
+ tmp = []
222
+ for key in y_dic.keys():
223
+ if key in diag_agg_df.index:
224
+ c = diag_agg_df.loc[key].diagnostic_class
225
+ if str(c) != 'nan':
226
+ tmp.append(c)
227
+ return list(set(tmp))
228
+
229
+ diag_agg_df = aggregation_df[aggregation_df.diagnostic == 1.0]
230
+ if ctype == 'diagnostic':
231
+ df['diagnostic'] = df.scp_codes.apply(aggregate_all_diagnostic)
232
+ df['diagnostic_len'] = df.diagnostic.apply(lambda x: len(x))
233
+ elif ctype == 'subdiagnostic':
234
+ df['subdiagnostic'] = df.scp_codes.apply(aggregate_subdiagnostic)
235
+ df['subdiagnostic_len'] = df.subdiagnostic.apply(lambda x: len(x))
236
+ elif ctype == 'superdiagnostic':
237
+ df['superdiagnostic'] = df.scp_codes.apply(aggregate_diagnostic)
238
+ df['superdiagnostic_len'] = df.superdiagnostic.apply(lambda x: len(x))
239
+ elif ctype == 'form':
240
+ form_agg_df = aggregation_df[aggregation_df.form == 1.0]
241
+
242
+ def aggregate_form(y_dic):
243
+ tmp = []
244
+ for key in y_dic.keys():
245
+ if key in form_agg_df.index:
246
+ c = key
247
+ if str(c) != 'nan':
248
+ tmp.append(c)
249
+ return list(set(tmp))
250
+
251
+ df['form'] = df.scp_codes.apply(aggregate_form)
252
+ df['form_len'] = df.form.apply(lambda x: len(x))
253
+ elif ctype == 'rhythm':
254
+ rhythm_agg_df = aggregation_df[aggregation_df.rhythm == 1.0]
255
+
256
+ def aggregate_rhythm(y_dic):
257
+ tmp = []
258
+ for key in y_dic.keys():
259
+ if key in rhythm_agg_df.index:
260
+ c = key
261
+ if str(c) != 'nan':
262
+ tmp.append(c)
263
+ return list(set(tmp))
264
+
265
+ df['rhythm'] = df.scp_codes.apply(aggregate_rhythm)
266
+ df['rhythm_len'] = df.rhythm.apply(lambda x: len(x))
267
+ elif ctype == 'all':
268
+ df['all_scp'] = df.scp_codes.apply(lambda x: list(set(x.keys())))
269
+
270
+ return df
271
+
272
+
273
+ def select_data(XX, YY, ctype, min_samples, output_folder):
274
+ # convert multi_label to multi-hot
275
+ mlb = MultiLabelBinarizer()
276
+
277
+ if ctype == 'diagnostic':
278
+ X = XX[YY.diagnostic_len > 0]
279
+ Y = YY[YY.diagnostic_len > 0]
280
+ mlb.fit(Y.diagnostic.values)
281
+ y = mlb.transform(Y.diagnostic.values)
282
+ elif ctype == 'subdiagnostic':
283
+ counts = pd.Series(np.concatenate(YY.subdiagnostic.values)).value_counts()
284
+ counts = counts[counts > min_samples]
285
+ YY.subdiagnostic = YY.subdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
286
+ YY['subdiagnostic_len'] = YY.subdiagnostic.apply(lambda x: len(x))
287
+ X = XX[YY.subdiagnostic_len > 0]
288
+ Y = YY[YY.subdiagnostic_len > 0]
289
+ mlb.fit(Y.subdiagnostic.values)
290
+ y = mlb.transform(Y.subdiagnostic.values)
291
+ elif ctype == 'superdiagnostic':
292
+ counts = pd.Series(np.concatenate(YY.superdiagnostic.values)).value_counts()
293
+ counts = counts[counts > min_samples]
294
+ YY.superdiagnostic = YY.superdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
295
+ YY['superdiagnostic_len'] = YY.superdiagnostic.apply(lambda x: len(x))
296
+ X = XX[YY.superdiagnostic_len > 0]
297
+ Y = YY[YY.superdiagnostic_len > 0]
298
+ mlb.fit(Y.superdiagnostic.values)
299
+ y = mlb.transform(Y.superdiagnostic.values)
300
+ elif ctype == 'form':
301
+ # filter
302
+ counts = pd.Series(np.concatenate(YY.form.values)).value_counts()
303
+ counts = counts[counts > min_samples]
304
+ YY.form = YY.form.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
305
+ YY['form_len'] = YY.form.apply(lambda x: len(x))
306
+ # select
307
+ X = XX[YY.form_len > 0]
308
+ Y = YY[YY.form_len > 0]
309
+ mlb.fit(Y.form.values)
310
+ y = mlb.transform(Y.form.values)
311
+ elif ctype == 'rhythm':
312
+ # filter
313
+ counts = pd.Series(np.concatenate(YY.rhythm.values)).value_counts()
314
+ counts = counts[counts > min_samples]
315
+ YY.rhythm = YY.rhythm.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
316
+ YY['rhythm_len'] = YY.rhythm.apply(lambda x: len(x))
317
+ # select
318
+ X = XX[YY.rhythm_len > 0]
319
+ Y = YY[YY.rhythm_len > 0]
320
+ mlb.fit(Y.rhythm.values)
321
+ y = mlb.transform(Y.rhythm.values)
322
+ elif ctype == 'all':
323
+ # filter
324
+ counts = pd.Series(np.concatenate(YY.all_scp.values)).value_counts()
325
+ counts = counts[counts > min_samples]
326
+ YY.all_scp = YY.all_scp.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
327
+ YY['all_scp_len'] = YY.all_scp.apply(lambda x: len(x))
328
+ # select
329
+ X = XX[YY.all_scp_len > 0]
330
+ Y = YY[YY.all_scp_len > 0]
331
+ mlb.fit(Y.all_scp.values)
332
+ y = mlb.transform(Y.all_scp.values)
333
+ else:
334
+ pass
335
+
336
+ # save Label_Binarizer
337
+ with open(output_folder + 'mlb.pkl', 'wb') as tokenizer:
338
+ pickle.dump(mlb, tokenizer)
339
+
340
+ return X, Y, y, mlb
341
+
342
+
343
+ def preprocess_signals(X_train, X_validation, X_test, outputfolder):
344
+ # Standardize data such that mean 0 and variance 1
345
+ ss = StandardScaler()
346
+ ss.fit(np.vstack(X_train).flatten()[:, np.newaxis].astype(float))
347
+
348
+ # Save Standardize data
349
+ with open(outputfolder + 'standard_scaler.pkl', 'wb') as ss_file:
350
+ pickle.dump(ss, ss_file)
351
+
352
+ return apply_standardizer(X_train, ss), apply_standardizer(X_validation,
353
+ ss), apply_standardizer(
354
+ X_test, ss)
355
+
356
+
357
+ def apply_standardizer(X, ss):
358
+ X_tmp = []
359
+ for x in X:
360
+ x_shape = x.shape
361
+ X_tmp.append(ss.transform(x.flatten()[:, np.newaxis]).reshape(x_shape))
362
+ X_tmp = np.array(X_tmp)
363
+ return X_tmp
364
+
365
+
366
+ # DOCUMENTATION STUFF
367
+
368
+ def generate_ptbxl_summary_table(selection=None, folder='/output/'):
369
+ exps = ['exp0', 'exp1', 'exp1.1', 'exp1.1.1', 'exp2', 'exp3']
370
+ metrics = ['macro_auc', 'Accuracy', 'TP', 'TN', 'FP', 'FN', 'Precision', 'Recall', 'F1']
371
+ # 0 1 2 3 4 5 6 7 8
372
+
373
+ # get models
374
+ models = {}
375
+ for i, exp in enumerate(exps):
376
+ if selection is None:
377
+ exp_models = [m.split('/')[-1] for m in glob.glob(folder + str(exp) + '/models/*')]
378
+ else:
379
+ exp_models = selection
380
+ if i == 0:
381
+ models = set(exp_models)
382
+ else:
383
+ models = models.union(set(exp_models))
384
+
385
+ results_dic = {'Method': [],
386
+ 'exp0_macro_auc': [],
387
+ 'exp1_macro_auc': [],
388
+ 'exp1.1_macro_auc': [],
389
+ 'exp1.1.1_macro_auc': [],
390
+ 'exp2_macro_auc': [],
391
+ 'exp3_macro_auc': [],
392
+ 'exp0_Accuracy': [],
393
+ 'exp1_Accuracy': [],
394
+ 'exp1.1_Accuracy': [],
395
+ 'exp1.1.1_Accuracy': [],
396
+ 'exp2_Accuracy': [],
397
+ 'exp3_Accuracy': [],
398
+ 'exp0_F1': [],
399
+ 'exp1_F1': [],
400
+ 'exp1.1_F1': [],
401
+ 'exp1.1.1_F1': [],
402
+ 'exp2_F1': [],
403
+ 'exp3_F1': [],
404
+ 'exp0_Precision': [],
405
+ 'exp1_Precision': [],
406
+ 'exp1.1_Precision': [],
407
+ 'exp1.1.1_Precision': [],
408
+ 'exp2_Precision': [],
409
+ 'exp3_Precision': [],
410
+ 'exp0_Recall': [],
411
+ 'exp1_Recall': [],
412
+ 'exp1.1_Recall': [],
413
+ 'exp1.1.1_Recall': [],
414
+ 'exp2_Recall': [],
415
+ 'exp3_Recall': [],
416
+ 'exp0_TP': [],
417
+ 'exp1_TP': [],
418
+ 'exp1.1_TP': [],
419
+ 'exp1.1.1_TP': [],
420
+ 'exp2_TP': [],
421
+ 'exp3_TP': [],
422
+ 'exp0_TN': [],
423
+ 'exp1_TN': [],
424
+ 'exp1.1_TN': [],
425
+ 'exp1.1.1_TN': [],
426
+ 'exp2_TN': [],
427
+ 'exp3_TN': [],
428
+ 'exp0_FP': [],
429
+ 'exp1_FP': [],
430
+ 'exp1.1_FP': [],
431
+ 'exp1.1.1_FP': [],
432
+ 'exp2_FP': [],
433
+ 'exp3_FP': [],
434
+ 'exp0_FN': [],
435
+ 'exp1_FN': [],
436
+ 'exp1.1_FN': [],
437
+ 'exp1.1.1_FN': [],
438
+ 'exp2_FN': [],
439
+ 'exp3_FN': []
440
+ }
441
+
442
+ for m in models:
443
+ results_dic['Method'].append(m)
444
+
445
+ for e in exps:
446
+
447
+ try:
448
+ me_res = pd.read_csv(folder + str(e) + '/models/' + str(m) + '/results/te_results.csv', index_col=0)
449
+
450
+ mean1 = me_res.loc['point'][metrics[0]]
451
+ unc1 = max(me_res.loc['upper'][metrics[0]] - me_res.loc['point'][metrics[0]],
452
+ me_res.loc['point'][metrics[0]] - me_res.loc['lower'][metrics[0]])
453
+
454
+ acc = me_res.loc['point'][metrics[1]]
455
+ f1 = me_res.loc['point'][metrics[8]]
456
+ precision = me_res.loc['point'][metrics[6]]
457
+ recall = me_res.loc['point'][metrics[7]]
458
+ tp = me_res.loc['point'][metrics[2]]
459
+ tn = me_res.loc['point'][metrics[3]]
460
+ fp = me_res.loc['point'][metrics[4]]
461
+ fn = me_res.loc['point'][metrics[5]]
462
+
463
+ results_dic[e + '_macro_auc'].append("%.3f(%.2d)" % (np.round(mean1, 3), int(unc1 * 1000)))
464
+ results_dic[e + '_Accuracy'].append("%.3f" % acc)
465
+ results_dic[e + '_F1'].append("%.3f" % f1)
466
+ results_dic[e + '_Precision'].append("%.3f" % precision)
467
+ results_dic[e + '_Recall'].append("%.3f" % recall)
468
+ results_dic[e + '_TP'].append("%.3f" % tp)
469
+ results_dic[e + '_TN'].append("%.3f" % tn)
470
+ results_dic[e + '_FP'].append("%.3f" % fp)
471
+ results_dic[e + '_FN'].append("%.3f" % fn)
472
+
473
+ except FileNotFoundError:
474
+ results_dic[e + '_macro_auc'].append("--")
475
+ results_dic[e + '_Accuracy'].append("--")
476
+ results_dic[e + '_F1'].append("--")
477
+ results_dic[e + '_Precision'].append("--")
478
+ results_dic[e + '_Recall'].append("--")
479
+ results_dic[e + '_TP'].append("--")
480
+ results_dic[e + '_TN'].append("--")
481
+ results_dic[e + '_FP'].append("--")
482
+ results_dic[e + '_FN'].append("--")
483
+
484
+ df = pd.DataFrame(results_dic)
485
+ df_index = df[df.Method.isin(['naive', 'ensemble'])]
486
+ df_rest = df[~df.Method.isin(['naive', 'ensemble'])]
487
+ df = pd.concat([df_rest, df_index])
488
+ df.to_csv(folder + 'results_ptbxl.csv')
489
+
490
+ titles = [
491
+ '### 1. PTB-XL: all statements',
492
+ '### 2. PTB-XL: diagnostic statements',
493
+ '### 3. PTB-XL: Diagnostic subclasses',
494
+ '### 4. PTB-XL: Diagnostic superclasses',
495
+ '### 5. PTB-XL: Form statements',
496
+ '### 6. PTB-XL: Rhythm statements'
497
+ ]
498
+
499
+ # helper output function for markdown tables
500
+ our_work = 'https://arxiv.org/abs/2004.13701'
501
+ our_repo = 'https://github.com/helme/ecg_ptbxl_benchmarking/'
502
+ md_source = ''
503
+ for i, e in enumerate(exps):
504
+ md_source += '\n ' + titles[i] + ' \n \n'
505
+ md_source += '| Model | AUC |\n'
506
+
507
+ for row in df_rest[['Method', e + '_AUC']].sort_values(e + '_AUC', ascending=False).values:
508
+ md_source += '| ' + row[0].replace('fastai_', '') + ' | ' + row[1] + ' |\n'
509
+ print(md_source)