sohamc10 commited on
Commit
17ef2b0
·
1 Parent(s): 06789d6
Files changed (49) hide show
  1. .gitignore +1 -0
  2. LICENSE +21 -0
  3. README.md +322 -0
  4. app.py +73 -0
  5. config.py +28 -0
  6. data_util/audioset_classes.py +1393 -0
  7. data_util/audioset_strong.py +329 -0
  8. data_util/dcase2016task2.py +280 -0
  9. data_util/transforms.py +195 -0
  10. ex_audioset_strong.py +504 -0
  11. ex_dcase2016task2.py +517 -0
  12. helpers/augment.py +225 -0
  13. helpers/decode.py +72 -0
  14. helpers/encode.py +230 -0
  15. helpers/score.py +384 -0
  16. helpers/utils.py +12 -0
  17. images/downstream_task_results.png +0 -0
  18. inference.py +126 -0
  19. models/asit/ASIT_wrapper.py +60 -0
  20. models/asit/data_transformations.py +29 -0
  21. models/asit/utils.py +540 -0
  22. models/asit/vision_transformer.py +316 -0
  23. models/atstframe/ATSTF_wrapper.py +105 -0
  24. models/atstframe/audio_transformer.py +253 -0
  25. models/atstframe/transformer.py +112 -0
  26. models/beats/BEATs.py +183 -0
  27. models/beats/BEATs_wrapper.py +56 -0
  28. models/beats/Tokenizers.py +172 -0
  29. models/beats/backbone.py +783 -0
  30. models/beats/modules.py +218 -0
  31. models/beats/quantizer.py +215 -0
  32. models/frame_mn/Frame_MN_wrapper.py +75 -0
  33. models/frame_mn/block_types.py +189 -0
  34. models/frame_mn/model.py +356 -0
  35. models/frame_mn/utils.py +93 -0
  36. models/frame_passt/fpasst.py +963 -0
  37. models/frame_passt/fpasst_wrapper.py +86 -0
  38. models/frame_passt/preprocess.py +147 -0
  39. models/frame_passt/vit_helpers.py +399 -0
  40. models/m2d/M2D_wrapper.py +52 -0
  41. models/m2d/portable_m2d.py +410 -0
  42. models/prediction_wrapper.py +213 -0
  43. models/seq_models.py +40 -0
  44. models/transformer_wrapper.py +19 -0
  45. requirements.txt +20 -0
  46. resources/README.md +1 -0
  47. resources/best_model_BEATs.pth +3 -0
  48. resources/eval_durations.csv +0 -0
  49. resources/labelvocabulary.csv +201 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Florian Schmid
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -11,3 +11,325 @@ license: cc-by-nc-4.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+
16
+ # Effective Pre-Training of Audio Transformers for Sound Event Detection
17
+
18
+ In this repository, we publish pre-trained models and code for the ICASSP'25 paper: [**Effective Pre-Training of Audio Transformers for Sound Event Detection**](https://arxiv.org/abs/2409.09546).
19
+
20
+ In this paper, we propose a pre-training pipeline for audio spectrogram transformers for frame-level sound event detection tasks. On top of common pre-training steps, we add a meticulously designed training routine on AudioSet frame-level annotations. For five transformers, we show that this additional pre-training step leads to substantial performance improvements on frame-level downstream tasks. We release all model checkpoints and hope that they will help researchers improve tasks that require high-quality frame-level representations.
21
+
22
+ This repository includes:
23
+ * All pre-trained checkpoints and model files (see [here](https://github.com/fschmid56/PretrainedSED/releases/tag/v0.0.1))
24
+ * A script that demonstrates how the pre-trained checkpoints can be loaded and used for inference (see [here](https://github.com/fschmid56/PretrainedSED/blob/main/inference.py))
25
+ * Add a table outlining the external checkpoints used in this work (see [here](https://github.com/fschmid56/PretrainedSED?tab=readme-ov-file#model-checkpoints))
26
+ * Evaluation routine on the AudioSet frame-level annotations (see [here](https://github.com/fschmid56/PretrainedSED?tab=readme-ov-file#run-audioset-strong-evaluation))
27
+ * The AudioSet Strong training routine (see [here](https://github.com/fschmid56/PretrainedSED?tab=readme-ov-file#audioset-strong-pre-training))
28
+ * The ensemble logits for the AudioSet Strong dataset (see [here](https://github.com/fschmid56/PretrainedSED?tab=readme-ov-file#download-ensemble-pseudo-labels))
29
+ * A file demonstrating how the pre-trained transformers can be fine-tuned on a downstream task (see [here](ex_dcase2016task2.py))
30
+ * **New:** added two low-complexity SED models ('frame_mn10' with 3.83M parameters and 'frame_mn06' with 1.62M parameters)
31
+
32
+ ## Setting up Environment
33
+
34
+ 1. If needed, create a new environment with python 3.9 and activate it:
35
+
36
+ ```bash
37
+ conda create -n ptsed python=3.9 cython
38
+ conda activate ptsed
39
+ ```
40
+
41
+ 2. Install pytorch build that suits your system. For example:
42
+
43
+ ```bash
44
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
45
+ # or for cuda >= 12.1
46
+ pip3 install torch torchvision torchaudio
47
+ ```
48
+
49
+ 3. Install the requirements:
50
+
51
+ ```bash
52
+ pip3 install -r requirements.txt
53
+ ```
54
+
55
+ 4. Install package for mp3 decoding:
56
+
57
+ ``` bash
58
+ CFLAGS='-O3 -march=native' pip install https://github.com/f0k/minimp3py/archive/master.zip
59
+ ```
60
+
61
+ ## Inference
62
+
63
+ The script [inference.py](inference.py) demonstrates how to load a pre-trained model and run sound event detection on an audio file
64
+ of arbitrary length.
65
+
66
+ ```python
67
+ python inference.py --cuda --model_name="BEATs" --audio_file="test_files/752547__iscence__milan_metro_coming_in_station.wav"
68
+ ```
69
+
70
+ The argument ```model_name``` specifies the transformer used for inference, and the corresponding pre-trained model checkpoint
71
+ is automatically downloaded and placed in the folder [resources](resources).
72
+
73
+ The argument ```audio_file``` specifies the path to a single audio file. There is one [example file](test_files/752547__iscence__milan_metro_coming_in_station.wav) included.
74
+ More example files can be downloaded from the [GitHub release](https://github.com/fschmid56/PretrainedSED/releases/tag/v0.0.1).
75
+
76
+ **Low-complexity** inference with customized MobileNet:
77
+
78
+ ```python
79
+ python inference.py --cuda --model_name="frame_mn06" --audio_file="test_files/752547__iscence__milan_metro_coming_in_station.wav"
80
+ ```
81
+
82
+ ## Model Checkpoints
83
+
84
+ The following is a list of checkpoints that we have created and worked with in our paper. For external checkpoints, we provide the download link. "Checkpoint Name" refers to the respective names in our [GitHub release](https://github.com/fschmid56/PretrainedSED/releases/tag/v0.0.1). **All model checkpoints** are automatically downloaded by running the code, or can be manually downloaded from the [GitHub release](https://github.com/fschmid56/PretrainedSED/releases/tag/v0.0.1).
85
+
86
+ | Model | Pre-Training | Checkpoint Name | External Download Link | Reference |
87
+ |----------------------|--------------|--------------------|---------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------|
88
+ | BEATs | SSL | BEATs_ssl.pt | [here](https://1drv.ms/u/s!AqeByhGUtINrgcpxJUNDxg4eU0r-vA?e=qezPJ5) | [[1]](https://arxiv.org/pdf/2212.09058) |
89
+ | BEATs | Weak | BEATs_weak.pt | [here](https://1drv.ms/u/s!AqeByhGUtINrgcpke6_lRSZEKD5j2Q?e=A3FpOf) | [[1]](https://arxiv.org/pdf/2212.09058) |
90
+ | BEATs | Strong | BEATs_strong_1.pt | ours | [[1]](https://arxiv.org/pdf/2212.09058) |
91
+ | ATST-Frame | SSL | ATST-F_ssl.pt | [here](https://drive.google.com/file/d/1bGJSZWlAIIJ6GL5Id5dW0PTB72DL-QDQ/view?usp=sharing) | [[2]](https://arxiv.org/pdf/2306.04186) |
92
+ | ATST-Frame | Weak | ATST-F_weak.pt | [here](https://drive.google.com/file/d/1_xb0_n3UNbUG_pH1vLHTviLfsaSfCzxz/view?usp=drive_link) | [[2]](https://arxiv.org/pdf/2306.04186) |
93
+ | ATST-Frame | Strong | ATST-F_strong_1.pt | ours | [[2]](https://arxiv.org/pdf/2306.04186) |
94
+ | fPaSST | SSL | fpasst_im.pt | [here](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth) | [[3]](https://arxiv.org/pdf/2110.05069), [[4]](https://arxiv.org/pdf/2407.12997) |
95
+ | fPaSST | Weak | fpasst_weak.pt | ours | [[3]](https://arxiv.org/pdf/2110.05069), [[4]](https://arxiv.org/pdf/2407.12997) |
96
+ | fPaSST | Strong | fpasst_strong_1.pt | ours | [[3]](https://arxiv.org/pdf/2110.05069), [[4]](https://arxiv.org/pdf/2407.12997) |
97
+ | ASiT | SSL | ASIT_ssl.pt | [here](https://drive.google.com/file/d/11eaOU40jonpYZ3u_XI-XUSSWclv8qeR7/view?usp=drive_link) | [[5]](https://arxiv.org/pdf/2211.13189) |
98
+ | ASiT | Weak | ASIT_weak.pt | ours | [[5]](https://arxiv.org/pdf/2211.13189) |
99
+ | ASiT | Strong | ASIT_strong_1.pt | ours | [[5]](https://arxiv.org/pdf/2211.13189) |
100
+ | M2D | SSL | M2D_ssl.pt | [here](https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_clap_vit_base-80x1001p16x16-240128_AS-FT_enconly.zip) | [[6]](https://arxiv.org/pdf/2406.02032) |
101
+ | M2D | Weak | M2D_weak.pt | [here](https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_clap_vit_base-80x1001p16x16-240128_AS-FT_enconly.zip) | [[6]](https://arxiv.org/pdf/2406.02032) |
102
+ | M2D | Strong | M2D_strong_1.pt | ours | [[6]](https://arxiv.org/pdf/2406.02032) |
103
+ | Customized MobileNet | Strong | frame_mn06.pt | ours | **NEW** |
104
+ | Customized MobileNet | Strong | frame_mn10.pt | ours | **NEW** |
105
+
106
+ ## AudioSet Strong pre-training
107
+
108
+ ### Prepare Dataset
109
+
110
+ 1. Follow the steps described [here](https://github.com/kkoutini/PaSST/tree/main/audioset#experiments-on-audioset) to obtain AudioSet, encoded as mp3 files and packed into HDF5 format.
111
+
112
+ You will end up with a directory containing three HDF5 files:
113
+ * balanced_train_segments_mp3.hdf
114
+ * unbalanced_train_segments_mp3.hdf
115
+ * eval_segments_mp3.hdf
116
+
117
+ 2. We use the [Huggingface datasets](https://huggingface.co/docs/datasets/index) API for fast and memory-efficient loading of the dataset. The [hf_dataset_gen/audioset_strong.py](hf_dataset_gen/audioset_strong.py) file takes the dataset from Step 1 and converts it into a Huggingface dataset.
118
+
119
+ Adapt the paths in [hf_dataset_gen/audioset_strong.py](hf_dataset_gen/audioset_strong.py) marked as TODOs (2x: hdf5 path and target path for HF dataset).
120
+
121
+ 3. Create the Hunggingface dataset:
122
+ ```
123
+ cd hf_dataset_gen
124
+ python audioset_strong.py
125
+ ```
126
+
127
+ 4. The path to the dataset is specified via an environment variable. When you access the dataset for training or evaluation,
128
+ set the environment variable. For example, in our case, the Huggingface dataset path is set to:
129
+
130
+ ```/share/hel/datasets/HF_datasets/local/audioset_strong```
131
+
132
+ And therefore we set the following environment variable:
133
+
134
+ ```
135
+ export HF_DATASETS_CACHE=/share/hel/datasets/HF_datasets/cache/
136
+ ```
137
+
138
+ ### Download ensemble pseudo labels
139
+
140
+ If you want to train on AudioSet Strong using Knowledge Distillation as described in the paper, you will have to download the
141
+ ensemble logits from [Zenodo](https://zenodo.org/records/14626113). The HDF5 file contains filenames (Youtube IDs) matched with the corresponding ensembled logits. The corresponding keys are "filenames" and "strong_logits". Ensemble Logits for one file are of shape 447 x 250 (number of classes x timeframes at 40 ms resolution). Ensemble Logits are stored in float16 format to save space.
142
+
143
+ Check out [this code piece](https://github.com/fschmid56/PretrainedSED/blob/f62e9fb1566254766396cce0343a2de4156d3015/data_util/transforms.py#L37) if you want to learn how pseudo labels are loaded.
144
+
145
+ For training, the pseudo-label file can simply be set via command line: ```--pseudo_labels_file=<location>```
146
+
147
+ ### Run AudioSet Strong training
148
+
149
+ Example: Train ATST-F, pretrained on AudioSet weak, with an RNN on top, use the balanced sampler and set wavmix augmentation to probability of 1.0.
150
+
151
+ ```
152
+ python ex_audioset_strong.py --model_name=ATST-F --seq_model_type=rnn --use_balanced_sampler --pretrained=weak --wavmix_p=1.0
153
+ ```
154
+
155
+ Check out the results: https://api.wandb.ai/links/cp_tobi/tphswm5k
156
+
157
+ Example: Train ATST-F using Knowledge Distillation.
158
+
159
+ ```
160
+ python ex_audioset_strong.py --model_name=ATST-F --pretrained=weak --n_epochs=120 --wavmix_p=0.5 --freq_warp_p=0 --filter_augment_p=0 --mixstyle_p=0 --max_lr=1e-4 --distillation_loss_weight=0.9 --pseudo_labels_file=<path_to_pseudo_label_file_from_Zenodo>
161
+ ```
162
+
163
+ Check out the results: https://api.wandb.ai/links/cp_tobi/2eh4cz80
164
+
165
+ ### Run AudioSet Strong evaluation
166
+
167
+ Evaluate the AudioSet Strong pre-trained checkpoint of ATST-F:
168
+
169
+ ```
170
+ python ex_audioset_strong.py --model_name=ATST-F --pretrained=strong --evaluate
171
+ ```
172
+
173
+ If everything is set up correctly, this should give a `val/psds1_macro_averaged` of around 46.
174
+
175
+ ## Fine-Tuning on Downstream Task
176
+
177
+ We demonstrate how pre-trained transformers can be fine-tuned for the downstream Sound Event Detection task by using our transformers on [DCASE 2016 Task 2](https://dcase.community/challenge2016/task-sound-event-detection-in-synthetic-audio-results). This task focuses on detecting office sounds and is part of the [HEAR benchmark](https://hearbenchmark.com/hear-tasks.html).
178
+
179
+ ### Obtain DCASE 2016 Task 2 Dataset in HEAR format
180
+
181
+ Follow the instructions on the [HEAR website](https://hearbenchmark.com/hear-tasks.html) to download the dataset in 16 kHz sampling rate. After completing the setup, your file tree should look similar to this:
182
+ ```
183
+ hear_datasets/tasks/dcase2016_task2-hear2021-full/
184
+ ├── 16000
185
+ ├── 48000
186
+ ├── labelvocabulary.csv
187
+ ├── task_metadata.json
188
+ ├── test.json
189
+ ├── train.json
190
+ └── valid.json
191
+ ```
192
+
193
+ The ```16000``` folder contains audio files sampled at 16 kHz.
194
+
195
+ ### Run Fine-Tuning
196
+
197
+ The main script for fine-tuning is [ex_dcase2016task2.py](ex_dcase2016task2.py).
198
+
199
+ To fine-tune the full ATST-F model, pre-trained on AudioSet Strong, with a layer-wise learning rate decay of 0.95, use the following command:
200
+
201
+ ```
202
+ python ex_dcase2016task2.py --task_path=hear_datasets/tasks/dcase2016_task2-hear2021-full --model_name=ATST-F --pretrained=strong --lr_decay=0.95
203
+ ```
204
+
205
+ To train only the linear prediction head on top of the frozen BEATs transformer, also pre-trained on AudioSet Strong, use this command:
206
+
207
+ ```
208
+ python ex_dcase2016task2.py --task_path=hear_datasets/tasks/dcase2016_task2-hear2021-full --model_name=BEATs --pretrained=strong --transformer_frozen --max_lr=2e-1 --mixup_p=0 --wavmix_p=0 --no_adamw --weight_decay=0 --n_epochs=500
209
+ ```
210
+
211
+ ## Results & Ablation Studies
212
+
213
+ This section presents the main results reported [in the paper](https://arxiv.org/pdf/2409.09546), along with additional ablation studies, including teacher model performances, comparisons of different sequence models, and evaluations using the DESED baseline system setup. The additional ablation studies have been requested by ICASSP`25 reviewers.
214
+
215
+ * All results represent averages over three independent runs.
216
+ * For AudioSet Strong, we employ the threshold-independent PSDS1 [7] metric to ensure fine-grained temporal evaluation.
217
+
218
+
219
+ ### Student Model Performances on AudioSet Strong (*from paper*)
220
+
221
+ * For the *Li et al. [2]* row, we reproduced their AudioSet Strong [training pipeline](https://github.com/Audio-WestlakeU/audiossl).
222
+ * Alongside the **Proposed Pipeline**, we include ablation studies for three settings: no KD, no RNN in teacher models, and no pre-training on AudioSet Weak (no Step 2).
223
+
224
+ | | **ATST-F** | **BEATs** | **fPaSST** | **M2D** | **ASiT** |
225
+ |-----------------------|------------|-----------|------------|----------|----------|
226
+ | **Li et al. [2]** | 40.9 | 36.5 | 38.7 | 36.9 | 37.0 |
227
+ | **Proposed Pipeline** | **45.8** | **46.5** | **45.4** | **46.3** | **46.2** |
228
+ | **-- without KD** | 41.8 | 44.1 | 40.7 | 41.1 | 40.9 |
229
+ | **-- without RNN** | 45.7 | 45.8 | 45.3 | 46.0 | 46.1 |
230
+ | **-- without Step 2** | 45.7 | 46.3 | 45.2 | 44.9 | **46.2** |
231
+
232
+ **Conclusions:**
233
+ * The significant performance gap to [2] stems mainly from our three design choices (KD, RNNs, Step 2), but also improvements in training on AudioSet Strong, including balanced sampling and aggressive data augmentation.
234
+ * Knowledge Distillation (KD) has the most substantial impact, underlining the effectiveness of the ensemble-KD approach.
235
+ * RNNs in teacher models and pre-training on AudioSet Weak offer modest improvements but are justified due to their low additional cost. Notably, they do not increase student model complexity, and AudioSet Weak checkpoints are publicly available for most transformers.
236
+
237
+ ### Teacher Model Performances on AudioSet Strong (*additional results*)
238
+
239
+ * The table below shows teacher model results for each transformer.
240
+ * Column **Avg. Ind.** represents the average performance across all single models in the row.
241
+ * Column **Ensemble** represents the performance of the ensemble consisting of all models in the respective row.
242
+
243
+ | | **ATST-F** | **BEATs** | **fPaSST** | **M2D** | **ASiT** | **Avg. Ind.** | **Ensemble** |
244
+ |-------------------------------|------------|-----------|------------|----------|----------|---------------|--------------|
245
+ | **Proposed Teacher Pipeline** | 43.3 | **45.8** | **43.3** | **44.1** | **43.3** | **44.9** | **47.1** |
246
+ | **-- without RNN** | 41.8 | 44.1 | 40.7 | 41.1 | 40.9 | 41.7 | 46.2 |
247
+ | **-- without Step 2** | **43.5** | 34.4 | 40.9 | 43.8 | 43.2 | 41.2 | 46.5 |
248
+
249
+ **Conclusions:**
250
+ * *Ensemble Performance*: The *Ensemble* column reflects the teacher ensemble performances utilized for Knowledge Distillation (KD) in table above.
251
+ * *Impact of RNNs and Step 2*: Incorporating RNNs and Step 2 (AudioSet Weak pre-training) notably enhances single-model teacher performance, with the exception of ATST-F without Step 2.
252
+ * *Benefits of Ensembling*: While individual model performances show considerable variability (Avg. Ind.), ensembling stabilizes and elevates overall performance, as evidenced by the smaller differences in the *Ensemble* column.
253
+ * *BEATs-Specific Insights*: BEATs excels in the *Proposed Teacher Pipeline* and *without RNN* settings but underperforms in the *without Step 2* configuration. This discrepancy may be attributed to its unique SSL pre-training routine and longer sequence length (resulting from more tokens being extracted from the input).
254
+
255
+ ### Teacher Model Performances with different Sequence Models (*additional results*)
256
+
257
+ * The use of an additional sequence model on top of the AudioSet Weak pre-trained transformers stems from our hypothesis that adding capacity specifically for temporally-strong predictions can enhance performance.
258
+ * The table below shows teacher model performances for various sequence models added on top of the transformers before training on AudioSet Strong. The paper uses BiGRUs (RNN) as they deliver the best performance.
259
+ * We investigated 4 different sequence models:
260
+ * RNN: BiGRUs
261
+ * Attention: Multi-Head Self-Attention with rotary position embeddings
262
+ * Transformer (TF): Transformer Encoder blocks with rotary position embeddings
263
+ * [MAMBA](https://arxiv.org/abs/2312.00752): Implementation from [mambapy](https://github.com/alxndrTL/mamba.py)
264
+ * We varied the inner dimension (*dim*) and the number of layers (\<Model Type\>:\<#layers\>; e.g., TF:2 means two Transformer layers were added on top of the pre-trained transformer).
265
+ * **Setup Notes**:
266
+ * Ablations were performed using **ATST-F** due to its computational efficiency.
267
+ * Performance without a sequence model was **41.8 PSDS1**.
268
+ * Removing the top Transformer layers, which may overfit to AudioSet Weak labels, decreased performance.
269
+ * For MAMBA, only a single layer was feasible due to memory constraints.
270
+
271
+ | | RNN:1 | RNN:2 | RNN:3 | TF:1 | TF:2 | TF:3 | ATT:1 | ATT:2 | ATT:3 | MAMBA:1 |
272
+ |:-------------|:-----:|:---------:|:-----:|:---------:|:-----:|:---------:|:-----:|:---------:|:-----:|:---------:|
273
+ | **dim=256** | 8.72 | 3.76 | 3.10 | 34.25 | 34.62 | 34.05 | 40.08 | 39.70 | 39.55 | 40.27 |
274
+ | **dim=512** | 40.62 | 7.26 | 0.12 | 40.41 | 41.11 | 40.30 | 41.78 | 41.91 | 41.95 | 41.25 |
275
+ | **dim=1024** | 42.74 | 42.75 | 43.00 | 42.69 | 42.22 | 42.20 | 42.44 | **42.45** | 42.08 | **41.97** |
276
+ | **dim=2048** | 43.41 | **43.43** | 42.66 | **42.90** | 38.94 | **42.90** | 41.58 | 41.59 | 41.42 | 41.72 |
277
+
278
+ **Conclusions:**
279
+ * *Best model type*: The highest performance was achieved with 2 BiGRU layers, followed by Transformer, Self-Attention, and MAMBA. All sequence models improved performance compared to using no additional sequence model, though MAMBA's gains were marginal.
280
+ * *Inner Dimension*: Larger inner dimensions consistently led to better performance across all sequence models. Significant improvements required dimensions ≥1024, while smaller dimensions (e.g., 256) often degraded performance, with severe failures for BiGRU. We believe that large inner dimensions are essential due to the high number of classes (447) in AudioSet Strong.
281
+ * *Number of layers*: Performance was relatively insensitive to the number of layers for most sequence models, with optimal results often achieved with just 1–2 layers.
282
+
283
+
284
+ ### Downstream Task Performances (*from paper*)
285
+
286
+ * Three frame-level downstream tasks:
287
+ * DCASE 2023 Task 4: Domestic Environment Sound Event Detection (*DESED*), metric: PSDS 1
288
+ * DCASE 2016 Task 2 (*DC16-T2*), metric: onset F-measure
289
+ * MAESTRO 5hr (*MAESTRO*), metric: onset F-measure
290
+ * For DESED, we followed a simplified setup in line with [2], excluding unsupervised data (no mean teacher approach) and an additional CRNN component from the [DCASE 2023 Task 4 baseline system](https://github.com/DCASE-REPO/DESED_task/tree/master/recipes/dcase2023_task4_baseline). While state-of-the-art approaches such as [4] and [8] leverage advanced techniques (e.g., multi-stage/multi-iteration training, sophisticated data augmentation, and interpolation consistency training), we deliberately avoided these complexities, as the focus is on a precise evaluation of pre-training quality.
291
+
292
+ ![Downstream Task Results](/images/downstream_task_results.png)
293
+
294
+ **Conclusions:**
295
+ * *In-Domain Tasks*: The pipeline demonstrates strong, consistent improvements for all transformers on *DESED* and *DC16-T2*, showcasing its effectiveness for in-domain tasks.
296
+ * *Out-of-Domain Task*: Results on *MAESTRO* (piano pitch prediction) are inconclusive. This limitation suggests that the proposed pre-training strategy yields substantial gains only when audio and labels are similar to the AudioSet ontology.
297
+ * *Simplified DESED Setup*: Despite the simplified setup (no CRNN, no unsupervised data), performance remains comparable to the [DCASE 2023 Task 4 baseline system](https://github.com/DCASE-REPO/DESED_task/tree/master/recipes/dcase2023_task4_baseline).
298
+
299
+ #### DESED Baseline Setup (*additional results*)
300
+
301
+ To complement the simplified DESED setup presented earlier, we provide results for the [DCASE 2023 Task 4 baseline system](https://github.com/DCASE-REPO/DESED_task/tree/master/recipes/dcase2023_task4_baseline) setup for ATST-F and BEATs in the table below. Note that hyperparameters were not extensively tuned, and the data setup may differ slightly from the original baseline.
302
+
303
+ | **Model** | **Checkpoint** | **Notes** | **Performance** |
304
+ |-----------|------------------|---------------------|-----------------|
305
+ | ATST-F | Step 1 (SSL) | | 42.7 |
306
+ | ATST-F | Step 2 (AS weak) | | 47.1 |
307
+ | ATST-F | Full Pipeline | | 50.4 |
308
+ | ATST-F | Full Pipeline | dropped 2 TF layers | **51.1** |
309
+ | BEATs | Step 1 (SSL) | | 39.7 |
310
+ | BEATs | Step 2 (AS weak) | | 48.1 |
311
+ | BEATs | Full Pipeline | | 48.6 |
312
+ | BEATs | Full Pipeline | dropped 2 TF layers | **51.1** |
313
+
314
+ **Conclusions**:
315
+ * The *Full Pipeline* substantially improves performance over *Step 1 (SSL)* and *Step 2 (AS Weak)* for both ATST-F and BEATs.
316
+ * Dropping the last two Transformer layers notably enhances results, suggesting that the final layers may focus on AudioSet Strong label-specific features, while earlier layers provide more general, transferable embeddings that benefit the DESED task. We will conduct further experiments to find out whether dropping Transformer layers is generalizable to other tasks or specific to the DESED task.
317
+
318
+ # References
319
+
320
+ [1] S. Chen, Y. Wu, C. Wang, S. Liu, D. Tompkins, Z. Chen, W. Che, X. Yu, and F. Wei, “BEATs: Audio pre-training with acoustic tokenizers,” in Proceedings of the International Conference on Machine Learning (ICML), 2023.
321
+
322
+ [2] X. Li, N. Shao, and X. Li, “Self-supervised audio teacher-student transformer for both clip-level and frame-level tasks,” Transactions on Audio, Speech, and Language Processing, vol. 32, pp. 1336–1351, 2024.
323
+
324
+ [3] K. Koutini, J. Schl¨uter, H. Eghbal-zadeh, and G. Widmer, “Efficient training of audio transformers with patchout,” in Proceedings of the Interspeech Conference, 2022.
325
+
326
+ [4] F. Schmid, P. Primus, T. Morocutti, J. Greif, and G. Widmer, “Multi-iteration multi-stage fine-tuning of transformers for sound event detection with heterogeneous datasets,” in Workshop on Detection and Classification of Acoustic Scenes and Events (DCASE), 2024.
327
+
328
+ [5] S. Atito, M. Awais, W. Wang, M. D. Plumbley, and J. Kittler, “ASiT: Local-global audio spectrogram vision transformer for event classification,” IEEE ACM Trans. Audio Speech Lang. Process., vol. 32, pp. 3684–3693, 2024.
329
+
330
+ [6] D. Niizumi, D. Takeuchi, Y. Ohishi, N. Harada, M. Yasuda, S. Tsubaki, and K. Imoto, “M2D-CLAP: masked modeling duo meets CLAP for learning general-purpose audio-language representation,” in Proceedings of the Interspeech Conference, 2024.
331
+
332
+ [7] J. Ebbers, R. Haeb-Umbach, and R. Serizel, “Threshold independent evaluation of sound event detection scores,” in Proceedings of the International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2022.
333
+
334
+ [8] N. Shao, X. Li, and X. Li, “Fine-tune the pretrained ATST model for sound event detection,” in Proceedings of the International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2024
335
+
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from models.atstframe.ATSTF_wrapper import ATSTWrapper
4
+ from models.beats.BEATs_wrapper import BEATsWrapper
5
+ from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
6
+ from models.m2d.M2D_wrapper import M2DWrapper
7
+ from models.asit.ASIT_wrapper import ASiTWrapper
8
+ from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
9
+ from models.prediction_wrapper import PredictionsWrapper
10
+ from models.frame_mn.utils import NAME_TO_WIDTH
11
+ import torch
12
+ from torch import nn
13
+ import pandas as pd
14
+
15
+ class TransformerClassifier(nn.Module):
16
+ def __init__(self, model, n_classes):
17
+ super(TransformerClassifier, self).__init__()
18
+ self.model = model
19
+ self.linear = nn.Linear(model.embed_dim, n_classes)
20
+
21
+ def forward(self, x):
22
+ mel = self.model.mel_forward(x)
23
+ features = self.model(mel).squeeze(1)
24
+ return self.linear(features)
25
+
26
+
27
+ def get_model(model_name):
28
+ if model_name == "BEATs":
29
+ beats = BEATsWrapper()
30
+ model = PredictionsWrapper(beats, checkpoint=None, head_type=None, seq_len=1)
31
+ elif model_name == "ATST-F":
32
+ atst = ATSTWrapper()
33
+ model = PredictionsWrapper(atst, checkpoint=None, head_type=None, seq_len=1)
34
+ elif model_name == "fpasst":
35
+ fpasst = FPaSSTWrapper()
36
+ model = PredictionsWrapper(fpasst, checkpoint=None, head_type=None, seq_len=1)
37
+ elif model_name == "M2D":
38
+ m2d = M2DWrapper()
39
+ model = PredictionsWrapper(m2d, checkpoint=None, head_type=None, seq_len=1,
40
+ embed_dim=m2d.m2d.cfg.feature_d)
41
+ elif model_name == "ASIT":
42
+ asit = ASiTWrapper()
43
+ model = PredictionsWrapper(asit, checkpoint=None, head_type=None, seq_len=1)
44
+ elif model_name.startswith("frame_mn"):
45
+ width = NAME_TO_WIDTH(model_name)
46
+ frame_mn = FrameMNWrapper(width)
47
+ embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
48
+ model = PredictionsWrapper(frame_mn, checkpoint=None, head_type=None, seq_len=1, embed_dim=embed_dim)
49
+ else:
50
+ raise NotImplementedError(f"Model {model_name} not (yet) implemented")
51
+ main_model = TransformerClassifier(model, n_classes=200)
52
+ # main_model.compile()
53
+ main_model.load_state_dict(torch.load(f"resources/best_model_{model_name}.pth", map_location='cpu'))
54
+ print(main_model)
55
+ main_model.eval()
56
+ return main_model
57
+
58
+ model = get_model("BEATs")
59
+ label_mapping = pd.read_csv("resources/labelvocabulary.csv", header=None, index_col=0).to_dict()[1]
60
+ threshold = 0.4
61
+ def predict(input_audio):
62
+ # Apply sepia effect to the audio
63
+ with torch.no_grad():
64
+ waveform = torch.from_numpy(input_audio[1]).float() # Convert to tensor
65
+ output = model(waveform.unsqueeze(0)).squeeze(0) # Add batch dimension
66
+ output = output.sigmoid()
67
+ num_labels = torch.where(output >= threshold)[0].tolist()
68
+ labels = [label_mapping[str(i)] for i in num_labels]
69
+ return ", ".join(labels) if labels else "No sound event detected"
70
+
71
+
72
+ demo = gr.Interface(predict, gr.Audio(max_length=30,), "text",title="Freesound Sound Event Detection",)
73
+ demo.launch()
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RESOURCES_FOLDER = "resources"
2
+ GITHUB_RELEASE_URL = "https://github.com/fschmid56/PretrainedSED/releases/download/v0.0.1/"
3
+
4
+ # checkpoints
5
+ CHECKPOINT_URLS = {}
6
+
7
+ # strong
8
+ CHECKPOINT_URLS['BEATs_strong_1'] = GITHUB_RELEASE_URL + "BEATs_strong_1.pt"
9
+ CHECKPOINT_URLS['ATST-F_strong_1'] = GITHUB_RELEASE_URL + "ATST-F_strong_1.pt"
10
+ CHECKPOINT_URLS['ASIT_strong_1'] = GITHUB_RELEASE_URL + "ASIT_strong_1.pt"
11
+ CHECKPOINT_URLS['fpasst_strong_1'] = GITHUB_RELEASE_URL + "fpasst_strong_1.pt"
12
+ CHECKPOINT_URLS['M2D_strong_1'] = GITHUB_RELEASE_URL + "M2D_strong_1.pt"
13
+ for width in ['06', '10']:
14
+ CHECKPOINT_URLS[f'frame_mn{width}_strong_1'] = GITHUB_RELEASE_URL + f'frame_mn{width}_strong_1.pt'
15
+
16
+ # weak
17
+ CHECKPOINT_URLS['BEATs_weak'] = GITHUB_RELEASE_URL + "BEATs_weak.pt"
18
+ CHECKPOINT_URLS['ATST-F_weak'] = GITHUB_RELEASE_URL + "ATST-F_weak.pt"
19
+ CHECKPOINT_URLS['ASIT_weak'] = GITHUB_RELEASE_URL + "ASIT_weak.pt"
20
+ CHECKPOINT_URLS['fpasst_weak'] = GITHUB_RELEASE_URL + "fpasst_weak.pt"
21
+ CHECKPOINT_URLS['M2D_weak'] = GITHUB_RELEASE_URL + "M2D_weak.pt"
22
+
23
+ # ssl
24
+ CHECKPOINT_URLS['BEATs_ssl'] = GITHUB_RELEASE_URL + "BEATs_ssl.pt"
25
+ CHECKPOINT_URLS['ATST-F_ssl'] = GITHUB_RELEASE_URL + "ATST-F_ssl.pt"
26
+ CHECKPOINT_URLS['ASIT_ssl'] = GITHUB_RELEASE_URL + "ASIT_ssl.pt"
27
+ CHECKPOINT_URLS['fpasst_ssl'] = GITHUB_RELEASE_URL + "fpasst_ssl.pt"
28
+ CHECKPOINT_URLS['M2D_ssl'] = GITHUB_RELEASE_URL + "M2D_ssl.pt"
data_util/audioset_classes.py ADDED
@@ -0,0 +1,1393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ as_strong_train_classes = ['Accelerating, revving, vroom',
2
+ 'Air brake',
3
+ 'Air conditioning',
4
+ 'Air horn, truck horn',
5
+ 'Aircraft',
6
+ 'Aircraft engine',
7
+ 'Alarm',
8
+ 'Alarm clock',
9
+ 'Alert',
10
+ 'Ambulance (siren)',
11
+ 'Animal',
12
+ 'Applause',
13
+ 'Arrow',
14
+ 'Artillery fire',
15
+ 'Audio logo',
16
+ 'Babbling',
17
+ 'Baby cry, infant cry',
18
+ 'Baby laughter',
19
+ 'Background noise',
20
+ 'Bang',
21
+ 'Bark',
22
+ 'Basketball bounce',
23
+ 'Bathroom sounds',
24
+ 'Bathtub (filling or washing)',
25
+ 'Battle cry',
26
+ 'Bee, wasp, etc.',
27
+ 'Beep, bleep',
28
+ 'Bell',
29
+ 'Bellow',
30
+ 'Belly laugh',
31
+ 'Bicycle bell',
32
+ 'Bicycle, tricycle',
33
+ 'Bird',
34
+ 'Bird flight, flapping wings',
35
+ 'Bird vocalization, bird call, bird song',
36
+ 'Biting',
37
+ 'Bleat',
38
+ 'Blender, food processor',
39
+ 'Boat, Water vehicle',
40
+ 'Boiling',
41
+ 'Boing',
42
+ 'Booing',
43
+ 'Boom',
44
+ 'Bouncing',
45
+ 'Bow-wow',
46
+ 'Breaking',
47
+ 'Breathing',
48
+ 'Brief tone',
49
+ 'Burping, eructation',
50
+ 'Burst, pop',
51
+ 'Bus',
52
+ 'Busy signal',
53
+ 'Buzz',
54
+ 'Buzzer',
55
+ 'Cacophony',
56
+ 'Camera',
57
+ 'Canidae, wild dogs, wolves',
58
+ 'Cap gun',
59
+ 'Car',
60
+ 'Car alarm',
61
+ 'Car passing by',
62
+ 'Carbon monoxide detector, CO detector',
63
+ 'Cart',
64
+ 'Cash register',
65
+ 'Cat',
66
+ 'Caterwaul',
67
+ 'Cattle, bovinae',
68
+ 'Caw',
69
+ 'Cellphone buzz, vibrating alert',
70
+ 'Chain',
71
+ 'Chainsaw',
72
+ 'Change ringing (campanology)',
73
+ 'Channel, environment and background',
74
+ 'Chant',
75
+ 'Cheering',
76
+ 'Chewing, mastication',
77
+ 'Chicken, rooster',
78
+ 'Child singing',
79
+ 'Child speech, kid speaking',
80
+ 'Children playing',
81
+ 'Children shouting',
82
+ 'Chime',
83
+ 'Chipmunk',
84
+ 'Chirp tone',
85
+ 'Chirp, tweet',
86
+ 'Choir',
87
+ 'Chop',
88
+ 'Chopping (food)',
89
+ 'Chorus effect',
90
+ 'Chuckle, chortle',
91
+ 'Church bell',
92
+ 'Civil defense siren',
93
+ 'Clang',
94
+ 'Clapping',
95
+ 'Clatter',
96
+ 'Clickety-clack',
97
+ 'Clicking',
98
+ 'Clip-clop',
99
+ 'Clock',
100
+ 'Cluck',
101
+ 'Clunk',
102
+ 'Coin (dropping)',
103
+ 'Computer keyboard',
104
+ 'Conversation',
105
+ 'Coo',
106
+ 'Cough',
107
+ 'Cowbell',
108
+ 'Crack',
109
+ 'Crackle',
110
+ 'Creak',
111
+ 'Cricket',
112
+ 'Croak',
113
+ 'Crockery breaking and smashing',
114
+ 'Crow',
115
+ 'Crowd',
116
+ 'Crowing, cock-a-doodle-doo',
117
+ 'Crumpling, crinkling',
118
+ 'Crunch',
119
+ 'Crushing',
120
+ 'Crying, sobbing',
121
+ 'Cupboard open or close',
122
+ 'Cutlery, silverware',
123
+ 'Deformable shell',
124
+ "Dental drill, dentist's drill",
125
+ 'Dial tone',
126
+ 'Digestive',
127
+ 'Ding',
128
+ 'Ding-dong',
129
+ 'Dishes, pots, and pans',
130
+ 'Distortion',
131
+ 'Dog',
132
+ 'Domestic animals, pets',
133
+ 'Dong, bong',
134
+ 'Donkey, ass',
135
+ 'Door',
136
+ 'Doorbell',
137
+ 'Drawer open or close',
138
+ 'Drill',
139
+ 'Drip',
140
+ 'Duck call (hunting tool)',
141
+ 'Ducks, geese, waterfowl',
142
+ 'Echo',
143
+ 'Effects unit',
144
+ 'Electric rotor drone, quadcopter',
145
+ 'Electric shaver, electric razor',
146
+ 'Electric toothbrush',
147
+ 'Electronic tuner',
148
+ 'Emergency vehicle',
149
+ 'Engine',
150
+ 'Engine knocking',
151
+ 'Engine starting',
152
+ 'Environmental noise',
153
+ 'Error signal',
154
+ 'Eruption',
155
+ 'Explosion',
156
+ 'Fart',
157
+ 'Female singing',
158
+ 'Female speech, woman speaking',
159
+ 'Filing (rasp)',
160
+ 'Fill (with liquid)',
161
+ 'Finger snapping',
162
+ 'Fire',
163
+ 'Fire alarm',
164
+ 'Fire engine, fire truck (siren)',
165
+ 'Firecracker',
166
+ 'Fireworks',
167
+ 'Fixed-wing aircraft, airplane',
168
+ 'Fizz',
169
+ 'Flap',
170
+ 'Fly, housefly',
171
+ 'Foghorn',
172
+ 'Fowl',
173
+ 'Frog',
174
+ 'Frying (food)',
175
+ 'Fusillade',
176
+ 'Gargling',
177
+ 'Gasp',
178
+ 'Gears',
179
+ 'Generic impact sounds',
180
+ 'Giggle',
181
+ 'Glass',
182
+ 'Glass chink, clink',
183
+ 'Glass shatter',
184
+ 'Goat',
185
+ 'Gobble',
186
+ 'Grind',
187
+ 'Groan',
188
+ 'Growling',
189
+ 'Grunt',
190
+ 'Gull, seagull',
191
+ 'Gunshot, gunfire',
192
+ 'Gurgling, bubbling',
193
+ 'Gush',
194
+ 'Hair dryer',
195
+ 'Hammer',
196
+ 'Hands',
197
+ 'Heart sounds, heartbeat',
198
+ 'Heavy engine (low frequency)',
199
+ 'Helicopter',
200
+ 'Hiccup',
201
+ 'Hiss',
202
+ 'Honk',
203
+ 'Hoot',
204
+ 'Horse',
205
+ 'Howl',
206
+ 'Howl (wind)',
207
+ 'Hubbub, speech noise, speech babble',
208
+ 'Hum',
209
+ 'Human group actions',
210
+ 'Human locomotion',
211
+ 'Human sounds',
212
+ 'Human voice',
213
+ 'Humming',
214
+ 'Ice cream truck, ice cream van',
215
+ 'Idling',
216
+ 'Insect',
217
+ 'Inside, large room or hall',
218
+ 'Inside, public space',
219
+ 'Inside, small room',
220
+ 'Jackhammer',
221
+ 'Jet engine',
222
+ 'Jingle bell',
223
+ 'Jingle, tinkle',
224
+ 'Kettle whistle',
225
+ 'Keypress tone',
226
+ 'Keys jangling',
227
+ 'Kitchen and dining room sounds',
228
+ 'Knife',
229
+ 'Knock',
230
+ 'Laughter',
231
+ 'Lawn mower',
232
+ 'Light engine (high frequency)',
233
+ 'Liquid',
234
+ 'Livestock, farm animals, working animals',
235
+ 'Lock',
236
+ 'Machine gun',
237
+ 'Mains hum',
238
+ 'Male singing',
239
+ 'Male speech, man speaking',
240
+ 'Mantra',
241
+ 'Mechanical bell',
242
+ 'Mechanical fan',
243
+ 'Mechanisms',
244
+ 'Medium engine (mid frequency)',
245
+ 'Meow',
246
+ 'Microphone',
247
+ 'Microwave oven',
248
+ 'Moo',
249
+ 'Mosquito',
250
+ 'Motor vehicle (road)',
251
+ 'Motorboat, speedboat',
252
+ 'Motorcycle',
253
+ 'Mouse',
254
+ 'Music',
255
+ 'Narration, monologue',
256
+ 'Neigh, whinny',
257
+ 'Noise',
258
+ 'Non-motorized land vehicle',
259
+ 'Ocean',
260
+ 'Oink',
261
+ 'Other sourceless',
262
+ 'Outside, urban or manmade',
263
+ 'Owl',
264
+ 'Packing tape, duct tape',
265
+ 'Pant',
266
+ 'Pant (dog)',
267
+ 'Paper rustling',
268
+ 'Patter',
269
+ 'Pig',
270
+ 'Pigeon, dove',
271
+ 'Ping',
272
+ 'Plop',
273
+ 'Police car (siren)',
274
+ 'Pour',
275
+ 'Power saw, circular saw, table saw',
276
+ 'Power tool',
277
+ 'Power windows, electric windows',
278
+ 'Printer',
279
+ 'Propeller, airscrew',
280
+ 'Puff',
281
+ 'Pulleys',
282
+ 'Pulse',
283
+ 'Pump (liquid)',
284
+ 'Purr',
285
+ 'Quack',
286
+ 'Race car, auto racing',
287
+ 'Radio',
288
+ 'Rail transport',
289
+ 'Railroad car, train wagon',
290
+ 'Rain',
291
+ 'Rain on surface',
292
+ 'Raindrop',
293
+ 'Rapping',
294
+ 'Ratchet, pawl',
295
+ 'Rattle',
296
+ 'Refrigerator',
297
+ 'Respiratory sounds',
298
+ 'Reverberation',
299
+ 'Reversing beeps',
300
+ 'Ringing tone, ringback tone',
301
+ 'Ringtone',
302
+ 'Roar',
303
+ 'Roaring cats (lions, tigers)',
304
+ 'Rodents, rats, mice',
305
+ 'Roll',
306
+ 'Rowboat, canoe, kayak',
307
+ 'Rub',
308
+ 'Rumble',
309
+ 'Run',
310
+ 'Rustle',
311
+ 'Sailboat, sailing ship',
312
+ 'Sanding',
313
+ 'Sawing',
314
+ 'Scissors',
315
+ 'Scrape',
316
+ 'Scratch',
317
+ 'Screaming',
318
+ 'Screech',
319
+ 'Sewing machine',
320
+ 'Sheep',
321
+ 'Ship',
322
+ 'Shout',
323
+ 'Shower',
324
+ 'Shuffle',
325
+ 'Shuffling cards',
326
+ 'Sigh',
327
+ 'Sine wave',
328
+ 'Singing',
329
+ 'Single-lens reflex camera',
330
+ 'Sink (filling or washing)',
331
+ 'Siren',
332
+ 'Sizzle',
333
+ 'Skateboard',
334
+ 'Slam',
335
+ 'Slap, smack',
336
+ 'Sliding door',
337
+ 'Slosh',
338
+ 'Slurp, drinking straw',
339
+ 'Smash, crash',
340
+ 'Smoke detector, smoke alarm',
341
+ 'Snake',
342
+ 'Snap',
343
+ 'Sneeze',
344
+ 'Snicker',
345
+ 'Sniff',
346
+ 'Snoring',
347
+ 'Snort',
348
+ 'Snort (horse)',
349
+ 'Sonar',
350
+ 'Sonic boom',
351
+ 'Sound effect',
352
+ 'Sound equipment',
353
+ 'Sound reproduction',
354
+ 'Speech',
355
+ 'Speech synthesizer',
356
+ 'Splash, splatter',
357
+ 'Splinter',
358
+ 'Spray',
359
+ 'Squawk',
360
+ 'Squeak',
361
+ 'Squeal',
362
+ 'Squish',
363
+ 'Stairs',
364
+ 'Static',
365
+ 'Steam',
366
+ 'Steam whistle',
367
+ 'Stir',
368
+ 'Stomach rumble',
369
+ 'Stomp, stamp',
370
+ 'Stream, river',
371
+ 'Subway, metro, underground',
372
+ 'Surface contact',
373
+ 'Sweeping',
374
+ 'Synthetic singing',
375
+ 'Tap',
376
+ 'Tap dance',
377
+ 'Tape hiss',
378
+ 'Tearing',
379
+ 'Telephone',
380
+ 'Telephone bell ringing',
381
+ 'Telephone dialing, DTMF',
382
+ 'Television',
383
+ 'Throat clearing',
384
+ 'Thump, thud',
385
+ 'Thunder',
386
+ 'Thunderstorm',
387
+ 'Thunk',
388
+ 'Tick',
389
+ 'Tick-tock',
390
+ 'Tire squeal, skidding',
391
+ 'Toilet flush',
392
+ 'Tools',
393
+ 'Toothbrush',
394
+ 'Traffic noise, roadway noise',
395
+ 'Train',
396
+ 'Train horn',
397
+ 'Train wheels squealing',
398
+ 'Train whistle',
399
+ 'Trickle, dribble',
400
+ 'Truck',
401
+ 'Tuning fork',
402
+ 'Turkey',
403
+ 'Typewriter',
404
+ 'Typing',
405
+ 'Unknown sound',
406
+ 'Vacuum cleaner',
407
+ 'Vehicle',
408
+ 'Vehicle horn, car horn, honking, toot',
409
+ 'Velcro, hook and loop fastener',
410
+ 'Video game sound',
411
+ 'Wail, moan',
412
+ 'Walk, footsteps',
413
+ 'Washing machine',
414
+ 'Water',
415
+ 'Water tap, faucet',
416
+ 'Waterfall',
417
+ 'Waves, surf',
418
+ 'Whack, thwack',
419
+ 'Whale vocalization',
420
+ 'Wheeze',
421
+ 'Whimper',
422
+ 'Whimper (dog)',
423
+ 'Whip',
424
+ 'Whir',
425
+ 'Whispering',
426
+ 'Whistle',
427
+ 'Whistling',
428
+ 'White noise, pink noise',
429
+ 'Whoop',
430
+ 'Whoosh, swoosh, swish',
431
+ 'Wild animals',
432
+ 'Wildfire',
433
+ 'Wind',
434
+ 'Wind chime',
435
+ 'Wind noise (microphone)',
436
+ 'Windscreen wiper, windshield wiper',
437
+ 'Wobble',
438
+ 'Wolf-whistling',
439
+ 'Wood',
440
+ 'Writing',
441
+ 'Yak',
442
+ 'Yawn',
443
+ 'Yell',
444
+ 'Yip',
445
+ 'Yodeling',
446
+ 'Zing',
447
+ 'Zipper (clothing)']
448
+
449
+ as_strong_eval_classes = ['Accelerating, revving, vroom',
450
+ 'Air brake',
451
+ 'Air conditioning',
452
+ 'Air horn, truck horn',
453
+ 'Aircraft',
454
+ 'Aircraft engine',
455
+ 'Alarm',
456
+ 'Alarm clock',
457
+ 'Ambulance (siren)',
458
+ 'Animal',
459
+ 'Applause',
460
+ 'Arrow',
461
+ 'Artillery fire',
462
+ 'Audio logo',
463
+ 'Babbling',
464
+ 'Baby cry, infant cry',
465
+ 'Baby laughter',
466
+ 'Background noise',
467
+ 'Bang',
468
+ 'Bark',
469
+ 'Basketball bounce',
470
+ 'Bathtub (filling or washing)',
471
+ 'Battle cry',
472
+ 'Bee, wasp, etc.',
473
+ 'Beep, bleep',
474
+ 'Bell',
475
+ 'Bellow',
476
+ 'Belly laugh',
477
+ 'Bicycle bell',
478
+ 'Bicycle, tricycle',
479
+ 'Bird',
480
+ 'Bird flight, flapping wings',
481
+ 'Bird vocalization, bird call, bird song',
482
+ 'Biting',
483
+ 'Bleat',
484
+ 'Blender, food processor',
485
+ 'Boat, Water vehicle',
486
+ 'Boiling',
487
+ 'Boing',
488
+ 'Boom',
489
+ 'Bouncing',
490
+ 'Bow-wow',
491
+ 'Breaking',
492
+ 'Breathing',
493
+ 'Brief tone',
494
+ 'Burping, eructation',
495
+ 'Burst, pop',
496
+ 'Bus',
497
+ 'Busy signal',
498
+ 'Buzz',
499
+ 'Buzzer',
500
+ 'Cacophony',
501
+ 'Camera',
502
+ 'Canidae, wild dogs, wolves',
503
+ 'Cap gun',
504
+ 'Car',
505
+ 'Car alarm',
506
+ 'Car passing by',
507
+ 'Cart',
508
+ 'Cash register',
509
+ 'Cat',
510
+ 'Caterwaul',
511
+ 'Cattle, bovinae',
512
+ 'Caw',
513
+ 'Cellphone buzz, vibrating alert',
514
+ 'Chainsaw',
515
+ 'Change ringing (campanology)',
516
+ 'Chant',
517
+ 'Cheering',
518
+ 'Chewing, mastication',
519
+ 'Chicken, rooster',
520
+ 'Child singing',
521
+ 'Child speech, kid speaking',
522
+ 'Children playing',
523
+ 'Children shouting',
524
+ 'Chime',
525
+ 'Chipmunk',
526
+ 'Chirp tone',
527
+ 'Chirp, tweet',
528
+ 'Choir',
529
+ 'Chop',
530
+ 'Chopping (food)',
531
+ 'Chorus effect',
532
+ 'Chuckle, chortle',
533
+ 'Church bell',
534
+ 'Civil defense siren',
535
+ 'Clang',
536
+ 'Clapping',
537
+ 'Clatter',
538
+ 'Clickety-clack',
539
+ 'Clicking',
540
+ 'Clip-clop',
541
+ 'Clock',
542
+ 'Cluck',
543
+ 'Coin (dropping)',
544
+ 'Computer keyboard',
545
+ 'Conversation',
546
+ 'Coo',
547
+ 'Cough',
548
+ 'Cowbell',
549
+ 'Crack',
550
+ 'Crackle',
551
+ 'Creak',
552
+ 'Cricket',
553
+ 'Croak',
554
+ 'Crockery breaking and smashing',
555
+ 'Crow',
556
+ 'Crowd',
557
+ 'Crowing, cock-a-doodle-doo',
558
+ 'Crumpling, crinkling',
559
+ 'Crunch',
560
+ 'Crushing',
561
+ 'Crying, sobbing',
562
+ 'Cupboard open or close',
563
+ 'Cutlery, silverware',
564
+ "Dental drill, dentist's drill",
565
+ 'Dial tone',
566
+ 'Ding',
567
+ 'Ding-dong',
568
+ 'Dishes, pots, and pans',
569
+ 'Distortion',
570
+ 'Dog',
571
+ 'Domestic animals, pets',
572
+ 'Door',
573
+ 'Doorbell',
574
+ 'Drawer open or close',
575
+ 'Drill',
576
+ 'Drip',
577
+ 'Ducks, geese, waterfowl',
578
+ 'Echo',
579
+ 'Effects unit',
580
+ 'Electric rotor drone, quadcopter',
581
+ 'Electric shaver, electric razor',
582
+ 'Electric toothbrush',
583
+ 'Electronic tuner',
584
+ 'Emergency vehicle',
585
+ 'Engine',
586
+ 'Engine knocking',
587
+ 'Engine starting',
588
+ 'Environmental noise',
589
+ 'Eruption',
590
+ 'Explosion',
591
+ 'Fart',
592
+ 'Female singing',
593
+ 'Female speech, woman speaking',
594
+ 'Filing (rasp)',
595
+ 'Fill (with liquid)',
596
+ 'Finger snapping',
597
+ 'Fire',
598
+ 'Fire alarm',
599
+ 'Fire engine, fire truck (siren)',
600
+ 'Firecracker',
601
+ 'Fireworks',
602
+ 'Fixed-wing aircraft, airplane',
603
+ 'Flap',
604
+ 'Fly, housefly',
605
+ 'Foghorn',
606
+ 'Fowl',
607
+ 'Frog',
608
+ 'Frying (food)',
609
+ 'Fusillade',
610
+ 'Gargling',
611
+ 'Gasp',
612
+ 'Gears',
613
+ 'Generic impact sounds',
614
+ 'Giggle',
615
+ 'Glass',
616
+ 'Glass chink, clink',
617
+ 'Glass shatter',
618
+ 'Goat',
619
+ 'Gobble',
620
+ 'Groan',
621
+ 'Growling',
622
+ 'Grunt',
623
+ 'Gunshot, gunfire',
624
+ 'Gurgling, bubbling',
625
+ 'Gush',
626
+ 'Hair dryer',
627
+ 'Hammer',
628
+ 'Hands',
629
+ 'Heart murmur',
630
+ 'Heart sounds, heartbeat',
631
+ 'Heavy engine (low frequency)',
632
+ 'Helicopter',
633
+ 'Hiccup',
634
+ 'Hiss',
635
+ 'Honk',
636
+ 'Hoot',
637
+ 'Horse',
638
+ 'Howl',
639
+ 'Howl (wind)',
640
+ 'Hubbub, speech noise, speech babble',
641
+ 'Hum',
642
+ 'Human sounds',
643
+ 'Human voice',
644
+ 'Humming',
645
+ 'Ice cream truck, ice cream van',
646
+ 'Idling',
647
+ 'Insect',
648
+ 'Inside, large room or hall',
649
+ 'Inside, public space',
650
+ 'Inside, small room',
651
+ 'Jackhammer',
652
+ 'Jet engine',
653
+ 'Jingle bell',
654
+ 'Jingle, tinkle',
655
+ 'Keys jangling',
656
+ 'Kitchen and dining room sounds',
657
+ 'Knock',
658
+ 'Laughter',
659
+ 'Lawn mower',
660
+ 'Light engine (high frequency)',
661
+ 'Liquid',
662
+ 'Livestock, farm animals, working animals',
663
+ 'Machine gun',
664
+ 'Mains hum',
665
+ 'Male singing',
666
+ 'Male speech, man speaking',
667
+ 'Mantra',
668
+ 'Mechanical fan',
669
+ 'Mechanisms',
670
+ 'Medium engine (mid frequency)',
671
+ 'Meow',
672
+ 'Microwave oven',
673
+ 'Moo',
674
+ 'Mosquito',
675
+ 'Motor vehicle (road)',
676
+ 'Motorboat, speedboat',
677
+ 'Motorcycle',
678
+ 'Mouse',
679
+ 'Music',
680
+ 'Narration, monologue',
681
+ 'Neigh, whinny',
682
+ 'Noise',
683
+ 'Non-motorized land vehicle',
684
+ 'Ocean',
685
+ 'Oink',
686
+ 'Outside, rural or natural',
687
+ 'Outside, urban or manmade',
688
+ 'Owl',
689
+ 'Packing tape, duct tape',
690
+ 'Pant',
691
+ 'Pant (dog)',
692
+ 'Paper rustling',
693
+ 'Patter',
694
+ 'Pig',
695
+ 'Pigeon, dove',
696
+ 'Ping',
697
+ 'Plop',
698
+ 'Police car (siren)',
699
+ 'Pour',
700
+ 'Power saw, circular saw, table saw',
701
+ 'Power tool',
702
+ 'Power windows, electric windows',
703
+ 'Printer',
704
+ 'Propeller, airscrew',
705
+ 'Pulleys',
706
+ 'Pulse',
707
+ 'Pump (liquid)',
708
+ 'Purr',
709
+ 'Quack',
710
+ 'Race car, auto racing',
711
+ 'Radio',
712
+ 'Rail transport',
713
+ 'Railroad car, train wagon',
714
+ 'Rain',
715
+ 'Rain on surface',
716
+ 'Raindrop',
717
+ 'Rapping',
718
+ 'Ratchet, pawl',
719
+ 'Rattle',
720
+ 'Respiratory sounds',
721
+ 'Reverberation',
722
+ 'Reversing beeps',
723
+ 'Ringing tone, ringback tone',
724
+ 'Ringtone',
725
+ 'Roar',
726
+ 'Roaring cats (lions, tigers)',
727
+ 'Rodents, rats, mice',
728
+ 'Roll',
729
+ 'Rowboat, canoe, kayak',
730
+ 'Rub',
731
+ 'Rumble',
732
+ 'Run',
733
+ 'Rustle',
734
+ 'Sailboat, sailing ship',
735
+ 'Sanding',
736
+ 'Sawing',
737
+ 'Scissors',
738
+ 'Scrape',
739
+ 'Scratch',
740
+ 'Screaming',
741
+ 'Sewing machine',
742
+ 'Sheep',
743
+ 'Ship',
744
+ 'Shout',
745
+ 'Shower',
746
+ 'Shuffle',
747
+ 'Shuffling cards',
748
+ 'Sigh',
749
+ 'Silence',
750
+ 'Sine wave',
751
+ 'Singing',
752
+ 'Single-lens reflex camera',
753
+ 'Sink (filling or washing)',
754
+ 'Siren',
755
+ 'Sizzle',
756
+ 'Skateboard',
757
+ 'Slam',
758
+ 'Slap, smack',
759
+ 'Sliding door',
760
+ 'Slosh',
761
+ 'Smash, crash',
762
+ 'Smoke detector, smoke alarm',
763
+ 'Snake',
764
+ 'Sneeze',
765
+ 'Snicker',
766
+ 'Sniff',
767
+ 'Snoring',
768
+ 'Snort',
769
+ 'Snort (horse)',
770
+ 'Sonar',
771
+ 'Sound effect',
772
+ 'Sound equipment',
773
+ 'Source-ambiguous sounds',
774
+ 'Specific impact sounds',
775
+ 'Speech',
776
+ 'Speech synthesizer',
777
+ 'Splash, splatter',
778
+ 'Splinter',
779
+ 'Spray',
780
+ 'Squawk',
781
+ 'Squeak',
782
+ 'Squeal',
783
+ 'Squish',
784
+ 'Stairs',
785
+ 'Static',
786
+ 'Steam',
787
+ 'Steam whistle',
788
+ 'Stir',
789
+ 'Stomach rumble',
790
+ 'Stomp, stamp',
791
+ 'Stream, river',
792
+ 'Studio recording',
793
+ 'Subway, metro, underground',
794
+ 'Surface contact',
795
+ 'Synthetic singing',
796
+ 'Tap',
797
+ 'Tap dance',
798
+ 'Tearing',
799
+ 'Telephone',
800
+ 'Telephone bell ringing',
801
+ 'Telephone dialing, DTMF',
802
+ 'Television',
803
+ 'Throat clearing',
804
+ 'Throbbing',
805
+ 'Thump, thud',
806
+ 'Thunder',
807
+ 'Thunderstorm',
808
+ 'Thunk',
809
+ 'Tick',
810
+ 'Tick-tock',
811
+ 'Tire squeal, skidding',
812
+ 'Toilet flush',
813
+ 'Tools',
814
+ 'Toothbrush',
815
+ 'Traffic noise, roadway noise',
816
+ 'Train',
817
+ 'Train horn',
818
+ 'Train wheels squealing',
819
+ 'Train whistle',
820
+ 'Trickle, dribble',
821
+ 'Truck',
822
+ 'Tuning fork',
823
+ 'Turkey',
824
+ 'Typewriter',
825
+ 'Typing',
826
+ 'Unknown sound',
827
+ 'Unmodified field recording',
828
+ 'Vacuum cleaner',
829
+ 'Vehicle',
830
+ 'Vehicle horn, car horn, honking, toot',
831
+ 'Velcro, hook and loop fastener',
832
+ 'Vibration',
833
+ 'Video game sound',
834
+ 'Wail, moan',
835
+ 'Walk, footsteps',
836
+ 'Washing machine',
837
+ 'Water',
838
+ 'Water tap, faucet',
839
+ 'Waterfall',
840
+ 'Waves, surf',
841
+ 'Whack, thwack',
842
+ 'Whale vocalization',
843
+ 'Wheeze',
844
+ 'Whimper',
845
+ 'Whimper (dog)',
846
+ 'Whip',
847
+ 'Whir',
848
+ 'Whispering',
849
+ 'Whistle',
850
+ 'Whistling',
851
+ 'White noise, pink noise',
852
+ 'Whoop',
853
+ 'Whoosh, swoosh, swish',
854
+ 'Wild animals',
855
+ 'Wind',
856
+ 'Wind chime',
857
+ 'Wind noise (microphone)',
858
+ 'Wood',
859
+ 'Writing',
860
+ 'Yawn',
861
+ 'Yell',
862
+ 'Yip',
863
+ 'Yodeling',
864
+ 'Zipper (clothing)']
865
+
866
+ as_weak_classes = ['A capella',
867
+ 'Accelerating, revving, vroom',
868
+ 'Accordion',
869
+ 'Acoustic guitar',
870
+ 'Afrobeat',
871
+ 'Air brake',
872
+ 'Air conditioning',
873
+ 'Air horn, truck horn',
874
+ 'Aircraft',
875
+ 'Aircraft engine',
876
+ 'Alarm',
877
+ 'Alarm clock',
878
+ 'Ambient music',
879
+ 'Ambulance (siren)',
880
+ 'Angry music',
881
+ 'Animal',
882
+ 'Applause',
883
+ 'Arrow',
884
+ 'Artillery fire',
885
+ 'Babbling',
886
+ 'Baby cry, infant cry',
887
+ 'Baby laughter',
888
+ 'Background music',
889
+ 'Bagpipes',
890
+ 'Bang',
891
+ 'Banjo',
892
+ 'Bark',
893
+ 'Basketball bounce',
894
+ 'Bass drum',
895
+ 'Bass guitar',
896
+ 'Bathtub (filling or washing)',
897
+ 'Battle cry',
898
+ 'Beatboxing',
899
+ 'Bee, wasp, etc.',
900
+ 'Beep, bleep',
901
+ 'Bell',
902
+ 'Bellow',
903
+ 'Belly laugh',
904
+ 'Bicycle',
905
+ 'Bicycle bell',
906
+ 'Bird',
907
+ 'Bird flight, flapping wings',
908
+ 'Bird vocalization, bird call, bird song',
909
+ 'Biting',
910
+ 'Bleat',
911
+ 'Blender',
912
+ 'Bluegrass',
913
+ 'Blues',
914
+ 'Boat, Water vehicle',
915
+ 'Boiling',
916
+ 'Boing',
917
+ 'Boom',
918
+ 'Bouncing',
919
+ 'Bow-wow',
920
+ 'Bowed string instrument',
921
+ 'Brass instrument',
922
+ 'Breaking',
923
+ 'Breathing',
924
+ 'Burping, eructation',
925
+ 'Burst, pop',
926
+ 'Bus',
927
+ 'Busy signal',
928
+ 'Buzz',
929
+ 'Buzzer',
930
+ 'Cacophony',
931
+ 'Camera',
932
+ 'Canidae, dogs, wolves',
933
+ 'Cap gun',
934
+ 'Car',
935
+ 'Car alarm',
936
+ 'Car passing by',
937
+ 'Carnatic music',
938
+ 'Cash register',
939
+ 'Cat',
940
+ 'Caterwaul',
941
+ 'Cattle, bovinae',
942
+ 'Caw',
943
+ 'Cello',
944
+ 'Chainsaw',
945
+ 'Change ringing (campanology)',
946
+ 'Chant',
947
+ 'Chatter',
948
+ 'Cheering',
949
+ 'Chewing, mastication',
950
+ 'Chicken, rooster',
951
+ 'Child singing',
952
+ 'Child speech, kid speaking',
953
+ 'Children playing',
954
+ 'Children shouting',
955
+ 'Chime',
956
+ 'Chink, clink',
957
+ 'Chirp tone',
958
+ 'Chirp, tweet',
959
+ 'Choir',
960
+ 'Chop',
961
+ 'Chopping (food)',
962
+ 'Chorus effect',
963
+ 'Christian music',
964
+ 'Christmas music',
965
+ 'Chuckle, chortle',
966
+ 'Church bell',
967
+ 'Civil defense siren',
968
+ 'Clang',
969
+ 'Clapping',
970
+ 'Clarinet',
971
+ 'Classical music',
972
+ 'Clatter',
973
+ 'Clickety-clack',
974
+ 'Clicking',
975
+ 'Clip-clop',
976
+ 'Clock',
977
+ 'Cluck',
978
+ 'Coin (dropping)',
979
+ 'Computer keyboard',
980
+ 'Conversation',
981
+ 'Coo',
982
+ 'Cough',
983
+ 'Country',
984
+ 'Cowbell',
985
+ 'Crack',
986
+ 'Crackle',
987
+ 'Creak',
988
+ 'Cricket',
989
+ 'Croak',
990
+ 'Crow',
991
+ 'Crowd',
992
+ 'Crowing, cock-a-doodle-doo',
993
+ 'Crumpling, crinkling',
994
+ 'Crunch',
995
+ 'Crushing',
996
+ 'Crying, sobbing',
997
+ 'Cupboard open or close',
998
+ 'Cutlery, silverware',
999
+ 'Cymbal',
1000
+ 'Dance music',
1001
+ "Dental drill, dentist's drill",
1002
+ 'Dial tone',
1003
+ 'Didgeridoo',
1004
+ 'Ding',
1005
+ 'Ding-dong',
1006
+ 'Disco',
1007
+ 'Dishes, pots, and pans',
1008
+ 'Distortion',
1009
+ 'Dog',
1010
+ 'Domestic animals, pets',
1011
+ 'Door',
1012
+ 'Doorbell',
1013
+ 'Double bass',
1014
+ 'Drawer open or close',
1015
+ 'Drill',
1016
+ 'Drip',
1017
+ 'Drum',
1018
+ 'Drum and bass',
1019
+ 'Drum kit',
1020
+ 'Drum machine',
1021
+ 'Drum roll',
1022
+ 'Dubstep',
1023
+ 'Duck',
1024
+ 'Echo',
1025
+ 'Effects unit',
1026
+ 'Electric guitar',
1027
+ 'Electric piano',
1028
+ 'Electric shaver, electric razor',
1029
+ 'Electric toothbrush',
1030
+ 'Electronic dance music',
1031
+ 'Electronic music',
1032
+ 'Electronic organ',
1033
+ 'Electronic tuner',
1034
+ 'Electronica',
1035
+ 'Emergency vehicle',
1036
+ 'Engine',
1037
+ 'Engine knocking',
1038
+ 'Engine starting',
1039
+ 'Environmental noise',
1040
+ 'Eruption',
1041
+ 'Exciting music',
1042
+ 'Explosion',
1043
+ 'Fart',
1044
+ 'Female singing',
1045
+ 'Female speech, woman speaking',
1046
+ 'Field recording',
1047
+ 'Filing (rasp)',
1048
+ 'Fill (with liquid)',
1049
+ 'Finger snapping',
1050
+ 'Fire',
1051
+ 'Fire alarm',
1052
+ 'Fire engine, fire truck (siren)',
1053
+ 'Firecracker',
1054
+ 'Fireworks',
1055
+ 'Fixed-wing aircraft, airplane',
1056
+ 'Flamenco',
1057
+ 'Flap',
1058
+ 'Flute',
1059
+ 'Fly, housefly',
1060
+ 'Foghorn',
1061
+ 'Folk music',
1062
+ 'Fowl',
1063
+ 'French horn',
1064
+ 'Frog',
1065
+ 'Frying (food)',
1066
+ 'Funk',
1067
+ 'Funny music',
1068
+ 'Fusillade',
1069
+ 'Gargling',
1070
+ 'Gasp',
1071
+ 'Gears',
1072
+ 'Giggle',
1073
+ 'Glass',
1074
+ 'Glockenspiel',
1075
+ 'Goat',
1076
+ 'Gobble',
1077
+ 'Gong',
1078
+ 'Goose',
1079
+ 'Gospel music',
1080
+ 'Groan',
1081
+ 'Growling',
1082
+ 'Grunge',
1083
+ 'Grunt',
1084
+ 'Guitar',
1085
+ 'Gunshot, gunfire',
1086
+ 'Gurgling',
1087
+ 'Gush',
1088
+ 'Hair dryer',
1089
+ 'Hammer',
1090
+ 'Hammond organ',
1091
+ 'Hands',
1092
+ 'Happy music',
1093
+ 'Harmonic',
1094
+ 'Harmonica',
1095
+ 'Harp',
1096
+ 'Harpsichord',
1097
+ 'Heart murmur',
1098
+ 'Heart sounds, heartbeat',
1099
+ 'Heavy engine (low frequency)',
1100
+ 'Heavy metal',
1101
+ 'Helicopter',
1102
+ 'Hi-hat',
1103
+ 'Hiccup',
1104
+ 'Hip hop music',
1105
+ 'Hiss',
1106
+ 'Honk',
1107
+ 'Hoot',
1108
+ 'Horse',
1109
+ 'House music',
1110
+ 'Howl',
1111
+ 'Hubbub, speech noise, speech babble',
1112
+ 'Hum',
1113
+ 'Humming',
1114
+ 'Ice cream truck, ice cream van',
1115
+ 'Idling',
1116
+ 'Independent music',
1117
+ 'Insect',
1118
+ 'Inside, large room or hall',
1119
+ 'Inside, public space',
1120
+ 'Inside, small room',
1121
+ 'Jackhammer',
1122
+ 'Jazz',
1123
+ 'Jet engine',
1124
+ 'Jingle (music)',
1125
+ 'Jingle bell',
1126
+ 'Jingle, tinkle',
1127
+ 'Keyboard (musical)',
1128
+ 'Keys jangling',
1129
+ 'Knock',
1130
+ 'Laughter',
1131
+ 'Lawn mower',
1132
+ 'Light engine (high frequency)',
1133
+ 'Liquid',
1134
+ 'Livestock, farm animals, working animals',
1135
+ 'Lullaby',
1136
+ 'Machine gun',
1137
+ 'Mains hum',
1138
+ 'Male singing',
1139
+ 'Male speech, man speaking',
1140
+ 'Mallet percussion',
1141
+ 'Mandolin',
1142
+ 'Mantra',
1143
+ 'Maraca',
1144
+ 'Marimba, xylophone',
1145
+ 'Mechanical fan',
1146
+ 'Mechanisms',
1147
+ 'Medium engine (mid frequency)',
1148
+ 'Meow',
1149
+ 'Microwave oven',
1150
+ 'Middle Eastern music',
1151
+ 'Moo',
1152
+ 'Mosquito',
1153
+ 'Motor vehicle (road)',
1154
+ 'Motorboat, speedboat',
1155
+ 'Motorcycle',
1156
+ 'Mouse',
1157
+ 'Music',
1158
+ 'Music for children',
1159
+ 'Music of Africa',
1160
+ 'Music of Asia',
1161
+ 'Music of Bollywood',
1162
+ 'Music of Latin America',
1163
+ 'Musical instrument',
1164
+ 'Narration, monologue',
1165
+ 'Neigh, whinny',
1166
+ 'New-age music',
1167
+ 'Noise',
1168
+ 'Ocean',
1169
+ 'Oink',
1170
+ 'Opera',
1171
+ 'Orchestra',
1172
+ 'Organ',
1173
+ 'Outside, rural or natural',
1174
+ 'Outside, urban or manmade',
1175
+ 'Owl',
1176
+ 'Pant',
1177
+ 'Patter',
1178
+ 'Percussion',
1179
+ 'Piano',
1180
+ 'Pig',
1181
+ 'Pigeon, dove',
1182
+ 'Ping',
1183
+ 'Pink noise',
1184
+ 'Pizzicato',
1185
+ 'Plop',
1186
+ 'Plucked string instrument',
1187
+ 'Police car (siren)',
1188
+ 'Pop music',
1189
+ 'Pour',
1190
+ 'Power tool',
1191
+ 'Power windows, electric windows',
1192
+ 'Printer',
1193
+ 'Progressive rock',
1194
+ 'Propeller, airscrew',
1195
+ 'Psychedelic rock',
1196
+ 'Pulleys',
1197
+ 'Pulse',
1198
+ 'Pump (liquid)',
1199
+ 'Punk rock',
1200
+ 'Purr',
1201
+ 'Quack',
1202
+ 'Race car, auto racing',
1203
+ 'Radio',
1204
+ 'Rail transport',
1205
+ 'Railroad car, train wagon',
1206
+ 'Rain',
1207
+ 'Rain on surface',
1208
+ 'Raindrop',
1209
+ 'Rapping',
1210
+ 'Ratchet, pawl',
1211
+ 'Rattle',
1212
+ 'Rattle (instrument)',
1213
+ 'Reggae',
1214
+ 'Reverberation',
1215
+ 'Reversing beeps',
1216
+ 'Rhythm and blues',
1217
+ 'Rimshot',
1218
+ 'Ringtone',
1219
+ 'Roar',
1220
+ 'Roaring cats (lions, tigers)',
1221
+ 'Rock and roll',
1222
+ 'Rock music',
1223
+ 'Rodents, rats, mice',
1224
+ 'Roll',
1225
+ 'Rowboat, canoe, kayak',
1226
+ 'Rub',
1227
+ 'Rumble',
1228
+ 'Run',
1229
+ 'Rustle',
1230
+ 'Rustling leaves',
1231
+ 'Sad music',
1232
+ 'Sailboat, sailing ship',
1233
+ 'Salsa music',
1234
+ 'Sampler',
1235
+ 'Sanding',
1236
+ 'Sawing',
1237
+ 'Saxophone',
1238
+ 'Scary music',
1239
+ 'Scissors',
1240
+ 'Scrape',
1241
+ 'Scratch',
1242
+ 'Scratching (performance technique)',
1243
+ 'Screaming',
1244
+ 'Sewing machine',
1245
+ 'Shatter',
1246
+ 'Sheep',
1247
+ 'Ship',
1248
+ 'Shofar',
1249
+ 'Shout',
1250
+ 'Shuffle',
1251
+ 'Shuffling cards',
1252
+ 'Sidetone',
1253
+ 'Sigh',
1254
+ 'Silence',
1255
+ 'Sine wave',
1256
+ 'Singing',
1257
+ 'Singing bowl',
1258
+ 'Single-lens reflex camera',
1259
+ 'Sink (filling or washing)',
1260
+ 'Siren',
1261
+ 'Sitar',
1262
+ 'Sizzle',
1263
+ 'Ska',
1264
+ 'Skateboard',
1265
+ 'Skidding',
1266
+ 'Slam',
1267
+ 'Slap, smack',
1268
+ 'Sliding door',
1269
+ 'Slosh',
1270
+ 'Smash, crash',
1271
+ 'Smoke detector, smoke alarm',
1272
+ 'Snake',
1273
+ 'Snare drum',
1274
+ 'Sneeze',
1275
+ 'Snicker',
1276
+ 'Sniff',
1277
+ 'Snoring',
1278
+ 'Snort',
1279
+ 'Sonar',
1280
+ 'Song',
1281
+ 'Soul music',
1282
+ 'Sound effect',
1283
+ 'Soundtrack music',
1284
+ 'Speech',
1285
+ 'Speech synthesizer',
1286
+ 'Splash, splatter',
1287
+ 'Splinter',
1288
+ 'Spray',
1289
+ 'Squawk',
1290
+ 'Squeak',
1291
+ 'Squeal',
1292
+ 'Squish',
1293
+ 'Static',
1294
+ 'Steam',
1295
+ 'Steam whistle',
1296
+ 'Steel guitar, slide guitar',
1297
+ 'Steelpan',
1298
+ 'Stir',
1299
+ 'Stomach rumble',
1300
+ 'Stream',
1301
+ 'String section',
1302
+ 'Strum',
1303
+ 'Subway, metro, underground',
1304
+ 'Swing music',
1305
+ 'Synthesizer',
1306
+ 'Synthetic singing',
1307
+ 'Tabla',
1308
+ 'Tambourine',
1309
+ 'Tap',
1310
+ 'Tapping (guitar technique)',
1311
+ 'Tearing',
1312
+ 'Techno',
1313
+ 'Telephone',
1314
+ 'Telephone bell ringing',
1315
+ 'Telephone dialing, DTMF',
1316
+ 'Television',
1317
+ 'Tender music',
1318
+ 'Theme music',
1319
+ 'Theremin',
1320
+ 'Throat clearing',
1321
+ 'Throbbing',
1322
+ 'Thump, thud',
1323
+ 'Thunder',
1324
+ 'Thunderstorm',
1325
+ 'Thunk',
1326
+ 'Tick',
1327
+ 'Tick-tock',
1328
+ 'Timpani',
1329
+ 'Tire squeal',
1330
+ 'Toilet flush',
1331
+ 'Tools',
1332
+ 'Toot',
1333
+ 'Toothbrush',
1334
+ 'Traditional music',
1335
+ 'Traffic noise, roadway noise',
1336
+ 'Train',
1337
+ 'Train horn',
1338
+ 'Train wheels squealing',
1339
+ 'Train whistle',
1340
+ 'Trance music',
1341
+ 'Trickle, dribble',
1342
+ 'Trombone',
1343
+ 'Truck',
1344
+ 'Trumpet',
1345
+ 'Tubular bells',
1346
+ 'Tuning fork',
1347
+ 'Turkey',
1348
+ 'Typewriter',
1349
+ 'Typing',
1350
+ 'Ukulele',
1351
+ 'Vacuum cleaner',
1352
+ 'Vehicle',
1353
+ 'Vehicle horn, car horn, honking',
1354
+ 'Vibraphone',
1355
+ 'Vibration',
1356
+ 'Video game music',
1357
+ 'Violin, fiddle',
1358
+ 'Vocal music',
1359
+ 'Wail, moan',
1360
+ 'Walk, footsteps',
1361
+ 'Water',
1362
+ 'Water tap, faucet',
1363
+ 'Waterfall',
1364
+ 'Waves, surf',
1365
+ 'Wedding music',
1366
+ 'Whack, thwack',
1367
+ 'Whale vocalization',
1368
+ 'Wheeze',
1369
+ 'Whimper',
1370
+ 'Whimper (dog)',
1371
+ 'Whip',
1372
+ 'Whir',
1373
+ 'Whispering',
1374
+ 'Whistle',
1375
+ 'Whistling',
1376
+ 'White noise',
1377
+ 'Whoop',
1378
+ 'Whoosh, swoosh, swish',
1379
+ 'Wild animals',
1380
+ 'Wind',
1381
+ 'Wind chime',
1382
+ 'Wind instrument, woodwind instrument',
1383
+ 'Wind noise (microphone)',
1384
+ 'Wood',
1385
+ 'Wood block',
1386
+ 'Writing',
1387
+ 'Yell',
1388
+ 'Yip',
1389
+ 'Yodeling',
1390
+ 'Zing',
1391
+ 'Zipper (clothing)',
1392
+ 'Zither'
1393
+ ]
data_util/audioset_strong.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from time import perf_counter
3
+ import datasets
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import (
8
+ Dataset as TorchDataset,
9
+ DistributedSampler,
10
+ WeightedRandomSampler,
11
+ )
12
+
13
+ from data_util.audioset_classes import as_strong_train_classes
14
+ from data_util.transforms import (
15
+ Mp3DecodeTransform,
16
+ SequentialTransform,
17
+ AddPseudoLabelsTransform,
18
+ strong_label_transform,
19
+ target_transform
20
+ )
21
+
22
+ logger = datasets.logging.get_logger(__name__)
23
+
24
+
25
+ def init_hf_config(max_shard_size="2GB", verbose=True, in_mem_max=None):
26
+ datasets.config.MAX_SHARD_SIZE = max_shard_size
27
+ if verbose:
28
+ datasets.logging.set_verbosity_info()
29
+ if in_mem_max is not None:
30
+ datasets.config.IN_MEMORY_MAX_SIZE = in_mem_max
31
+
32
+
33
+ def get_hf_local_path(path, local_datasets_path=None):
34
+ if local_datasets_path is None:
35
+ local_datasets_path = os.environ.get(
36
+ "HF_DATASETS_LOCAL",
37
+ os.path.join(os.environ.get("HF_DATASETS_CACHE"), "../local"),
38
+ )
39
+ path = os.path.join(local_datasets_path, path)
40
+ return path
41
+
42
+
43
+ class catchtime:
44
+ # context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time
45
+ def __init__(self, debug_print="Time", logger=logger):
46
+ self.debug_print = debug_print
47
+ self.logger = logger
48
+
49
+ def __enter__(self):
50
+ self.start = perf_counter()
51
+ return self
52
+
53
+ def __exit__(self, type, value, traceback):
54
+ self.time = perf_counter() - self.start
55
+ readout = f"{self.debug_print}: {self.time:.3f} seconds"
56
+ self.logger.info(readout)
57
+
58
+
59
+ def merge_overlapping_events(sample):
60
+ events = pd.DataFrame(sample['events'][0])
61
+ events = events.sort_values(by='onset')
62
+ sample['events'] = [None]
63
+
64
+ for l in events['event_label'].unique():
65
+ rows = []
66
+ for i, r in events.loc[events['event_label'] == l].iterrows():
67
+ if len(rows) == 0 or rows[-1]['offset'] < r['onset']:
68
+ rows.append(r)
69
+ else:
70
+ onset = min(rows[-1]['onset'], r['onset'])
71
+ offset = max(rows[-1]['offset'], r['offset'])
72
+ rows[-1]['onset'] = onset
73
+ rows[-1]['offset'] = offset
74
+ if sample["events"][0] is None:
75
+ sample['events'][0] = pd.DataFrame(rows)
76
+ else:
77
+ sample["events"][0] = pd.concat([sample['events'][0], pd.DataFrame(rows)])
78
+ return sample
79
+
80
+
81
+ def get_training_dataset(
82
+ label_encoder,
83
+ audio_length=10.0,
84
+ sample_rate=16000,
85
+ wavmix_p=0.0,
86
+ pseudo_labels_file=None,
87
+ ):
88
+ init_hf_config()
89
+
90
+ decode_transform = Mp3DecodeTransform(
91
+ sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
92
+ )
93
+
94
+ ds_list = []
95
+
96
+ with catchtime("Loading audioset_strong"):
97
+ as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
98
+
99
+ # label encode transformation
100
+ if label_encoder is not None:
101
+ # set list of label names to be encoded
102
+ label_encoder.labels = as_strong_train_classes
103
+ encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
104
+ else:
105
+ encode_label_fun = lambda x: x
106
+
107
+ as_transforms = [
108
+ decode_transform,
109
+ merge_overlapping_events,
110
+ encode_label_fun,
111
+ target_transform,
112
+ ]
113
+
114
+ if pseudo_labels_file:
115
+ as_transforms.append(AddPseudoLabelsTransform(pseudo_labels_file=pseudo_labels_file).add_pseudo_label_transform)
116
+
117
+ as_ds.set_transform(SequentialTransform(as_transforms))
118
+
119
+ ds_list.append(as_ds["balanced_train"])
120
+ ds_list.append(as_ds["unbalanced_train"])
121
+ dataset = torch.utils.data.ConcatDataset(ds_list)
122
+
123
+ if wavmix_p > 0:
124
+ print("Using Wavmix!")
125
+ dataset = MixupDataset(dataset, rate=wavmix_p)
126
+ return dataset
127
+
128
+
129
+ def get_eval_dataset(
130
+ label_encoder,
131
+ audio_length=10.0,
132
+ sample_rate=16000
133
+ ):
134
+ init_hf_config()
135
+ ds_list = []
136
+
137
+ decode_transform = Mp3DecodeTransform(
138
+ sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
139
+ )
140
+
141
+ with catchtime(f"Loading audioset:"):
142
+ as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
143
+
144
+ # label encode transformation
145
+ if label_encoder is not None:
146
+ label_encoder.labels = as_strong_train_classes
147
+ encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
148
+ else:
149
+ encode_label_fun = lambda x: x
150
+
151
+ as_transforms = [
152
+ decode_transform,
153
+ merge_overlapping_events,
154
+ encode_label_fun,
155
+ target_transform
156
+ ]
157
+ as_ds.set_transform(SequentialTransform(as_transforms))
158
+ as_ds_eval = (
159
+ as_ds["eval"]
160
+ )
161
+ ds_list.append(as_ds_eval)
162
+ dataset = torch.utils.data.ConcatDataset(ds_list)
163
+ return dataset
164
+
165
+
166
+ def get_full_dataset(label_encoder, audio_length=10.0, sample_rate=16000):
167
+ init_hf_config()
168
+
169
+ decode_transform = Mp3DecodeTransform(
170
+ sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
171
+ )
172
+
173
+ with catchtime(f"Loading audioset:"):
174
+ as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
175
+
176
+ # label encode transformation
177
+ if label_encoder is not None:
178
+ label_encoder.labels = as_strong_train_classes
179
+ encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
180
+ else:
181
+ encode_label_fun = lambda x: x
182
+
183
+ as_transforms = [
184
+ decode_transform,
185
+ merge_overlapping_events,
186
+ encode_label_fun,
187
+ ]
188
+
189
+ as_ds.set_transform(SequentialTransform(as_transforms))
190
+ ds_list = []
191
+ ds_list.append(as_ds["balanced_train"])
192
+ ds_list.append(as_ds["unbalanced_train"])
193
+ ds_list.append(as_ds["eval"])
194
+
195
+ dataset = torch.utils.data.ConcatDataset(ds_list)
196
+ return dataset
197
+
198
+
199
+ def get_uniform_sample_weights(dataset):
200
+ """
201
+ :return: float tensor of shape len(full_training_set) representing the weights of each sample.
202
+ """
203
+ return torch.ones(len(dataset)).float()
204
+
205
+
206
+ def get_temporal_count_balanced_sample_weights(dataset, sample_weight_offset=30,
207
+ save_folder="/share/rk8/shared/as_strong"):
208
+ """
209
+ :return: float tensor of shape len(full_training_set) representing the weights of each sample.
210
+ """
211
+ # the order of balanced_train_hdf5, unbalanced_train_hdf5 is important.
212
+ # should match get_full_training_set
213
+ os.makedirs(save_folder, exist_ok=True)
214
+ save_file = os.path.join(save_folder, f"weights_temporal_count_offset_{sample_weight_offset}.pt")
215
+ if os.path.exists(save_file):
216
+ return torch.load(save_file)
217
+
218
+ from tqdm import tqdm
219
+
220
+ all_y = []
221
+ for sample in tqdm(dataset, desc="Calculating sample weights."):
222
+ all_y.append(sample["event_count"])
223
+ all_y = torch.from_numpy(np.stack(all_y, axis=0))
224
+ per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class
225
+
226
+ per_class = sample_weight_offset + per_class # offset low freq classes
227
+ if sample_weight_offset > 0:
228
+ print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}")
229
+ per_class_weights = 1000. / per_class
230
+ all_weight = all_y * per_class_weights
231
+ all_weight = all_weight.sum(dim=1)
232
+
233
+ torch.save(all_weight, save_file)
234
+ return all_weight
235
+
236
+
237
+ class MixupDataset(TorchDataset):
238
+ """ Mixing Up wave forms
239
+ """
240
+
241
+ def __init__(self, dataset, beta=2, rate=0.5):
242
+ self.beta = beta
243
+ self.rate = rate
244
+ self.dataset = dataset
245
+ print(f"Mixing up waveforms from dataset of len {len(dataset)}")
246
+
247
+ def __getitem__(self, index):
248
+ if torch.rand(1) < self.rate:
249
+ batch1 = self.dataset[index]
250
+ idx2 = torch.randint(len(self.dataset), (1,)).item()
251
+ batch2 = self.dataset[idx2]
252
+ x1, x2 = batch1['audio'], batch2['audio']
253
+ y1, y2 = batch1['strong'], batch2['strong']
254
+ if 'pseudo_strong' in batch1:
255
+ p1, p2 = batch1['pseudo_strong'], batch2['pseudo_strong']
256
+ l = np.random.beta(self.beta, self.beta)
257
+ l = max(l, 1. - l)
258
+ x1 = x1 - x1.mean()
259
+ x2 = x2 - x2.mean()
260
+ x = (x1 * l + x2 * (1. - l))
261
+ x = x - x.mean()
262
+ batch1['audio'] = x
263
+ batch1['strong'] = (y1 * l + y2 * (1. - l))
264
+ if 'pseudo_strong' in batch1:
265
+ batch1['pseudo_strong'] = (p1 * l + p2 * (1. - l))
266
+ return batch1
267
+ return self.dataset[index]
268
+
269
+ def __len__(self):
270
+ return len(self.dataset)
271
+
272
+
273
+ class DistributedSamplerWrapper(DistributedSampler):
274
+ def __init__(
275
+ self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True
276
+ ):
277
+ super(DistributedSamplerWrapper, self).__init__(
278
+ dataset, num_replicas, rank, shuffle
279
+ )
280
+ # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
281
+ self.sampler = sampler
282
+
283
+ def __iter__(self):
284
+ if self.sampler.generator is None:
285
+ self.sampler.generator = torch.Generator()
286
+ self.sampler.generator.manual_seed(self.seed + self.epoch)
287
+ indices = list(self.sampler)
288
+ if self.epoch < 2:
289
+ logger.info(
290
+ f"\n DistributedSamplerWrapper (rank {self.rank}) : {indices[:3]} \n\n"
291
+ )
292
+ indices = indices[self.rank : self.total_size : self.num_replicas]
293
+ return iter(indices)
294
+
295
+
296
+ def get_weighted_sampler(
297
+ samples_weights,
298
+ epoch_len=100_000,
299
+ sampler_replace=False,
300
+ ):
301
+ num_nodes = int(os.environ.get("WORLD_SIZE", 1))
302
+ ddp = int(os.environ.get("DDP", 1))
303
+ num_nodes = max(ddp, num_nodes)
304
+ rank = int(os.environ.get("NODE_RANK", 0))
305
+ return DistributedSamplerWrapper(
306
+ sampler=WeightedRandomSampler(
307
+ samples_weights, num_samples=epoch_len, replacement=sampler_replace
308
+ ),
309
+ dataset=range(epoch_len),
310
+ num_replicas=num_nodes,
311
+ rank=rank,
312
+ )
313
+
314
+
315
+ if __name__ == "__main__":
316
+ from helpers.encode import ManyHotEncoder
317
+
318
+ encoder = ManyHotEncoder([], 10., 160, net_pooling=4, fs=16_000)
319
+
320
+ train_ds = get_training_dataset(
321
+ encoder, audio_length=10.0, sample_rate=16_000
322
+ )
323
+
324
+ valid_ds = get_eval_dataset(
325
+ encoder, audio_length=10.0, sample_rate=16_000
326
+ )
327
+
328
+ print("Len train dataset: ", len(train_ds))
329
+ print("Len valid dataset: ", len(valid_ds))
data_util/dcase2016task2.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, List, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import soundfile as sf
8
+ import torch
9
+ from intervaltree import IntervalTree
10
+ from torch.utils.data import Dataset
11
+
12
+
13
+ class FixCropDataset(Dataset):
14
+ """
15
+ Read in a JSON file and return audio and audio filenames
16
+ """
17
+
18
+ def __init__(self, data: Dict,
19
+ audio_dir: Path,
20
+ sample_rate: int,
21
+ label_fps: int,
22
+ label_to_idx: Dict,
23
+ nlabels: int):
24
+ self.clip_len = 120
25
+ self.target_len = 10
26
+ self.pieces_per_clip = self.clip_len // self.target_len
27
+ self.filenames = list(data.keys())
28
+ self.audio_dir = audio_dir
29
+ assert self.audio_dir.is_dir(), f"{audio_dir} is not a directory"
30
+ self.sample_rate = sample_rate
31
+ # all files are 120 seconds long, split them into 12 x 10 second pieces
32
+ self.pieces = []
33
+ self.labels = []
34
+ self.timestamps = []
35
+ for filename in self.filenames:
36
+ self.pieces += [(filename, i) for i in range(self.pieces_per_clip)]
37
+ labels = data[filename]
38
+ frame_len = 1000 / label_fps
39
+ timestamps = np.arange(label_fps * self.clip_len) * frame_len + 0.5 * frame_len
40
+ timestamp_labels = get_labels_for_timestamps(labels, timestamps)
41
+ ys = []
42
+ for timestamp_label in timestamp_labels:
43
+ timestamp_label_idxs = [label_to_idx[str(event)] for event in timestamp_label]
44
+ y_timestamp = label_to_binary_vector(timestamp_label_idxs, nlabels)
45
+ ys.append(y_timestamp)
46
+ ys = torch.stack(ys)
47
+ frames_per_clip = ys.size(0) // self.pieces_per_clip
48
+ self.labels += [ys[frames_per_clip * i: frames_per_clip * (i + 1)] for i in range(self.pieces_per_clip)]
49
+ self.timestamps += [timestamps[frames_per_clip * i: frames_per_clip * (i + 1)] for i in
50
+ range(self.pieces_per_clip)]
51
+
52
+ assert len(self.labels) == len(self.pieces) == len(self.filenames) * self.pieces_per_clip
53
+
54
+ def __len__(self):
55
+ return len(self.pieces)
56
+
57
+ def __getitem__(self, idx):
58
+ filename = self.pieces[idx][0]
59
+ piece = self.pieces[idx][1]
60
+ audio_path = self.audio_dir.joinpath(filename)
61
+ audio, sr = sf.read(str(audio_path), dtype=np.float32)
62
+ assert sr == self.sample_rate
63
+ start = self.sample_rate * piece * self.target_len
64
+ end = start + self.sample_rate * self.target_len
65
+ audio = audio[start:end]
66
+ return audio, self.labels[idx].transpose(0, 1), filename, self.timestamps[idx]
67
+
68
+
69
+ class RandomCropDataset(Dataset):
70
+ """
71
+ Read in a JSON file and return audio and audio filenames
72
+ """
73
+
74
+ def __init__(self, data: Dict,
75
+ audio_dir: Path,
76
+ sample_rate: int,
77
+ label_fps: int,
78
+ label_to_idx: Dict,
79
+ nlabels: int):
80
+ self.clip_len = 120
81
+ self.target_len = 10
82
+ self.pieces_per_clip = self.clip_len // self.target_len
83
+ self.filenames = list(data.keys())
84
+ self.audio_dir = audio_dir
85
+ assert self.audio_dir.is_dir(), f"{audio_dir} is not a directory"
86
+ self.sample_rate = sample_rate
87
+ self.label_fps = label_fps
88
+ # all files are 120 seconds long, randomly crop 10 seconds snippets
89
+ self.labels = []
90
+ self.timestamps = []
91
+ for filename in self.filenames:
92
+ labels = data[filename]
93
+ frame_len = 1000 / label_fps
94
+ timestamps = np.arange(label_fps * self.clip_len) * frame_len + 0.5 * frame_len
95
+ timestamp_labels = get_labels_for_timestamps(labels, timestamps)
96
+ ys = []
97
+ for timestamp_label in timestamp_labels:
98
+ timestamp_label_idxs = [label_to_idx[str(event)] for event in timestamp_label]
99
+ y_timestamp = label_to_binary_vector(timestamp_label_idxs, nlabels)
100
+ ys.append(y_timestamp)
101
+ ys = torch.stack(ys)
102
+ self.labels.append(ys)
103
+ self.timestamps.append(timestamps)
104
+
105
+ assert len(self.labels) == len(self.filenames)
106
+
107
+ def __len__(self):
108
+ return len(self.filenames) * self.clip_len // self.target_len
109
+
110
+ def __getitem__(self, idx):
111
+ idx = idx % len(self.filenames)
112
+ filename = self.filenames[idx]
113
+ audio_path = self.audio_dir.joinpath(filename)
114
+ audio, sr = sf.read(str(audio_path), dtype=np.float32)
115
+ assert sr == self.sample_rate
116
+
117
+ # crop random 10 seconds piece
118
+ labels_to_pick = self.target_len * self.label_fps
119
+ max_offset = len(self.labels[idx]) - labels_to_pick + 1
120
+ offset = torch.randint(max_offset, (1,)).item()
121
+ labels = self.labels[idx][offset:offset + labels_to_pick]
122
+ scale = self.sample_rate // self.label_fps
123
+ audio = audio[offset * scale:offset * scale + labels_to_pick * scale]
124
+ timestamps = self.timestamps[idx][offset:offset + labels_to_pick]
125
+ return audio, labels.transpose(0, 1), filename, timestamps
126
+
127
+
128
+ def get_training_dataset(
129
+ task_path,
130
+ sample_rate=16000,
131
+ label_fps=25,
132
+ wavmix_p=0.0,
133
+ random_crop=True
134
+ ):
135
+ task_path = Path(task_path)
136
+
137
+ label_vocab, nlabels = label_vocab_nlabels(task_path)
138
+ label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
139
+
140
+ train_fold = task_path.joinpath("train.json")
141
+ audio_dir = task_path.joinpath(str(sample_rate), "train")
142
+ train_fold_data = json.load(train_fold.open())
143
+ if random_crop:
144
+ dataset = RandomCropDataset(train_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels)
145
+ else:
146
+ dataset = FixCropDataset(train_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels)
147
+ if wavmix_p > 0:
148
+ dataset = MixupDataset(dataset, rate=wavmix_p)
149
+ return dataset
150
+
151
+
152
+ def get_validation_dataset(
153
+ task_path,
154
+ sample_rate=16000,
155
+ label_fps=25,
156
+ ):
157
+ task_path = Path(task_path)
158
+
159
+ label_vocab, nlabels = label_vocab_nlabels(task_path)
160
+ label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
161
+
162
+ valid_fold = task_path.joinpath("valid.json")
163
+ audio_dir = task_path.joinpath(str(sample_rate), "valid")
164
+ valid_fold_data = json.load(valid_fold.open())
165
+ dataset = FixCropDataset(valid_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels)
166
+ return dataset
167
+
168
+
169
+ def get_test_dataset(
170
+ task_path,
171
+ sample_rate=16000,
172
+ label_fps=25,
173
+ ):
174
+ task_path = Path(task_path)
175
+
176
+ label_vocab, nlabels = label_vocab_nlabels(task_path)
177
+ label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
178
+
179
+ test_fold = task_path.joinpath("test.json")
180
+ audio_dir = task_path.joinpath(str(sample_rate), "test")
181
+ test_fold_data = json.load(test_fold.open())
182
+ dataset = FixCropDataset(test_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels)
183
+ return dataset
184
+
185
+
186
+ def get_labels_for_timestamps(labels: List, timestamps: np.ndarray) -> List:
187
+ # A list of labels present at each timestamp
188
+ tree = IntervalTree()
189
+ # Add all events to the label tree
190
+ for event in labels:
191
+ # We add 0.0001 so that the end also includes the event
192
+ tree.addi(event["start"], event["end"] + 0.0001, event["label"])
193
+
194
+ timestamp_labels = []
195
+ # Update the binary vector of labels with intervals for each timestamp
196
+ for j, t in enumerate(timestamps):
197
+ interval_labels: List[str] = [interval.data for interval in tree[t]]
198
+ timestamp_labels.append(interval_labels)
199
+ # If we want to store the timestamp too
200
+ # labels_for_sound.append([float(t), interval_labels])
201
+
202
+ assert len(timestamp_labels) == len(timestamps)
203
+ return timestamp_labels
204
+
205
+
206
+ def label_vocab_nlabels(task_path: Path) -> Tuple[pd.DataFrame, int]:
207
+ label_vocab = pd.read_csv(task_path.joinpath("labelvocabulary.csv"))
208
+
209
+ nlabels = len(label_vocab)
210
+ assert nlabels == label_vocab["idx"].max() + 1
211
+ return (label_vocab, nlabels)
212
+
213
+
214
+ def label_vocab_as_dict(df: pd.DataFrame, key: str, value: str) -> Dict:
215
+ """
216
+ Returns a dictionary of the label vocabulary mapping the label column to
217
+ the idx column. key sets whether the label or idx is the key in the dict. The
218
+ other column will be the value.
219
+ """
220
+ if key == "label":
221
+ # Make sure the key is a string
222
+ df["label"] = df["label"].astype(str)
223
+ value = "idx"
224
+ else:
225
+ assert key == "idx", "key argument must be either 'label' or 'idx'"
226
+ value = "label"
227
+ return df.set_index(key).to_dict()[value]
228
+
229
+
230
+ def label_to_binary_vector(label: List, num_labels: int) -> torch.Tensor:
231
+ """
232
+ Converts a list of labels into a binary vector
233
+ Args:
234
+ label: list of integer labels
235
+ num_labels: total number of labels
236
+
237
+ Returns:
238
+ A float Tensor that is multi-hot binary vector
239
+ """
240
+ # Lame special case for multilabel with no labels
241
+ if len(label) == 0:
242
+ # BCEWithLogitsLoss wants float not long targets
243
+ binary_labels = torch.zeros((num_labels,), dtype=torch.float)
244
+ else:
245
+ binary_labels = torch.zeros((num_labels,)).scatter(0, torch.tensor(label), 1.0)
246
+
247
+ # Validate the binary vector we just created
248
+ assert set(torch.where(binary_labels == 1.0)[0].numpy()) == set(label)
249
+ return binary_labels
250
+
251
+
252
+ class MixupDataset(Dataset):
253
+ """ Mixing Up wave forms
254
+ """
255
+
256
+ def __init__(self, dataset, beta=0.2, rate=0.5):
257
+ self.beta = beta
258
+ self.rate = rate
259
+ self.dataset = dataset
260
+ print(f"Mixing up waveforms from dataset of len {len(dataset)}")
261
+
262
+ def __getitem__(self, index):
263
+ if torch.rand(1) < self.rate:
264
+ batch1 = self.dataset[index]
265
+ idx2 = torch.randint(len(self.dataset), (1,)).item()
266
+ batch2 = self.dataset[idx2]
267
+ x1, x2 = batch1[0], batch2[0]
268
+ y1, y2 = batch1[1], batch2[1]
269
+ l = np.random.beta(self.beta, self.beta)
270
+ l = max(l, 1. - l)
271
+ x1 = x1 - x1.mean()
272
+ x2 = x2 - x2.mean()
273
+ x = (x1 * l + x2 * (1. - l))
274
+ x = x - x.mean()
275
+ y = (y1 * l + y2 * (1. - l))
276
+ return x, y, batch1[2], batch1[3]
277
+ return self.dataset[index]
278
+
279
+ def __len__(self):
280
+ return len(self.dataset)
data_util/transforms.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import datasets
4
+ import h5py
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+
10
+ from data_util.audioset_classes import as_strong_train_classes
11
+
12
+ ## Transforms with a similar style to https://github.com/descriptinc/audiotools/blob/master/audiotools/data/transforms.py
13
+ logger = datasets.logging.get_logger(__name__)
14
+
15
+
16
+ def target_transform(sample):
17
+ del sample["labels"]
18
+ del sample["label_ids"]
19
+ return sample
20
+
21
+
22
+ def strong_label_transform(sample, strong_label_encoder=None):
23
+ assert strong_label_encoder is not None
24
+ events = pd.DataFrame(sample['events'][0])
25
+ events = events[events['event_label'].isin(set(as_strong_train_classes))]
26
+ strong = strong_label_encoder.encode_strong_df(events).T
27
+ sample["strong"] = [strong]
28
+ sample["event_count"] = [strong.sum(1)]
29
+ # encode ground truth events as string - we will use this for evaluation
30
+ sample["gt_string"] = ["++".join([";;".join([str(e[0]), str(e[1]), e[2]]) for e in
31
+ zip(sample['events'][0]['onset'], sample['events'][0]['offset'],
32
+ sample['events'][0]['event_label'])])]
33
+ del sample['events']
34
+ return sample
35
+
36
+
37
+ class AddPseudoLabelsTransform:
38
+ def __init__(self, pseudo_labels_file):
39
+ self.pseudo_labels_file = pseudo_labels_file
40
+
41
+ if self.pseudo_labels_file is not None:
42
+ # fetch dict of positions for each example
43
+ self.ex2pseudo_idx = {}
44
+ f = h5py.File(self.pseudo_labels_file, "r")
45
+ for i, fname in enumerate(f["filenames"]):
46
+ self.ex2pseudo_idx[fname.decode("UTF-8")] = i
47
+ self._opened_pseudo_hdf5 = None
48
+
49
+ @property
50
+ def pseudo_hdf5_file(self):
51
+ if self._opened_pseudo_hdf5 is None:
52
+ self._opened_pseudo_hdf5 = h5py.File(self.pseudo_labels_file, "r")
53
+ return self._opened_pseudo_hdf5
54
+
55
+ def add_pseudo_label_transform(self, sample):
56
+ indices = [self.ex2pseudo_idx[fn.rstrip(".mp3")] for fn in sample['filename']]
57
+ pseudo_strong = [torch.from_numpy(np.stack(self.pseudo_hdf5_file["strong_logits"][index])).float()
58
+ for index in indices]
59
+ pseudo_strong = [torch.sigmoid(pseudo_strong[i]) for i in range(len(pseudo_strong))]
60
+ sample['pseudo_strong'] = pseudo_strong
61
+ return sample
62
+
63
+
64
+ class SequentialTransform:
65
+ """Apply a sequence of transforms to a batch."""
66
+
67
+ def __init__(self, transforms):
68
+ """
69
+ Args:
70
+ transforms: list of transforms to apply
71
+ """
72
+ self.transforms = transforms
73
+
74
+ def append(self, transform):
75
+ self.transforms.append(transform)
76
+
77
+ def __call__(self, batch):
78
+ for t in self.transforms:
79
+ batch = t(batch)
80
+ return batch
81
+
82
+
83
+ class Mp3DecodeTransform:
84
+ def __init__(
85
+ self,
86
+ mp3_bytes_key="mp3_bytes",
87
+ audio_key="audio",
88
+ sample_rate=32000,
89
+ max_length=10.0,
90
+ min_length=None,
91
+ random_sample_crop=True,
92
+ allow_resample=True,
93
+ resampling_method="sinc_interp_kaiser",
94
+ keep_mp3_bytes=False,
95
+ debug_info_key=None,
96
+ ):
97
+ """Decode mp3 bytes to audio waveform
98
+
99
+ Args:
100
+ mp3_bytes_key (str, optional): The key to mp3 bytes in the input batch. Defaults to "mp3_bytes".
101
+ audio_key (str, optional): The key to save the decoded audio in the output batch. Defaults to "audio".
102
+ sample_rate (int, optional): The expected output audio_key. Defaults to 32000.
103
+ max_length (int, float, optional): the maximum output audio length in seconds if float, otherwise in samples. Defaults to 10.
104
+ min_length (int, optional): the minimum output audio length in seconds. Defaults to max_length.
105
+ random_sample_crop (bool, optional): Randomly crop the audio to max_length if its longer otherwise return the first crop. Defaults to True.
106
+ allow_resample (bool, optional): Resample the singal if the sampling rate don't match. Defaults to True.
107
+ resampling_method (str, optional): reampling method from torchaudio.transforms.Resample . Defaults to "sinc_interp_kaiser".
108
+ keep_mp3_bytes (bool, optional): keep the original bytes in the output dict. Defaults to False.
109
+
110
+ Raises:
111
+ Exception: if minimp3py is not installed
112
+ """
113
+ self.mp3_bytes_key = mp3_bytes_key
114
+ self.audio_key = audio_key
115
+ self.sample_rate = sample_rate
116
+ self.max_length = max_length
117
+ if min_length is None:
118
+ min_length = max_length
119
+ self.min_length = min_length
120
+ self.random_sample_crop = random_sample_crop
121
+ self.allow_resample = allow_resample
122
+ self.resampling_method = resampling_method
123
+ self.keep_mp3_bytes = keep_mp3_bytes
124
+ self.debug_info_key = debug_info_key
125
+ self.resamplers_cache = {}
126
+ try:
127
+ import minimp3py # noqa: F401
128
+ except:
129
+ raise Exception(
130
+ "minimp3py is not installed, please install it using: `CFLAGS='-O3 -march=native' pip install https://github.com/f0k/minimp3py/archive/master.zip`"
131
+ )
132
+
133
+ def __call__(self, batch):
134
+ import minimp3py
135
+
136
+ data_list = batch[self.mp3_bytes_key]
137
+ if self.debug_info_key is not None:
138
+ file_name_list = batch[self.debug_info_key]
139
+ else:
140
+ file_name_list = range(len(data_list))
141
+ audio_list = []
142
+ for data, file_name in zip(data_list, file_name_list):
143
+ try:
144
+ duration, ch, sr = minimp3py.probe(data)
145
+ if isinstance(self.max_length, float):
146
+ max_length = int(self.max_length * sr)
147
+ else:
148
+ max_length = int(self.max_length * sr // self.sample_rate)
149
+ offset = 0
150
+ if self.random_sample_crop and duration > max_length:
151
+ max_offset = max(int(duration - max_length), 0) + 1
152
+ offset = torch.randint(max_offset, (1,)).item()
153
+ waveform, _ = minimp3py.read(data, start=offset, length=max_length)
154
+ waveform = waveform[:, 0] # 0 for the first channel only
155
+ if waveform.dtype != "float32":
156
+ raise RuntimeError("Unexpected wave type")
157
+
158
+ waveform = torch.from_numpy(waveform)
159
+ if len(waveform) == 0:
160
+ logger.warning(
161
+ f"Empty waveform for {file_name}, duration {duration}, offset {offset}, max_length {max_length}, sr {sr}, ch {ch}"
162
+ )
163
+ elif sr != self.sample_rate:
164
+ assert self.allow_resample, f"Unexpected sample rate {sr} instead of {self.sample_rate} at {file_name}"
165
+ if self.resamplers_cache.get(sr) is None:
166
+ self.resamplers_cache[sr] = torchaudio.transforms.Resample(
167
+ sr,
168
+ self.sample_rate,
169
+ resampling_method=self.resampling_method,
170
+ )
171
+ waveform = self.resamplers_cache[sr](waveform)
172
+ min_length = self.min_length
173
+ if isinstance(self.min_length, float):
174
+ min_length = int(self.min_length * self.sample_rate)
175
+ if min_length is not None and len(waveform) < min_length:
176
+ waveform = torch.concatenate(
177
+ (
178
+ waveform,
179
+ torch.zeros(
180
+ min_length - len(waveform),
181
+ dtype=waveform.dtype,
182
+ device=waveform.device,
183
+ ),
184
+ ),
185
+ dim=0,
186
+ )
187
+ audio_list.append(waveform)
188
+ except Exception as e:
189
+ print(f"Error decoding {file_name}: {e}")
190
+ raise e
191
+ batch[self.audio_key] = audio_list
192
+ batch["sampling_rate"] = [self.sample_rate] * len(audio_list)
193
+ if not self.keep_mp3_bytes:
194
+ del batch[self.mp3_bytes_key]
195
+ return batch
ex_audioset_strong.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import argparse
6
+ import torch.nn as nn
7
+ import wandb
8
+ import transformers
9
+ import random
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.loggers import WandbLogger
12
+ import sed_scores_eval
13
+
14
+ from helpers.decode import batched_decode_preds
15
+ from helpers.encode import ManyHotEncoder
16
+ from models.atstframe.ATSTF_wrapper import ATSTWrapper
17
+ from models.beats.BEATs_wrapper import BEATsWrapper
18
+ from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
19
+ from models.m2d.M2D_wrapper import M2DWrapper
20
+ from models.asit.ASIT_wrapper import ASiTWrapper
21
+ from models.prediction_wrapper import PredictionsWrapper
22
+ from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop
23
+ from helpers.utils import worker_init_fn
24
+ from data_util.audioset_strong import get_training_dataset, get_eval_dataset
25
+ from data_util.audioset_strong import get_temporal_count_balanced_sample_weights, get_uniform_sample_weights, \
26
+ get_weighted_sampler
27
+ from data_util.audioset_classes import as_strong_train_classes, as_strong_eval_classes
28
+ from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
29
+ from models.frame_mn.utils import NAME_TO_WIDTH
30
+
31
+
32
+ class PLModule(pl.LightningModule):
33
+ def __init__(self, config, encoder):
34
+ super().__init__()
35
+ self.config = config
36
+ self.encoder = encoder
37
+
38
+ if config.pretrained == "scratch":
39
+ checkpoint = None
40
+ elif config.pretrained == "ssl":
41
+ checkpoint = "ssl"
42
+ elif config.pretrained == "weak":
43
+ checkpoint = "weak"
44
+ elif config.pretrained == "strong":
45
+ checkpoint = "strong_1"
46
+ else:
47
+ raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}")
48
+
49
+ # load transformer model
50
+ if config.model_name == "BEATs":
51
+ beats = BEATsWrapper()
52
+ model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None,
53
+ seq_model_type=config.seq_model_type)
54
+ elif config.model_name == "ATST-F":
55
+ atst = ATSTWrapper()
56
+ model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None,
57
+ seq_model_type=config.seq_model_type)
58
+ elif config.model_name == "fpasst":
59
+ fpasst = FPaSSTWrapper()
60
+ model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None,
61
+ seq_model_type=config.seq_model_type)
62
+ elif config.model_name == "M2D":
63
+ m2d = M2DWrapper()
64
+ model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None,
65
+ seq_model_type=config.seq_model_type,
66
+ embed_dim=m2d.m2d.cfg.feature_d)
67
+ elif config.model_name == "ASIT":
68
+ asit = ASiTWrapper()
69
+ model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None,
70
+ seq_model_type=config.seq_model_type)
71
+ elif config.model_name.startswith("frame_mn"):
72
+ width = NAME_TO_WIDTH(config.model_name)
73
+ frame_mn = FrameMNWrapper(width)
74
+ embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
75
+ model = PredictionsWrapper(frame_mn, checkpoint=f"{config.model_name}_strong_1", embed_dim=embed_dim)
76
+ else:
77
+ raise NotImplementedError(f"Model {config.model_name} not (yet) implemented")
78
+
79
+ self.model = model
80
+
81
+ # prepare ingredients for knowledge distillation
82
+ assert 0 <= config.distillation_loss_weight <= 1, "Lambda for Knowledge Distillation must be between 0 and 1."
83
+ self.strong_loss = nn.BCEWithLogitsLoss()
84
+
85
+ self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0))
86
+
87
+ self.val_durations_df = pd.read_csv(f"resources/eval_durations.csv",
88
+ sep=",", header=None, names=["filename", "duration"])
89
+ self.val_predictions_strong = {}
90
+ self.val_ground_truth = {}
91
+ self.val_duration = {}
92
+ self.val_loss = []
93
+
94
+ def forward(self, batch):
95
+ x = batch["audio"]
96
+ mel = self.model.mel_forward(x)
97
+ y_strong, _ = self.model(mel)
98
+ return y_strong
99
+
100
+ def get_optimizer(
101
+ self, lr, adamw=False, weight_decay=0.01, betas=(0.9, 0.999)
102
+ ):
103
+ # we split the parameters into two groups, one for the pretrained model and one for the downstream model
104
+ # we also split each of them into <=1 dimensional and >=2 dimensional parameters, so we can only
105
+ # apply weight decay to the >=2 dimensional parameters, thus excluding biases and batch norms, an idea from NanoGPT
106
+ params_leq1D = []
107
+ params_geq2D = []
108
+
109
+ for name, param in self.model.named_parameters():
110
+ if param.requires_grad:
111
+ if param.ndimension() >= 2:
112
+ params_geq2D.append(param)
113
+ else:
114
+ params_leq1D.append(param)
115
+
116
+ param_groups = [
117
+ {'params': params_leq1D, 'lr': lr},
118
+ {'params': params_geq2D, 'lr': lr, 'weight_decay': weight_decay},
119
+ ]
120
+
121
+ if weight_decay > 0:
122
+ assert adamw
123
+ assert len(param_groups) > 0
124
+ if adamw:
125
+ print(f"\nUsing adamw weight_decay={weight_decay}!\n")
126
+ return torch.optim.AdamW(param_groups, lr=lr, betas=betas)
127
+ return torch.optim.Adam(param_groups, lr=lr, betas=betas)
128
+
129
+ def get_lr_scheduler(
130
+ self,
131
+ optimizer,
132
+ num_training_steps,
133
+ schedule_mode="cos",
134
+ gamma: float = 0.999996,
135
+ num_warmup_steps=20000,
136
+ lr_end=2e-7,
137
+ ):
138
+ if schedule_mode in {"exp"}:
139
+ return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
140
+ if schedule_mode in {"cosine", "cos"}:
141
+ return transformers.get_cosine_schedule_with_warmup(
142
+ optimizer,
143
+ num_warmup_steps=num_warmup_steps,
144
+ num_training_steps=num_training_steps,
145
+ )
146
+ if schedule_mode in {"linear"}:
147
+ print("Linear schedule!")
148
+ return transformers.get_polynomial_decay_schedule_with_warmup(
149
+ optimizer,
150
+ num_warmup_steps=num_warmup_steps,
151
+ num_training_steps=num_training_steps,
152
+ power=1.0,
153
+ lr_end=lr_end,
154
+ )
155
+ raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
156
+
157
+ def configure_optimizers(self):
158
+ """
159
+ This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
160
+ The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
161
+ :return: dict containing optimizer and learning rate scheduler
162
+ """
163
+ optimizer = self.get_optimizer(self.config.max_lr, adamw=self.config.adamw,
164
+ weight_decay=self.config.weight_decay)
165
+
166
+ num_training_steps = self.trainer.estimated_stepping_batches
167
+
168
+ scheduler = self.get_lr_scheduler(optimizer, num_training_steps,
169
+ schedule_mode=self.config.schedule_mode,
170
+ lr_end=self.config.lr_end)
171
+ lr_scheduler_config = {
172
+ "scheduler": scheduler,
173
+ "interval": "step",
174
+ "frequency": 1
175
+ }
176
+ return [optimizer], [lr_scheduler_config]
177
+
178
+ def training_step(self, train_batch, batch_idx):
179
+ """
180
+ :param train_batch: contains one batch from train dataloader
181
+ :param batch_idx
182
+ :return: a dict containing at least loss that is used to update model parameters, can also contain
183
+ other items that can be processed in 'training_epoch_end' to log other metrics than loss
184
+ """
185
+
186
+ x = train_batch["audio"]
187
+ labels = train_batch['strong']
188
+ if 'pseudo_strong' in train_batch:
189
+ pseudo_labels = train_batch['pseudo_strong']
190
+ else:
191
+ # create dummy pseudo labels
192
+ pseudo_labels = torch.zeros_like(labels)
193
+ assert self.config.distillation_loss_weight == 0
194
+
195
+ mel = self.model.mel_forward(x)
196
+
197
+ # time rolling
198
+ if self.config.frame_shift_range > 0:
199
+ mel, labels, pseudo_labels = frame_shift(
200
+ mel,
201
+ labels,
202
+ pseudo_labels=pseudo_labels,
203
+ net_pooling=self.encoder.net_pooling,
204
+ shift_range=self.config.frame_shift_range
205
+ )
206
+
207
+ # mixup
208
+ if self.config.mixup_p > random.random():
209
+ mel, labels, pseudo_labels = mixup(
210
+ mel,
211
+ targets=labels,
212
+ pseudo_strong=pseudo_labels
213
+ )
214
+
215
+ # mixstyle
216
+ if self.config.mixstyle_p > random.random():
217
+ mel = mixstyle(
218
+ mel
219
+ )
220
+
221
+ # time masking
222
+ if self.config.max_time_mask_size > 0:
223
+ mel, labels, pseudo_labels = time_mask(
224
+ mel,
225
+ labels,
226
+ pseudo_labels=pseudo_labels,
227
+ net_pooling=self.encoder.net_pooling,
228
+ max_mask_ratio=self.config.max_time_mask_size
229
+ )
230
+
231
+ # frequency masking
232
+ if self.config.filter_augment_p > random.random():
233
+ mel, _ = filter_augmentation(
234
+ mel
235
+ )
236
+
237
+ # frequency warping
238
+ if self.config.freq_warp_p > random.random():
239
+ mel = mel.squeeze(1)
240
+ mel = self.freq_warp(mel)
241
+ mel = mel.unsqueeze(1)
242
+
243
+ # forward through network; use strong head
244
+ y_hat_strong, _ = self.model(mel)
245
+
246
+ strong_supervised_loss = self.strong_loss(y_hat_strong, labels)
247
+
248
+ if self.config.distillation_loss_weight > 0:
249
+ strong_distillation_loss = self.strong_loss(y_hat_strong, pseudo_labels)
250
+ else:
251
+ strong_distillation_loss = torch.tensor(0., device=y_hat_strong.device, dtype=y_hat_strong.dtype)
252
+
253
+ loss = self.config.distillation_loss_weight * strong_distillation_loss \
254
+ + (1 - self.config.distillation_loss_weight) * strong_supervised_loss
255
+
256
+ # logging
257
+ self.log('epoch', self.current_epoch)
258
+ for i, param_group in enumerate(self.trainer.optimizers[0].param_groups):
259
+ self.log(f'trainer/lr_optimizer_{i}', param_group['lr'])
260
+ self.log("train/loss", loss.detach().cpu(), prog_bar=True)
261
+ self.log("train/strong_supervised_loss", strong_supervised_loss.detach().cpu())
262
+ self.log("train/strong_distillation_loss", strong_distillation_loss.detach().cpu())
263
+
264
+ return loss
265
+
266
+ def validation_step(self, val_batch, batch_idx):
267
+ # bring ground truth into shape needed for evaluation
268
+ for f, gt_string in zip(val_batch["filename"], val_batch["gt_string"]):
269
+ f = f[:-len(".mp3")]
270
+ events = [e.split(";;") for e in gt_string.split("++")]
271
+ self.val_ground_truth[f] = [(float(e[0]), float(e[1]), e[2]) for e in events]
272
+ self.val_duration[f] = self.val_durations_df[self.val_durations_df["filename"] == f]["duration"].values[0]
273
+
274
+ y_hat_strong = self(val_batch)
275
+ y_strong = val_batch["strong"]
276
+
277
+ loss = self.strong_loss(y_hat_strong, y_strong)
278
+ self.val_loss.append(loss.cpu())
279
+
280
+ scores_raw, scores_postprocessed, prediction_dfs = batched_decode_preds(
281
+ y_hat_strong.float(),
282
+ val_batch['filename'],
283
+ self.encoder,
284
+ median_filter=self.config.median_window
285
+ )
286
+
287
+ self.val_predictions_strong.update(
288
+ scores_postprocessed
289
+ )
290
+
291
+ def on_validation_epoch_end(self):
292
+ gt_unique_events = set([e[2] for f, events in self.val_ground_truth.items() for e in events])
293
+ train_unique_events = set(self.encoder.labels)
294
+ # evaluate on all classes that are in both train and test sets (407 classes)
295
+ class_intersection = gt_unique_events.intersection(train_unique_events)
296
+
297
+ assert len(class_intersection) == len(set(as_strong_train_classes).intersection(as_strong_eval_classes)) == 407, \
298
+ f"Intersection unique events. Expected: {len(set(as_strong_train_classes).intersection(as_strong_eval_classes))}," \
299
+ f" Actual: {len(class_intersection)}"
300
+
301
+ # filter ground truth according to class_intersection
302
+ val_ground_truth = {fid: [event for event in self.val_ground_truth[fid] if event[2] in class_intersection]
303
+ for fid in self.val_ground_truth}
304
+ # drop audios without events - aligned with DESED evaluation procedure
305
+ val_ground_truth = {fid: events for fid, events in val_ground_truth.items() if len(events) > 0}
306
+ # keep only corresponding audio durations
307
+ audio_durations = {
308
+ fid: self.val_duration[fid] for fid in val_ground_truth.keys()
309
+ }
310
+
311
+ # filter files in predictions
312
+ as_strong_preds = {
313
+ fid: self.val_predictions_strong[fid] for fid in val_ground_truth.keys()
314
+ }
315
+ # filter classes in predictions
316
+ unused_classes = list(set(self.encoder.labels).difference(class_intersection))
317
+ for f, df in as_strong_preds.items():
318
+ df.drop(columns=list(unused_classes), axis=1, inplace=True)
319
+
320
+ segment_based_pauroc = sed_scores_eval.segment_based.auroc(
321
+ as_strong_preds,
322
+ val_ground_truth,
323
+ audio_durations,
324
+ max_fpr=0.1,
325
+ segment_length=1.0,
326
+ num_jobs=1
327
+ )
328
+
329
+ psds1 = sed_scores_eval.intersection_based.psds(
330
+ as_strong_preds,
331
+ val_ground_truth,
332
+ audio_durations,
333
+ dtc_threshold=0.7,
334
+ gtc_threshold=0.7,
335
+ cttc_threshold=None,
336
+ alpha_ct=0,
337
+ alpha_st=1,
338
+ num_jobs=1
339
+ )
340
+
341
+ # "val/psds1_macro_averaged" is psds1 without penalization for performance
342
+ # variations across classes
343
+ logs = {"val/loss": torch.as_tensor(self.val_loss).mean().cuda(),
344
+ "val/psds1": psds1[0],
345
+ "val/psds1_macro_averaged": np.array([v for k, v in psds1[1].items()]).mean(),
346
+ "val/pauroc": segment_based_pauroc[0]['mean'],
347
+ }
348
+
349
+ self.log_dict(logs, sync_dist=False)
350
+ self.val_predictions_strong = {}
351
+ self.val_ground_truth = {}
352
+ self.val_duration = {}
353
+ self.val_loss = []
354
+
355
+
356
+ def train(config):
357
+ # Train Models on temporally-strong portion of AudioSet.
358
+
359
+ # logging is done using wandb
360
+ wandb_logger = WandbLogger(
361
+ project="PTSED",
362
+ notes="Pre-Training Transformers for Sound Event Detection on AudioSet Strong.",
363
+ tags=["AudioSet Strong", "Sound Event Detection", "Pseudo Labels", "Knowledge Disitillation"],
364
+ config=config,
365
+ name=config.experiment_name
366
+ )
367
+
368
+ # encoder manages encoding and decoding of model predictions
369
+ encoder = ManyHotEncoder(as_strong_train_classes)
370
+
371
+ train_set = get_training_dataset(encoder, wavmix_p=config.wavmix_p,
372
+ pseudo_labels_file=config.pseudo_labels_file)
373
+ eval_set = get_eval_dataset(encoder)
374
+
375
+ if config.use_balanced_sampler:
376
+ sample_weights = get_temporal_count_balanced_sample_weights(train_set, save_folder="resources")
377
+ else:
378
+ sample_weights = get_uniform_sample_weights(train_set)
379
+
380
+ train_sampler = get_weighted_sampler(sample_weights, epoch_len=config.epoch_len)
381
+
382
+ # train dataloader
383
+ train_dl = DataLoader(dataset=train_set,
384
+ sampler=train_sampler,
385
+ worker_init_fn=worker_init_fn,
386
+ num_workers=config.num_workers,
387
+ batch_size=config.batch_size,
388
+ shuffle=False)
389
+
390
+ # eval dataloader
391
+ eval_dl = DataLoader(dataset=eval_set,
392
+ worker_init_fn=worker_init_fn,
393
+ num_workers=config.num_workers,
394
+ batch_size=config.batch_size)
395
+
396
+ # create pytorch lightening module
397
+ pl_module = PLModule(config, encoder)
398
+
399
+ # create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
400
+ # on which kind of device(s) to train and possible callbacks
401
+ trainer = pl.Trainer(max_epochs=config.n_epochs,
402
+ logger=wandb_logger,
403
+ accelerator='auto',
404
+ devices=config.num_devices,
405
+ precision=config.precision,
406
+ num_sanity_val_steps=0,
407
+ check_val_every_n_epoch=config.check_val_every_n_epoch
408
+ )
409
+
410
+ # start training and validation for the specified number of epochs
411
+ trainer.fit(pl_module, train_dl, eval_dl)
412
+
413
+ wandb.finish()
414
+
415
+
416
+ def evaluate(config):
417
+ # only evaluation of pre-trained models
418
+ # encoder manages encoding and decoding of model predictions
419
+ encoder = ManyHotEncoder(as_strong_train_classes)
420
+ eval_set = get_eval_dataset(encoder)
421
+
422
+ # eval dataloader
423
+ eval_dl = DataLoader(dataset=eval_set,
424
+ worker_init_fn=worker_init_fn,
425
+ num_workers=config.num_workers,
426
+ batch_size=config.batch_size)
427
+
428
+ # create pytorch lightening module
429
+ pl_module = PLModule(config, encoder)
430
+
431
+ # create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
432
+ # on which kind of device(s) to train and possible callbacks
433
+ trainer = pl.Trainer(max_epochs=config.n_epochs,
434
+ accelerator='auto',
435
+ devices=config.num_devices,
436
+ precision=config.precision,
437
+ num_sanity_val_steps=0,
438
+ check_val_every_n_epoch=config.check_val_every_n_epoch)
439
+
440
+ # start evaluation
441
+ trainer.validate(pl_module, eval_dl)
442
+
443
+
444
+ if __name__ == '__main__':
445
+ parser = argparse.ArgumentParser(description='Configuration Parser. ')
446
+
447
+ # general
448
+ parser.add_argument('--experiment_name', type=str, default="AudioSet_Strong")
449
+ parser.add_argument('--batch_size', type=int, default=256)
450
+ parser.add_argument('--num_workers', type=int, default=16)
451
+ parser.add_argument('--num_devices', type=int, default=1)
452
+ parser.add_argument('--precision', type=int, default=16)
453
+ parser.add_argument('--evaluate', action='store_true', default=False)
454
+ parser.add_argument('--check_val_every_n_epoch', type=int, default=5)
455
+
456
+ # model
457
+ parser.add_argument('--model_name', type=str,
458
+ choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"] + \
459
+ [f"frame_mn{width}" for width in ["06", "10"]],
460
+ default="ATST-F") # used also for training
461
+ # "scratch" = no pretraining
462
+ # "ssl" = SSL pre-trained
463
+ # "weak" = AudioSet Weak pre-trained
464
+ # "strong" = AudioSet Strong pre-trained
465
+ parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"],
466
+ default="weak")
467
+ parser.add_argument('--seq_model_type', type=str, choices=["rnn"],
468
+ default=None)
469
+
470
+ # training
471
+ parser.add_argument('--n_epochs', type=int, default=30)
472
+ parser.add_argument('--use_balanced_sampler', action='store_true', default=False)
473
+ parser.add_argument('--distillation_loss_weight', type=float, default=0.0)
474
+ parser.add_argument('--epoch_len', type=int, default=100000)
475
+ parser.add_argument('--median_window', type=int, default=9)
476
+
477
+ # augmentation
478
+ parser.add_argument('--wavmix_p', type=float, default=0.8)
479
+ parser.add_argument('--freq_warp_p', type=float, default=0.8)
480
+ parser.add_argument('--filter_augment_p', type=float, default=0.8)
481
+ parser.add_argument('--frame_shift_range', type=float, default=0.125) # in seconds
482
+ parser.add_argument('--mixup_p', type=float, default=0.3)
483
+ parser.add_argument('--mixstyle_p', type=float, default=0.3)
484
+ parser.add_argument('--max_time_mask_size', type=float, default=0.0)
485
+
486
+ # optimizer
487
+ parser.add_argument('--adamw', action='store_true', default=False)
488
+ parser.add_argument('--weight_decay', type=float, default=0.0)
489
+
490
+ # lr schedule
491
+ parser.add_argument('--schedule_mode', type=str, default="cos")
492
+ parser.add_argument('--max_lr', type=float, default=7e-5)
493
+ parser.add_argument('--lr_end', type=float, default=2e-7)
494
+ parser.add_argument('--warmup_steps', type=int, default=5000)
495
+
496
+ # knowledge distillation
497
+ parser.add_argument('--pseudo_labels_file', type=str,
498
+ default=None)
499
+
500
+ args = parser.parse_args()
501
+ if args.evaluate:
502
+ evaluate(args)
503
+ else:
504
+ train(args)
ex_dcase2016task2.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from pathlib import Path
4
+ from typing import Dict
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.nn as nn
9
+ import transformers
10
+ from einops import rearrange
11
+ from pytorch_lightning.loggers import WandbLogger
12
+ from torch.utils.data import DataLoader
13
+
14
+ import wandb
15
+ from data_util.dcase2016task2 import (get_training_dataset, get_validation_dataset, get_test_dataset,
16
+ label_vocab_nlabels, label_vocab_as_dict)
17
+ from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop
18
+ from helpers.score import get_events_for_all_files, combine_target_events, EventBasedScore, SegmentBasedScore
19
+ from helpers.utils import worker_init_fn
20
+ from models.asit.ASIT_wrapper import ASiTWrapper
21
+ from models.atstframe.ATSTF_wrapper import ATSTWrapper
22
+ from models.beats.BEATs_wrapper import BEATsWrapper
23
+ from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
24
+ from models.m2d.M2D_wrapper import M2DWrapper
25
+ from models.prediction_wrapper import PredictionsWrapper
26
+
27
+
28
+ class PLModule(pl.LightningModule):
29
+ def __init__(self, config):
30
+ super().__init__()
31
+ self.config = config
32
+
33
+ if config.pretrained == "scratch":
34
+ checkpoint = None
35
+ elif config.pretrained == "ssl":
36
+ checkpoint = "ssl"
37
+ elif config.pretrained == "weak":
38
+ checkpoint = "weak"
39
+ elif config.pretrained == "strong":
40
+ checkpoint = "strong_1"
41
+ else:
42
+ raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}")
43
+
44
+ # load transformer model
45
+ if config.model_name == "BEATs":
46
+ beats = BEATsWrapper()
47
+ model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None,
48
+ seq_model_type=config.seq_model_type,
49
+ n_classes_strong=self.config.n_classes)
50
+ elif config.model_name == "ATST-F":
51
+ atst = ATSTWrapper()
52
+ model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None,
53
+ seq_model_type=config.seq_model_type,
54
+ n_classes_strong=self.config.n_classes)
55
+ elif config.model_name == "fpasst":
56
+ fpasst = FPaSSTWrapper()
57
+ model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None,
58
+ seq_model_type=config.seq_model_type,
59
+ n_classes_strong=self.config.n_classes)
60
+ elif config.model_name == "M2D":
61
+ m2d = M2DWrapper()
62
+ model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None,
63
+ seq_model_type=config.seq_model_type,
64
+ n_classes_strong=self.config.n_classes,
65
+ embed_dim=m2d.m2d.cfg.feature_d)
66
+ elif config.model_name == "ASIT":
67
+ asit = ASiTWrapper()
68
+ model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None,
69
+ seq_model_type=config.seq_model_type,
70
+ n_classes_strong=self.config.n_classes)
71
+ else:
72
+ raise NotImplementedError(f"Model {config.model_name} not (yet) implemented")
73
+
74
+ self.model = model
75
+ self.strong_loss = nn.BCEWithLogitsLoss()
76
+
77
+ self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0))
78
+
79
+ task_path = Path(self.config.task_path)
80
+ label_vocab, nlabels = label_vocab_nlabels(task_path)
81
+ self.label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
82
+
83
+ self.idx_to_label: Dict[int, str] = {
84
+ idx: label for (label, idx) in self.label_to_idx.items()
85
+ }
86
+
87
+ self.event_onset_200ms_fms = EventBasedScore(
88
+ label_to_idx=self.label_to_idx,
89
+ name="event_onset_200ms_fms",
90
+ scores=("f_measure", "precision", "recall"),
91
+ params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.2}
92
+ )
93
+
94
+ self.event_onset_50ms_fms = EventBasedScore(
95
+ label_to_idx=self.label_to_idx,
96
+ name="event_onset_50ms_fms",
97
+ scores=("f_measure", "precision", "recall"),
98
+ params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.05}
99
+ )
100
+
101
+ self.segment_1s_er = SegmentBasedScore(
102
+ label_to_idx=self.label_to_idx,
103
+ name="segment_1s_er",
104
+ scores=("error_rate",),
105
+ params={"time_resolution": 1.0},
106
+ maximize=False,
107
+ )
108
+
109
+ self.postprocessing_grid = {
110
+ "median_filter_ms": [
111
+ 250
112
+ ],
113
+ "min_duration": [
114
+ 125
115
+ ]
116
+ }
117
+
118
+ self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], []
119
+
120
+ def forward(self, audio):
121
+ mel = self.model.mel_forward(audio)
122
+ y_strong, _ = self.model(mel)
123
+ return y_strong
124
+
125
+ def separate_params(self):
126
+ pt_params = []
127
+ seq_params = []
128
+ head_params = []
129
+
130
+ for name, p in self.named_parameters():
131
+ name = name[len("model."):]
132
+ if name.startswith('model'):
133
+ # the transformer
134
+ pt_params.append(p)
135
+ elif name.startswith('seq_model'):
136
+ # the optional sequence model
137
+ seq_params.append(p)
138
+ elif name.startswith('strong_head') or name.startswith('weak_head'):
139
+ # the prediction head
140
+ head_params.append(p)
141
+ else:
142
+ raise ValueError(f"Unexpected key in model: {name}")
143
+
144
+ if self.model.has_separate_params():
145
+ # split parameters into groups according to their depth in the network
146
+ # based on this, we can apply layer-wise learning rate decay
147
+ pt_params = self.model.separate_params()
148
+ else:
149
+ if self.config.lr_decay != 1.0:
150
+ raise ValueError(f"Model has no separate_params function. Can't apply layer-wise lr decay, but "
151
+ f"learning rate decay is set to {self.config.lr_decay}.")
152
+
153
+ return pt_params, seq_params, head_params
154
+
155
+ def get_optimizer(
156
+ self,
157
+ lr,
158
+ lr_decay=1.0,
159
+ transformer_lr=None,
160
+ transformer_frozen=False,
161
+ adamw=False,
162
+ weight_decay=0.01,
163
+ betas=(0.9, 0.999)
164
+ ):
165
+ pt_params, seq_params, head_params = self.separate_params()
166
+
167
+ param_groups = [
168
+ {'params': head_params, 'lr': lr}, # model head (besides base model and seq model)
169
+ ]
170
+
171
+ if transformer_frozen:
172
+ for p in pt_params + seq_params:
173
+ if isinstance(p, list):
174
+ for p_i in p:
175
+ p_i.detach_()
176
+ else:
177
+ p.detach_()
178
+ else:
179
+ if transformer_lr is None:
180
+ transformer_lr = lr
181
+ if isinstance(pt_params, list) and isinstance(pt_params[0], list):
182
+ # apply lr decay
183
+ scale_lrs = [transformer_lr * (lr_decay ** i) for i in range(1, len(pt_params) + 1)]
184
+ param_groups = param_groups + [{"params": pt_params[i], "lr": scale_lrs[i]} for i in
185
+ range(len(pt_params))]
186
+ else:
187
+ param_groups.append(
188
+ {'params': pt_params, 'lr': transformer_lr}, # pretrained model
189
+ )
190
+ param_groups.append(
191
+ {'params': seq_params, 'lr': lr}, # pretrained model
192
+ )
193
+
194
+ # do not apply weight decay to biases and batch norms
195
+ param_groups_split = []
196
+ for param_group in param_groups:
197
+ params_1D, params_2D = [], []
198
+ lr = param_group['lr']
199
+ for param in param_group['params']:
200
+ if param.ndimension() >= 2:
201
+ params_2D.append(param)
202
+ elif param.ndimension() <= 1:
203
+ params_1D.append(param)
204
+ param_groups_split += [{'params': params_2D, 'lr': lr, 'weight_decay': weight_decay},
205
+ {'params': params_1D, 'lr': lr}]
206
+ if weight_decay > 0:
207
+ assert adamw
208
+ if adamw:
209
+ print(f"\nUsing adamw weight_decay={weight_decay}!\n")
210
+ return torch.optim.AdamW(param_groups_split, lr=lr, weight_decay=weight_decay, betas=betas)
211
+ return torch.optim.Adam(param_groups_split, lr=lr, betas=betas)
212
+
213
+ def get_lr_scheduler(
214
+ self,
215
+ optimizer,
216
+ num_training_steps,
217
+ schedule_mode="cos",
218
+ gamma: float = 0.999996,
219
+ num_warmup_steps=4000,
220
+ lr_end=1e-7,
221
+ ):
222
+ if schedule_mode in {"exp"}:
223
+ return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
224
+ if schedule_mode in {"cosine", "cos"}:
225
+ return transformers.get_cosine_schedule_with_warmup(
226
+ optimizer,
227
+ num_warmup_steps=num_warmup_steps,
228
+ num_training_steps=num_training_steps,
229
+ )
230
+ if schedule_mode in {"linear"}:
231
+ print("Linear schedule!")
232
+ return transformers.get_polynomial_decay_schedule_with_warmup(
233
+ optimizer,
234
+ num_warmup_steps=num_warmup_steps,
235
+ num_training_steps=num_training_steps,
236
+ power=1.0,
237
+ lr_end=lr_end,
238
+ )
239
+ raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
240
+
241
+ def configure_optimizers(self):
242
+ """
243
+ This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
244
+ The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
245
+ :return: dict containing optimizer and learning rate scheduler
246
+ """
247
+ optimizer = self.get_optimizer(self.config.max_lr,
248
+ lr_decay=self.config.lr_decay,
249
+ transformer_lr=self.config.transformer_lr,
250
+ transformer_frozen=self.config.transformer_frozen,
251
+ adamw=False if self.config.no_adamw else True,
252
+ weight_decay=self.config.weight_decay)
253
+
254
+ num_training_steps = self.trainer.estimated_stepping_batches
255
+
256
+ scheduler = self.get_lr_scheduler(optimizer, num_training_steps,
257
+ schedule_mode=self.config.schedule_mode,
258
+ lr_end=self.config.lr_end)
259
+ lr_scheduler_config = {
260
+ "scheduler": scheduler,
261
+ "interval": "step",
262
+ "frequency": 1
263
+ }
264
+ return [optimizer], [lr_scheduler_config]
265
+
266
+ def training_step(self, train_batch, batch_idx):
267
+ """
268
+ :param train_batch: contains one batch from train dataloader
269
+ :param batch_idx
270
+ :return: a dict containing at least loss that is used to update model parameters, can also contain
271
+ other items that can be processed in 'training_epoch_end' to log other metrics than loss
272
+ """
273
+
274
+ audios, labels, fnames, timestamps = train_batch
275
+
276
+ if self.config.transformer_frozen:
277
+ self.model.model.eval()
278
+ self.model.seq_model.eval()
279
+ mel = self.model.mel_forward(audios)
280
+
281
+ # time rolling
282
+ if self.config.frame_shift_range > 0:
283
+ mel, labels = frame_shift(
284
+ mel,
285
+ labels,
286
+ shift_range=self.config.frame_shift_range
287
+ )
288
+
289
+ # mixup
290
+ if self.config.mixup_p > random.random():
291
+ mel, labels = mixup(
292
+ mel,
293
+ targets=labels
294
+ )
295
+
296
+ # mixstyle
297
+ if self.config.mixstyle_p > random.random():
298
+ mel = mixstyle(
299
+ mel
300
+ )
301
+
302
+ # time masking
303
+ if self.config.max_time_mask_size > 0:
304
+ mel, labels, pseudo_labels = time_mask(
305
+ mel,
306
+ labels,
307
+ max_mask_ratio=self.config.max_time_mask_size
308
+ )
309
+
310
+ # frequency masking
311
+ if self.config.filter_augment_p > random.random():
312
+ mel, _ = filter_augmentation(
313
+ mel
314
+ )
315
+
316
+ # frequency warping
317
+ if self.config.freq_warp_p > random.random():
318
+ mel = mel.squeeze(1)
319
+ mel = self.freq_warp(mel)
320
+ mel = mel.unsqueeze(1)
321
+
322
+ # forward through network; use strong head
323
+ y_hat_strong, _ = self.model(mel)
324
+
325
+ loss = self.strong_loss(y_hat_strong, labels)
326
+
327
+ # logging
328
+ self.log('epoch', self.current_epoch)
329
+ for i, param_group in enumerate(self.trainer.optimizers[0].param_groups):
330
+ self.log(f'trainer/lr_optimizer_{i}', param_group['lr'])
331
+ self.log("train/loss", loss.detach().cpu(), prog_bar=True)
332
+
333
+ return loss
334
+
335
+ def _score_step(self, batch):
336
+ audios, labels, fnames, timestamps = batch
337
+
338
+ strong_preds = self.forward(audios)
339
+
340
+ self.preds.append(strong_preds)
341
+ self.tgts.append(labels)
342
+ self.fnames.append(fnames)
343
+ self.timestamps.append(timestamps)
344
+
345
+ def _score_epoch_end(self, name="val"):
346
+ preds = torch.cat(self.preds)
347
+ tgts = torch.cat(self.tgts)
348
+ fnames = [item for sublist in self.fnames for item in sublist]
349
+ timestamps = torch.cat(self.timestamps)
350
+ val_loss = self.strong_loss(preds, tgts)
351
+ self.log(f"{name}/loss", val_loss, prog_bar=True)
352
+
353
+ # the following function expects one prediction per timestamp (sequence dimension must be flattened)
354
+ seq_len = preds.size(-1)
355
+ preds = rearrange(preds, 'bs c t -> (bs t) c').float()
356
+ timestamps = rearrange(timestamps, 'bs t -> (bs t)').float()
357
+ fnames = [fname for fname in fnames for _ in range(seq_len)]
358
+
359
+ predicted_events_by_postprocessing = get_events_for_all_files(
360
+ preds,
361
+ fnames,
362
+ timestamps,
363
+ self.idx_to_label,
364
+ self.postprocessing_grid
365
+ )
366
+
367
+ # we only have one postprocessing configurations (aligned with HEAR challenge)
368
+ key = list(predicted_events_by_postprocessing.keys())[0]
369
+ predicted_events = predicted_events_by_postprocessing[key]
370
+
371
+ # load ground truth for test fold
372
+ task_path = Path(self.config.task_path)
373
+ test_target_events = combine_target_events(["valid" if name == "val" else "test"], task_path)
374
+ onset_fms = self.event_onset_200ms_fms(predicted_events, test_target_events)
375
+ onset_fms_50 = self.event_onset_50ms_fms(predicted_events, test_target_events)
376
+ segment_1s_er = self.segment_1s_er(predicted_events, test_target_events)
377
+
378
+ self.log(f"{name}/onset_fms", onset_fms[0][1])
379
+ self.log(f"{name}/onset_fms_50", onset_fms_50[0][1])
380
+ self.log(f"{name}/segment_1s_er", segment_1s_er[0][1])
381
+
382
+ # free buffers
383
+ self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], []
384
+
385
+ def validation_step(self, batch, batch_idx):
386
+ self._score_step(batch)
387
+
388
+ def on_validation_epoch_end(self):
389
+ self._score_epoch_end(name="val")
390
+
391
+ def test_step(self, batch, batch_idx):
392
+ self._score_step(batch)
393
+
394
+ def on_test_epoch_end(self):
395
+ self._score_epoch_end(name="test")
396
+
397
+
398
+ def train(config):
399
+ # Example for fine-tuning pre-trained transformers on a downstream task.
400
+
401
+ # logging is done using wandb
402
+ wandb_logger = WandbLogger(
403
+ project="PTSED",
404
+ notes="Downstream Training on office sound event detection.",
405
+ tags=["DCASE 2016 Task 2", "Sound Event Detection"],
406
+ config=config,
407
+ name=config.experiment_name
408
+ )
409
+
410
+ train_set = get_training_dataset(config.task_path, wavmix_p=config.wavmix_p)
411
+ val_ds = get_validation_dataset(config.task_path)
412
+ test_ds = get_test_dataset(config.task_path)
413
+
414
+ # train dataloader
415
+ train_dl = DataLoader(dataset=train_set,
416
+ worker_init_fn=worker_init_fn,
417
+ num_workers=config.num_workers,
418
+ batch_size=config.batch_size,
419
+ shuffle=True)
420
+
421
+ # validation dataloader
422
+ valid_dl = DataLoader(dataset=val_ds,
423
+ worker_init_fn=worker_init_fn,
424
+ num_workers=config.num_workers,
425
+ batch_size=config.batch_size,
426
+ shuffle=False,
427
+ drop_last=False)
428
+
429
+ # test dataloader
430
+ test_dl = DataLoader(dataset=test_ds,
431
+ worker_init_fn=worker_init_fn,
432
+ num_workers=config.num_workers,
433
+ batch_size=config.batch_size,
434
+ shuffle=False,
435
+ drop_last=False)
436
+
437
+ # create pytorch lightening module
438
+ pl_module = PLModule(config)
439
+
440
+ # create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
441
+ # on which kind of device(s) to train and possible callbacks
442
+ trainer = pl.Trainer(max_epochs=config.n_epochs,
443
+ logger=wandb_logger,
444
+ accelerator='auto',
445
+ devices=config.num_devices,
446
+ precision=config.precision,
447
+ num_sanity_val_steps=0,
448
+ check_val_every_n_epoch=config.check_val_every_n_epoch
449
+ )
450
+
451
+ # start training and validation for the specified number of epochs
452
+ trainer.fit(
453
+ pl_module,
454
+ train_dataloaders=train_dl,
455
+ val_dataloaders=valid_dl,
456
+ )
457
+
458
+ test_results = trainer.test(pl_module, dataloaders=test_dl)
459
+ print(test_results)
460
+ wandb.finish()
461
+
462
+
463
+ if __name__ == '__main__':
464
+ parser = argparse.ArgumentParser(description='Configuration Parser. ')
465
+
466
+ # general
467
+ parser.add_argument('--task_path', type=str, required=True)
468
+ parser.add_argument('--experiment_name', type=str, default="DCASE2016Task2")
469
+ parser.add_argument('--batch_size', type=int, default=256)
470
+ parser.add_argument('--num_workers', type=int, default=16)
471
+ parser.add_argument('--num_devices', type=int, default=1)
472
+ parser.add_argument('--precision', type=int, default=16)
473
+ parser.add_argument('--check_val_every_n_epoch', type=int, default=10)
474
+
475
+ # model
476
+ parser.add_argument('--model_name', type=str,
477
+ choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"],
478
+ default="ATST-F") # used also for training
479
+ # "scratch" = no pretraining
480
+ # "ssl" = SSL pre-trained
481
+ # "weak" = AudioSet Weak pre-trained
482
+ # "strong" = AudioSet Strong pre-trained
483
+ parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"],
484
+ default="strong")
485
+ parser.add_argument('--seq_model_type', type=str, choices=["rnn"],
486
+ default=None)
487
+ parser.add_argument('--n_classes', type=int, default=11)
488
+
489
+ # training
490
+ parser.add_argument('--n_epochs', type=int, default=300)
491
+
492
+ # augmentation
493
+ parser.add_argument('--wavmix_p', type=float, default=0.5)
494
+ parser.add_argument('--freq_warp_p', type=float, default=0.0)
495
+ parser.add_argument('--filter_augment_p', type=float, default=0.0)
496
+ parser.add_argument('--frame_shift_range', type=float, default=0.0) # in seconds
497
+ parser.add_argument('--mixup_p', type=float, default=0.5)
498
+ parser.add_argument('--mixstyle_p', type=float, default=0.0)
499
+ parser.add_argument('--max_time_mask_size', type=float, default=0.0)
500
+
501
+ # optimizer
502
+ parser.add_argument('--no_adamw', action='store_true', default=False)
503
+ parser.add_argument('--weight_decay', type=float, default=0.001)
504
+ parser.add_argument('--transformer_frozen', action='store_true', dest='transformer_frozen',
505
+ default=False,
506
+ help='Disable training for the transformer.')
507
+
508
+ # lr schedule
509
+ parser.add_argument('--schedule_mode', type=str, default="cos")
510
+ parser.add_argument('--max_lr', type=float, default=1.06e-4)
511
+ parser.add_argument('--transformer_lr', type=float, default=None)
512
+ parser.add_argument('--lr_decay', type=float, default=1.0)
513
+ parser.add_argument('--lr_end', type=float, default=1e-7)
514
+ parser.add_argument('--warmup_steps', type=int, default=100)
515
+
516
+ args = parser.parse_args()
517
+ train(args)
helpers/augment.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.distributions.beta import Beta
7
+
8
+ def frame_shift(mels, labels, embeddings=None, pseudo_labels=None,
9
+ net_pooling=4, shift_range=0.125):
10
+ bsz, channels, n_bands, frames = mels.shape
11
+ abs_shift_mel = int(frames * shift_range)
12
+
13
+ if embeddings is not None:
14
+ embed_frames = embeddings.shape[-1]
15
+ embed_pool_fact = frames / embed_frames
16
+
17
+ for bindx in range(bsz):
18
+ shift = int(random.gauss(0, abs_shift_mel))
19
+ mels[bindx] = torch.roll(mels[bindx], shift, dims=-1)
20
+ label_shift = -abs(shift) / net_pooling if shift < 0 else shift / net_pooling
21
+ label_shift = round(label_shift)
22
+ labels[bindx] = torch.roll(labels[bindx], label_shift, dims=-1)
23
+
24
+ if pseudo_labels is not None:
25
+ pseudo_labels[bindx] = torch.roll(pseudo_labels[bindx], label_shift, dims=-1)
26
+
27
+ if embeddings is not None:
28
+ embed_shift = -abs(shift) / embed_pool_fact if shift < 0 else shift / embed_pool_fact
29
+ embed_shift = round(embed_shift)
30
+ embeddings[bindx] = torch.roll(embeddings[bindx], embed_shift, dims=-1)
31
+
32
+ out_args = [mels]
33
+ if embeddings is not None:
34
+ out_args.append(embeddings)
35
+ out_args.append(labels)
36
+ if pseudo_labels is not None:
37
+ out_args.append(pseudo_labels)
38
+ return tuple(out_args)
39
+
40
+
41
+ def time_mask(features, labels, embeddings=None, pseudo_labels=None, net_pooling=4,
42
+ min_mask_ratio=0.05, max_mask_ratio=0.2):
43
+ _, _, n_frame = labels.shape
44
+
45
+ if embeddings is not None:
46
+ embed_frames = embeddings.shape[-1]
47
+ embed_pool_fact = embed_frames / n_frame
48
+
49
+ t_width = torch.randint(low=int(n_frame * min_mask_ratio), high=int(n_frame * max_mask_ratio), size=(1,))
50
+ t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,))
51
+ features[:, :, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0
52
+ labels[:, :, t_low:t_low+t_width] = 0
53
+
54
+ if pseudo_labels is not None:
55
+ labels[:, :, t_low:t_low + t_width] = 0
56
+
57
+ if embeddings is not None:
58
+ low = round((t_low * embed_pool_fact).item())
59
+ high = round(((t_low + t_width) * embed_pool_fact).item())
60
+ embeddings[..., low:high] = 0
61
+
62
+ out_args = [features]
63
+
64
+ if embeddings is not None:
65
+ out_args.append(embeddings)
66
+ out_args.append(labels)
67
+ if pseudo_labels is not None:
68
+ out_args.append(pseudo_labels)
69
+ return tuple(out_args)
70
+
71
+
72
+ def mixup(data, embeddings=None, targets=None, pseudo_strong=None, alpha=0.2, beta=0.2, return_mix_coef=False):
73
+ with torch.no_grad():
74
+ batch_size = data.size(0)
75
+ c = np.random.beta(alpha, beta, size=batch_size)
76
+ c = np.maximum(c, 1 - c)
77
+
78
+ perm = torch.randperm(batch_size)
79
+ cd = torch.tensor(c, dtype=data.dtype, device=data.device).view(batch_size, *([1] * (data.ndim - 1)))
80
+ mixed_data = cd * data + (1 - cd) * data[perm, :]
81
+
82
+ if embeddings is not None:
83
+ ce = torch.tensor(c, dtype=embeddings.dtype, device=embeddings.device).view(batch_size, *([1] * (embeddings.ndim - 1)))
84
+ mixed_embeddings = ce * embeddings + (1 - ce) * embeddings[perm, :]
85
+
86
+ if targets is not None:
87
+ ct = torch.tensor(c, dtype=data.dtype, device=data.device).view(batch_size, *([1] * (targets.ndim - 1)))
88
+ mixed_target = torch.clamp(
89
+ ct * targets + (1 - ct) * targets[perm, :], min=0, max=1
90
+ )
91
+
92
+ if pseudo_strong is not None:
93
+ cp = torch.tensor(c, dtype=pseudo_strong.dtype, device=pseudo_strong.device).view(batch_size,
94
+ *([1] * (pseudo_strong.ndim - 1)))
95
+ mixed_pseudo_strong = cp * pseudo_strong + (1 - cp) * pseudo_strong[perm, :]
96
+
97
+ out_args = [mixed_data]
98
+ if embeddings is not None:
99
+ out_args.append(mixed_embeddings)
100
+ if targets is not None:
101
+ out_args.append(mixed_target)
102
+ if pseudo_strong is not None:
103
+ out_args.append(mixed_pseudo_strong)
104
+
105
+ if return_mix_coef:
106
+ out_args.append(perm)
107
+ out_args.append(c)
108
+ return tuple(out_args)
109
+
110
+
111
+ def filt_aug_(features, db_range=(-6, 6), n_band=(3, 6), min_bw=6):
112
+ batch_size, channels, n_freq_bin, _ = features.shape
113
+ n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() # [low, high)
114
+ if n_freq_band > 1:
115
+ while n_freq_bin - n_freq_band * min_bw + 1 < 0:
116
+ min_bw -= 1
117
+ band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1,
118
+ (n_freq_band - 1,)))[0] + \
119
+ torch.arange(1, n_freq_band) * min_bw
120
+ band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin])))
121
+
122
+ band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0]
123
+ freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
124
+ for i in range(n_freq_band):
125
+ for j in range(batch_size):
126
+ freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \
127
+ torch.linspace(band_factors[j, i], band_factors[j, i+1],
128
+ band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1)
129
+ freq_filt = 10 ** (freq_filt / 20)
130
+ return features * freq_filt.unsqueeze(1)
131
+ else:
132
+ return features
133
+
134
+
135
+ def filter_augmentation(features, n_transform=1, filter_db_range=(-6, 6),
136
+ filter_bands=(3, 6), filter_minimum_bandwidth=6):
137
+ if n_transform == 2:
138
+ feature_list = []
139
+ for _ in range(n_transform):
140
+ features_temp = features
141
+ features_temp = filt_aug_(features_temp, db_range=filter_db_range, n_band=filter_bands,
142
+ min_bw=filter_minimum_bandwidth)
143
+ feature_list.append(features_temp)
144
+ return feature_list
145
+ elif n_transform == 1:
146
+ features = filt_aug_(features, db_range=filter_db_range, n_band=filter_bands,
147
+ min_bw=filter_minimum_bandwidth)
148
+ return [features, features]
149
+ else:
150
+ return [features, features]
151
+
152
+
153
+ def mixstyle(x, alpha=0.4, eps=1e-6):
154
+ batch_size = x.size(0)
155
+
156
+ # frequency-wise statistics
157
+ f_mu = x.mean(dim=3, keepdim=True)
158
+ f_var = x.var(dim=3, keepdim=True)
159
+
160
+ f_sig = (f_var + eps).sqrt() # compute instance standard deviation
161
+ f_mu, f_sig = f_mu.detach(), f_sig.detach() # block gradients
162
+ x_normed = (x - f_mu) / f_sig # normalize input
163
+ lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device, dtype=x.dtype) # sample instance-wise convex weights
164
+ lmda = torch.max(lmda, 1-lmda)
165
+ perm = torch.randperm(batch_size).to(x.device) # generate shuffling indices
166
+ f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] # shuffling
167
+ mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) # generate mixed mean
168
+ sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) # generate mixed standard deviation
169
+ x = x_normed * sig_mix + mu_mix # denormalize input using the mixed statistics
170
+ return x
171
+
172
+
173
+ class RandomResizeCrop(nn.Module):
174
+ """Random Resize Crop block.
175
+
176
+ Args:
177
+ virtual_crop_scale: Virtual crop area `(F ratio, T ratio)` in ratio to input size.
178
+ freq_scale: Random frequency range `(min, max)`.
179
+ time_scale: Random time frame range `(min, max)`.
180
+ """
181
+
182
+ def __init__(self, virtual_crop_scale=(1.0, 1.5), freq_scale=(0.6, 1.0), time_scale=(0.6, 1.5)):
183
+ super().__init__()
184
+ self.virtual_crop_scale = virtual_crop_scale
185
+ self.freq_scale = freq_scale
186
+ self.time_scale = time_scale
187
+ self.interpolation = 'bicubic'
188
+ assert time_scale[1] >= 1.0 and freq_scale[1] >= 1.0
189
+
190
+ @staticmethod
191
+ def get_params(virtual_crop_size, in_size, time_scale, freq_scale):
192
+ canvas_h, canvas_w = virtual_crop_size
193
+ src_h, src_w = in_size
194
+ h = np.clip(int(np.random.uniform(*freq_scale) * src_h), 1, canvas_h)
195
+ w = np.clip(int(np.random.uniform(*time_scale) * src_w), 1, canvas_w)
196
+ i = random.randint(0, canvas_h - h) if canvas_h > h else 0
197
+ j = random.randint(0, canvas_w - w) if canvas_w > w else 0
198
+ return i, j, h, w
199
+
200
+ def forward(self, lms):
201
+ # spec_output = []
202
+ # for lms in specs:
203
+ # lms = lms.unsqueeze(0)
204
+ # make virtual_crop_arear empty space (virtual crop area) and copy the input log mel spectrogram to th the center
205
+ virtual_crop_size = [int(s * c) for s, c in zip(lms.shape[-2:], self.virtual_crop_scale)]
206
+ virtual_crop_area = (torch.zeros((lms.shape[0], virtual_crop_size[0], virtual_crop_size[1]))
207
+ .to(torch.float).to(lms.device))
208
+ _, lh, lw = virtual_crop_area.shape
209
+ c, h, w = lms.shape
210
+ x, y = (lw - w) // 2, (lh - h) // 2
211
+ virtual_crop_area[:, y:y+h, x:x+w] = lms
212
+ # get random area
213
+ i, j, h, w = self.get_params(virtual_crop_area.shape[-2:], lms.shape[-2:], self.time_scale, self.freq_scale)
214
+ crop = virtual_crop_area[:, i:i+h, j:j+w]
215
+ # print(f'shapes {virtual_crop_area.shape} {crop.shape} -> {lms.shape}')
216
+ lms = F.interpolate(crop.unsqueeze(1), size=lms.shape[-2:],
217
+ mode=self.interpolation, align_corners=True).squeeze(1)
218
+ # spec_output.append(lms.float())
219
+ return lms.float() # torch.concat(lms, dim=0)
220
+
221
+ def __repr__(self):
222
+ format_string = self.__class__.__name__ + f'(virtual_crop_size={self.virtual_crop_scale}'
223
+ format_string += ', time_scale={0}'.format(tuple(round(s, 4) for s in self.time_scale))
224
+ format_string += ', freq_scale={0})'.format(tuple(round(r, 4) for r in self.freq_scale))
225
+ return format_string
helpers/decode.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from:
3
+ https://github.com/DCASE-REPO/DESED_task
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import scipy
11
+ from sed_scores_eval.base_modules.scores import create_score_dataframe
12
+
13
+
14
+ def batched_decode_preds(
15
+ strong_preds,
16
+ filenames,
17
+ encoder,
18
+ thresholds=[0.5],
19
+ median_filter=None,
20
+ pad_indx=None,
21
+ ):
22
+ """Decode a batch of predictions to dataframes. Each threshold gives a different dataframe and stored in a
23
+ dictionary
24
+
25
+ Args:
26
+ strong_preds: torch.Tensor, batch of strong predictions.
27
+ filenames: list, the list of filenames of the current batch.
28
+ encoder: ManyHotEncoder object, object used to decode predictions.
29
+ thresholds: list, the list of thresholds to be used for predictions.
30
+ median_filter: int, the number of frames for which to apply median window (smoothing).
31
+ pad_indx: list, the list of indexes which have been used for padding.
32
+
33
+ Returns:
34
+ dict of predictions, each keys is a threshold and the value is the DataFrame of predictions.
35
+ """
36
+ # Init a dataframe per threshold
37
+ scores_raw = {}
38
+ scores_postprocessed = {}
39
+ prediction_dfs = {}
40
+ for threshold in thresholds:
41
+ prediction_dfs[threshold] = pd.DataFrame()
42
+
43
+ for j in range(strong_preds.shape[0]): # over batches
44
+ audio_id = Path(filenames[j]).stem
45
+ filename = audio_id + ".wav"
46
+ c_scores = strong_preds[j]
47
+ if pad_indx is not None:
48
+ true_len = int(c_scores.shape[-1] * pad_indx[j].item())
49
+ c_scores = c_scores[:true_len]
50
+ c_scores = c_scores.transpose(0, 1).detach().cpu().numpy()
51
+ scores_raw[audio_id] = create_score_dataframe(
52
+ scores=c_scores,
53
+ timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)),
54
+ event_classes=encoder.labels,
55
+ )
56
+ if median_filter is not None:
57
+ c_scores = scipy.ndimage.filters.median_filter(c_scores, (median_filter, 1))
58
+ scores_postprocessed[audio_id] = create_score_dataframe(
59
+ scores=c_scores,
60
+ timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)),
61
+ event_classes=encoder.labels,
62
+ )
63
+ for c_th in thresholds:
64
+ pred = c_scores > c_th
65
+ pred = encoder.decode_strong(pred)
66
+ pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"])
67
+ pred["filename"] = filename
68
+ prediction_dfs[c_th] = pd.concat(
69
+ [prediction_dfs[c_th], pred], ignore_index=True
70
+ )
71
+
72
+ return scores_raw, scores_postprocessed, prediction_dfs
helpers/encode.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from:
3
+ https://github.com/DCASE-REPO/DESED_task
4
+ """
5
+
6
+ from collections import OrderedDict
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from dcase_util.data import DecisionEncoder
11
+
12
+
13
+ class ManyHotEncoder:
14
+ """"
15
+ Adapted after DecisionEncoder.find_contiguous_regions method in
16
+ https://github.com/DCASE-REPO/dcase_util/blob/master/dcase_util/data/decisions.py
17
+
18
+ Encode labels into numpy arrays where 1 correspond to presence of the class and 0 absence.
19
+ Multiple 1 can appear on the same line, it is for multi label problem.
20
+ Args:
21
+ labels: list, the classes which will be encoded
22
+ n_frames: int, (Default value = None) only useful for strong labels. The number of frames of a segment.
23
+ Attributes:
24
+ labels: list, the classes which will be encoded
25
+ n_frames: int, only useful for strong labels. The number of frames of a segment.
26
+ """
27
+
28
+ def __init__(
29
+ self, labels, audio_len=10, frame_hop=160, net_pooling=4, fs=16000
30
+ ):
31
+ if type(labels) in [np.ndarray, np.array]:
32
+ labels = labels.tolist()
33
+ elif isinstance(labels, (dict, OrderedDict)):
34
+ labels = list(labels.keys())
35
+ self.labels = labels
36
+ self.audio_len = audio_len
37
+ self.frame_hop = frame_hop
38
+ self.fs = fs
39
+ self.net_pooling = net_pooling
40
+ n_frames = self.audio_len * self.fs
41
+ self.n_frames = int(int((n_frames / self.frame_hop)) / self.net_pooling)
42
+
43
+ def encode_weak(self, labels):
44
+ """ Encode a list of weak labels into a numpy array
45
+
46
+ Args:
47
+ labels: list, list of labels to encode (to a vector of 0 and 1)
48
+
49
+ Returns:
50
+ numpy.array
51
+ A vector containing 1 for each label, and 0 everywhere else
52
+ """
53
+ # useful for tensor empty labels
54
+ if type(labels) is str:
55
+ if labels == "empty":
56
+ y = np.zeros(len(self.labels)) - 1
57
+ return y
58
+ else:
59
+ labels = labels.split(",")
60
+ if type(labels) is pd.DataFrame:
61
+ if labels.empty:
62
+ labels = []
63
+ elif "event_label" in labels.columns:
64
+ labels = labels["event_label"]
65
+ y = np.zeros(len(self.labels))
66
+ for label in labels:
67
+ if not pd.isna(label):
68
+ i = self.labels.index(label)
69
+ y[i] = 1
70
+ return y
71
+
72
+ def _time_to_frame(self, time):
73
+ samples = time * self.fs
74
+ frame = (samples) / self.frame_hop
75
+ return np.clip(frame / self.net_pooling, a_min=0, a_max=self.n_frames)
76
+
77
+ def _frame_to_time(self, frame):
78
+ frame = frame * self.net_pooling / (self.fs / self.frame_hop)
79
+ return np.clip(frame, a_min=0, a_max=self.audio_len)
80
+
81
+ def encode_strong_df(self, label_df):
82
+ """Encode a list (or pandas Dataframe or Serie) of strong labels, they correspond to a given filename
83
+
84
+ Args:
85
+ label_df: pandas DataFrame or Series, contains filename, onset (in frames) and offset (in frames)
86
+ If only filename (no onset offset) is specified, it will return the event on all the frames
87
+ onset and offset should be in frames
88
+ Returns:
89
+ numpy.array
90
+ Encoded labels, 1 where the label is present, 0 otherwise
91
+ """
92
+
93
+ assert any(
94
+ [x is not None for x in [self.audio_len, self.frame_hop]]
95
+ )
96
+
97
+ samples_len = self.n_frames
98
+ if type(label_df) is str:
99
+ if label_df == "empty":
100
+ y = np.zeros((samples_len, len(self.labels))) - 1
101
+ return y
102
+ y = np.zeros((samples_len, len(self.labels)))
103
+ if type(label_df) is pd.DataFrame:
104
+ if {"onset", "offset", "event_label"}.issubset(label_df.columns):
105
+ for _, row in label_df.iterrows():
106
+ if not pd.isna(row["event_label"]):
107
+ i = self.labels.index(row["event_label"])
108
+ onset = int(self._time_to_frame(row["onset"]))
109
+ offset = int(np.ceil(self._time_to_frame(row["offset"])))
110
+ if "confidence" in label_df.columns:
111
+ y[onset:offset, i] = row["confidence"] # support confidence
112
+ else:
113
+ y[
114
+ onset:offset, i
115
+ ] = 1 # means offset not included (hypothesis of overlapping frames, so ok)
116
+
117
+ elif type(label_df) in [
118
+ pd.Series,
119
+ list,
120
+ np.ndarray,
121
+ ]: # list of list or list of strings
122
+ if type(label_df) is pd.Series:
123
+ if {"onset", "offset", "event_label"}.issubset(
124
+ label_df.index
125
+ ): # means only one value
126
+ if not pd.isna(label_df["event_label"]):
127
+ i = self.labels.index(label_df["event_label"])
128
+ onset = int(self._time_to_frame(label_df["onset"]))
129
+ offset = int(np.ceil(self._time_to_frame(label_df["offset"])))
130
+
131
+ if "confidence" in label_df.columns:
132
+ y[onset:offset, i] = label_df["confidence"]
133
+ else:
134
+ y[onset:offset, i] = 1
135
+ return y
136
+
137
+ for event_label in label_df:
138
+ # List of string, so weak labels to be encoded in strong
139
+ if type(event_label) is str:
140
+ if event_label != "":
141
+ i = self.labels.index(event_label)
142
+ y[:, i] = 1
143
+
144
+ # List of list, with [label, onset, offset]
145
+ elif len(event_label) == 3:
146
+ if event_label[0] != "":
147
+ i = self.labels.index(event_label[0])
148
+ onset = int(self._time_to_frame(event_label[1]))
149
+ offset = int(np.ceil(self._time_to_frame(event_label[2])))
150
+ y[onset:offset, i] = 1
151
+ # List of list, with [label, onset, offset, confidence]
152
+ elif len(event_label) == 4:
153
+ if event_label[0] != "":
154
+ i = self.labels.index(event_label[0])
155
+ onset = int(self._time_to_frame(event_label[1]))
156
+ offset = int(np.ceil(self._time_to_frame(event_label[2])))
157
+ y[onset:offset, i] = event_label[3]
158
+
159
+ else:
160
+ raise NotImplementedError(
161
+ "cannot encode strong, type mismatch: {}".format(
162
+ type(event_label)
163
+ )
164
+ )
165
+
166
+ else:
167
+ raise NotImplementedError(
168
+ "To encode_strong, type is pandas.Dataframe with onset, offset and event_label"
169
+ "columns, or it is a list or pandas Series of event labels, "
170
+ "type given: {}".format(type(label_df))
171
+ )
172
+ return y
173
+
174
+ def decode_weak(self, labels):
175
+ """ Decode the encoded weak labels
176
+ Args:
177
+ labels: numpy.array, the encoded labels to be decoded
178
+
179
+ Returns:
180
+ list
181
+ Decoded labels, list of string
182
+
183
+ """
184
+ result_labels = []
185
+ for i, value in enumerate(labels):
186
+ if value == 1:
187
+ result_labels.append(self.labels[i])
188
+ return result_labels
189
+
190
+ def decode_strong(self, labels):
191
+ """ Decode the encoded strong labels
192
+ Args:
193
+ labels: numpy.array, the encoded labels to be decoded
194
+ Returns:
195
+ list
196
+ Decoded labels, list of list: [[label, onset offset], ...]
197
+
198
+ """
199
+ result_labels = []
200
+ for i, label_column in enumerate(labels.T):
201
+ change_indices = DecisionEncoder().find_contiguous_regions(label_column)
202
+
203
+ # append [label, onset, offset] in the result list
204
+ for row in change_indices:
205
+ result_labels.append(
206
+ [
207
+ self.labels[i],
208
+ self._frame_to_time(row[0]),
209
+ self._frame_to_time(row[1]),
210
+ ]
211
+ )
212
+ return result_labels
213
+
214
+ def state_dict(self):
215
+ return {
216
+ "labels": self.labels,
217
+ "audio_len": self.audio_len,
218
+ "frame_hop": self.frame_hop,
219
+ "net_pooling": self.net_pooling,
220
+ "fs": self.fs,
221
+ }
222
+
223
+ @classmethod
224
+ def load_state_dict(cls, state_dict):
225
+ labels = state_dict["labels"]
226
+ audio_len = state_dict["audio_len"]
227
+ frame_hop = state_dict["frame_hop"]
228
+ net_pooling = state_dict["net_pooling"]
229
+ fs = state_dict["fs"]
230
+ return cls(labels, audio_len, frame_hop, net_pooling, fs)
helpers/score.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ score functions from: https://hearbenchmark.com/hear-tasks.html
3
+ """
4
+
5
+ import json
6
+ from collections import ChainMap
7
+ from pathlib import Path
8
+ from typing import Dict, Optional, Tuple, Union, List, Any
9
+
10
+ import more_itertools
11
+ import numpy as np
12
+ import sed_eval
13
+ import torch
14
+ from dcase_util.containers import MetaDataContainer
15
+ from scipy.ndimage import median_filter
16
+ from sklearn.model_selection import ParameterGrid
17
+ from tqdm import tqdm
18
+
19
+
20
+ def validate_score_return_type(ret: Union[Tuple[Tuple[str, float], ...], float]):
21
+ """
22
+ Valid return types for the metric are
23
+ - tuple(tuple(string: name of the subtype, float: the value)): This is the
24
+ case with sed eval metrics. They can return (("f_measure", value),
25
+ ("precision", value), ...), depending on the scores
26
+ the metric should is supposed to return. This is set as `scores`
27
+ attribute in the metric.
28
+ - float: Standard metric behaviour
29
+
30
+ The downstream prediction pipeline is able to handle these two types.
31
+ In case of the tuple return type, the value of the first entry in the
32
+ tuple will be used as an optimisation criterion wherever required.
33
+ For instance, if the return is (("f_measure", value), ("precision", value)),
34
+ the value corresponding to the f_measure will be used ( for instance in
35
+ early stopping if this metric is the primary score for the task )
36
+ """
37
+ if isinstance(ret, tuple):
38
+ assert all(
39
+ type(s) == tuple and type(s[0]) == str and type(s[1]) == float for s in ret
40
+ ), (
41
+ "If the return type of the score is a tuple, all the elements "
42
+ "in the tuple should be tuple of type (string, float)"
43
+ )
44
+ elif isinstance(ret, float):
45
+ pass
46
+ else:
47
+ raise ValueError(
48
+ f"Return type {type(ret)} is unexpected. Return type of "
49
+ "the score function should either be a "
50
+ "tuple(tuple) or float. "
51
+ )
52
+
53
+
54
+ class ScoreFunction:
55
+ """
56
+ A simple abstract base class for score functions
57
+ """
58
+
59
+ # TODO: Remove label_to_idx?
60
+ def __init__(
61
+ self,
62
+ label_to_idx: Dict[str, int],
63
+ name: Optional[str] = None,
64
+ maximize: bool = True,
65
+ ):
66
+ """
67
+ :param label_to_idx: Map from label string to integer index.
68
+ :param name: Override the name of this scoring function.
69
+ :param maximize: Maximize this score? (Otherwise, it's a loss or energy
70
+ we want to minimize, and I guess technically isn't a score.)
71
+ """
72
+ self.label_to_idx = label_to_idx
73
+ if name:
74
+ self.name = name
75
+ self.maximize = maximize
76
+
77
+ def __call__(self, *args, **kwargs) -> Union[Tuple[Tuple[str, float], ...], float]:
78
+ """
79
+ Calls the compute function of the metric, and after validating the output,
80
+ returns the metric score
81
+ """
82
+ ret = self._compute(*args, **kwargs)
83
+ validate_score_return_type(ret)
84
+ return ret
85
+
86
+ def _compute(
87
+ self, predictions: Any, targets: Any, **kwargs
88
+ ) -> Union[Tuple[Tuple[str, float], ...], float]:
89
+ """
90
+ Compute the score based on the predictions and targets.
91
+ This is a private function and the metric should be used as a functor
92
+ by calling the `__call__` method which calls this and also validates
93
+ the return type
94
+ """
95
+ raise NotImplementedError("Inheriting classes must implement this function")
96
+
97
+ def __str__(self):
98
+ return self.name
99
+
100
+
101
+ class SoundEventScore(ScoreFunction):
102
+ """
103
+ Scores for sound event detection tasks using sed_eval
104
+ """
105
+
106
+ # Score class must be defined in inheriting classes
107
+ score_class: sed_eval.sound_event.SoundEventMetrics = None
108
+
109
+ def __init__(
110
+ self,
111
+ label_to_idx: Dict[str, int],
112
+ scores: Tuple[str],
113
+ params: Dict = None,
114
+ name: Optional[str] = None,
115
+ maximize: bool = True,
116
+ ):
117
+ """
118
+ :param scores: Scores to use, from the list of overall SED eval scores.
119
+ The first score in the tuple will be the primary score for this metric
120
+ :param params: Parameters to pass to the scoring function,
121
+ see inheriting children for details.
122
+ """
123
+ if params is None:
124
+ params = {}
125
+ super().__init__(label_to_idx=label_to_idx, name=name, maximize=maximize)
126
+ self.scores = scores
127
+ self.params = params
128
+ assert self.score_class is not None
129
+
130
+ def _compute(
131
+ self, predictions: Dict, targets: Dict, **kwargs
132
+ ) -> Tuple[Tuple[str, float], ...]:
133
+ # Containers of events for sed_eval
134
+ reference_event_list = self.sed_eval_event_container(targets)
135
+ estimated_event_list = self.sed_eval_event_container(predictions)
136
+
137
+ # This will break in Python < 3.6 if the dict order is not
138
+ # the insertion order I think. I'm a little worried about this line
139
+ scores = self.score_class(
140
+ event_label_list=list(self.label_to_idx.keys()), **self.params
141
+ )
142
+
143
+ for filename in predictions:
144
+ scores.evaluate(
145
+ reference_event_list=reference_event_list.filter(filename=filename),
146
+ estimated_event_list=estimated_event_list.filter(filename=filename),
147
+ )
148
+
149
+ # results_overall_metrics return a pretty large nested selection of scores,
150
+ # with dicts of scores keyed on the type of scores, like f_measure, error_rate,
151
+ # accuracy
152
+ nested_overall_scores: Dict[
153
+ str, Dict[str, float]
154
+ ] = scores.results_overall_metrics()
155
+ # Open up nested overall scores
156
+ overall_scores: Dict[str, float] = dict(
157
+ ChainMap(*nested_overall_scores.values())
158
+ )
159
+ # Return the required scores as tuples. The scores are returned in the
160
+ # order they are passed in the `scores` argument
161
+ return tuple([(score, overall_scores[score]) for score in self.scores])
162
+
163
+ @staticmethod
164
+ def sed_eval_event_container(
165
+ x: Dict[str, List[Dict[str, Any]]]
166
+ ) -> MetaDataContainer:
167
+ # Reformat event list for sed_eval
168
+ reference_events = []
169
+ for filename, event_list in x.items():
170
+ for event in event_list:
171
+ reference_events.append(
172
+ {
173
+ # Convert from ms to seconds for sed_eval
174
+ "event_label": str(event["label"]),
175
+ "event_onset": event["start"] / 1000.0,
176
+ "event_offset": event["end"] / 1000.0,
177
+ "file": filename,
178
+ }
179
+ )
180
+ return MetaDataContainer(reference_events)
181
+
182
+
183
+ class EventBasedScore(SoundEventScore):
184
+ """
185
+ event-based scores - the ground truth and system output are compared at
186
+ event instance level;
187
+
188
+ See https://tut-arg.github.io/sed_eval/generated/sed_eval.sound_event.EventBasedMetrics.html # noqa: E501
189
+ for params.
190
+ """
191
+
192
+ score_class = sed_eval.sound_event.EventBasedMetrics
193
+
194
+
195
+ class SegmentBasedScore(SoundEventScore):
196
+ """
197
+ segment-based scores - the ground truth and system output are compared in a
198
+ fixed time grid; sound events are marked as active or inactive in each segment;
199
+
200
+ See https://tut-arg.github.io/sed_eval/sound_event.html#sed_eval.sound_event.SegmentBasedMetrics # noqa: E501
201
+ for params.
202
+ """
203
+
204
+ score_class = sed_eval.sound_event.SegmentBasedMetrics
205
+
206
+
207
+ def get_events_for_all_files(
208
+ predictions: torch.Tensor,
209
+ filenames: List[str],
210
+ timestamps: torch.Tensor,
211
+ idx_to_label: Dict[int, str],
212
+ postprocessing_grid: Dict[str, List[float]],
213
+ postprocessing: Optional[Tuple[Tuple[str, Any], ...]] = None,
214
+ ) -> Dict[Tuple[Tuple[str, Any], ...], Dict[str, List[Dict[str, Union[str, float]]]]]:
215
+ """
216
+ Produces lists of events from a set of frame based label probabilities.
217
+ The input prediction tensor may contain frame predictions from a set of different
218
+ files concatenated together. file_timestamps has a list of filenames and
219
+ timestamps for each frame in the predictions tensor.
220
+
221
+ We split the predictions into separate tensors based on the filename and compute
222
+ events based on those individually.
223
+
224
+ If no postprocessing is specified (during training), we try a
225
+ variety of ways of postprocessing the predictions into events,
226
+ from the postprocessing_grid including median filtering and
227
+ minimum event length.
228
+
229
+ If postprocessing is specified (during test, chosen at the best
230
+ validation epoch), we use this postprocessing.
231
+
232
+ Args:
233
+ predictions: a tensor of frame based multi-label predictions.
234
+ filenames: a list of filenames where each entry corresponds
235
+ to a frame in the predictions tensor.
236
+ timestamps: a list of timestamps where each entry corresponds
237
+ to a frame in the predictions tensor.
238
+ idx_to_label: Index to label mapping.
239
+ postprocessing: See above.
240
+
241
+ Returns:
242
+ A dictionary from filtering params to the following values:
243
+ A dictionary of lists of events keyed on the filename slug.
244
+ The event list is of dicts of the following format:
245
+ {"label": str, "start": float ms, "end": float ms}
246
+ """
247
+ # This probably could be more efficient if we make the assumption that
248
+ # timestamps are in sorted order. But this makes sure of it.
249
+ assert predictions.shape[0] == len(filenames)
250
+ assert predictions.shape[0] == len(timestamps)
251
+ event_files: Dict[str, Dict[float, torch.Tensor]] = {}
252
+ for i, (filename, timestamp) in enumerate(zip(filenames, timestamps)):
253
+ slug = Path(filename).name
254
+
255
+ # Key on the slug to be consistent with the ground truth
256
+ if slug not in event_files:
257
+ event_files[slug] = {}
258
+
259
+ # Save the predictions for the file keyed on the timestamp
260
+ event_files[slug][float(timestamp)] = predictions[i]
261
+
262
+ # Create events for all the different files. Store all the events as a dictionary
263
+ # with the same format as the ground truth from the luigi pipeline.
264
+ # Ex) { slug -> [{"label" : "woof", "start": 0.0, "end": 2.32}, ...], ...}
265
+ event_dict: Dict[
266
+ Tuple[Tuple[str, Any], ...], Dict[str, List[Dict[str, Union[float, str]]]]
267
+ ] = {}
268
+ if postprocessing:
269
+ postprocess = postprocessing
270
+ event_dict[postprocess] = {}
271
+ for slug, timestamp_predictions in event_files.items():
272
+ event_dict[postprocess][slug] = create_events_from_prediction(
273
+ timestamp_predictions, idx_to_label, **dict(postprocess)
274
+ )
275
+ else:
276
+ postprocessing_confs = list(ParameterGrid(postprocessing_grid))
277
+ for postprocess_dict in tqdm(postprocessing_confs):
278
+ postprocess = tuple(postprocess_dict.items())
279
+ event_dict[postprocess] = {}
280
+ for slug, timestamp_predictions in event_files.items():
281
+ event_dict[postprocess][slug] = create_events_from_prediction(
282
+ timestamp_predictions, idx_to_label, **postprocess_dict
283
+ )
284
+
285
+ return event_dict
286
+
287
+
288
+ def create_events_from_prediction(
289
+ prediction_dict: Dict[float, torch.Tensor],
290
+ idx_to_label: Dict[int, str],
291
+ threshold: float = 0.5,
292
+ median_filter_ms: float = 150,
293
+ min_duration: float = 60.0,
294
+ ) -> List[Dict[str, Union[float, str]]]:
295
+ """
296
+ Takes a set of prediction tensors keyed on timestamps and generates events.
297
+ (This is for one particular audio scene.)
298
+ We convert the prediction tensor to a binary label based on the threshold value. Any
299
+ events occurring at adjacent timestamps are considered to be part of the same event.
300
+ This loops through and creates events for each label class.
301
+ We optionally apply median filtering to predictions.
302
+ We disregard events that are less than the min_duration milliseconds.
303
+
304
+ Args:
305
+ prediction_dict: A dictionary of predictions keyed on timestamp
306
+ {timestamp -> prediction}. The prediction is a tensor of label
307
+ probabilities.
308
+ idx_to_label: Index to label mapping.
309
+ threshold: Threshold for determining whether to apply a label
310
+ min_duration: the minimum duration in milliseconds for an
311
+ event to be included.
312
+
313
+ Returns:
314
+ A list of dicts withs keys "label", "start", and "end"
315
+ """
316
+ # Make sure the timestamps are in the correct order
317
+ timestamps = np.array(sorted(prediction_dict.keys()))
318
+
319
+ # Create a sorted numpy matrix of frame level predictions for this file. We convert
320
+ # to a numpy array here before applying a median filter.
321
+ predictions = np.stack(
322
+ [prediction_dict[t].detach().cpu().numpy() for t in timestamps]
323
+ )
324
+
325
+ # Optionally apply a median filter here to smooth out events.
326
+ ts_diff = np.mean(np.diff(timestamps))
327
+ if median_filter_ms:
328
+ filter_width = int(round(median_filter_ms / ts_diff))
329
+ if filter_width:
330
+ predictions = median_filter(predictions, size=(filter_width, 1))
331
+
332
+ # Convert probabilities to binary vectors based on threshold
333
+ predictions = (predictions > threshold).astype(np.int8)
334
+
335
+ events = []
336
+ for label in range(predictions.shape[1]):
337
+ for group in more_itertools.consecutive_groups(
338
+ np.where(predictions[:, label])[0]
339
+ ):
340
+ grouptuple = tuple(group)
341
+ assert (
342
+ tuple(sorted(grouptuple)) == grouptuple
343
+ ), f"{sorted(grouptuple)} != {grouptuple}"
344
+ startidx, endidx = (grouptuple[0], grouptuple[-1])
345
+
346
+ start = timestamps[startidx]
347
+ end = timestamps[endidx]
348
+ # Add event if greater than the minimum duration threshold
349
+ if end - start >= min_duration:
350
+ events.append(
351
+ {"label": idx_to_label[label], "start": start, "end": end}
352
+ )
353
+
354
+ # This is just for pretty output, not really necessary
355
+ events.sort(key=lambda k: k["start"])
356
+ return events
357
+
358
+
359
+ def combine_target_events(split_names: List[str], task_path):
360
+ """
361
+ This combines the target events from the list of splits and
362
+ returns the combined target events. This is useful when combining
363
+ multiple folds of data to create the training or validation
364
+ dataloader. For example, in k-fold, the training data-loader
365
+ might be made from the first 4/5 folds, and calling this function
366
+ with [fold00, fold01, fold02, fold03] will return the
367
+ aggregated target events across all the folds
368
+ """
369
+ combined_target_events: Dict = {}
370
+ for split_name in split_names:
371
+ target_events = json.load(
372
+ task_path.joinpath(f"{split_name}.json").open()
373
+ )
374
+ common_keys = set(combined_target_events.keys()).intersection(
375
+ target_events.keys()
376
+ )
377
+ assert len(common_keys) == 0, (
378
+ "Target events from one split should not override "
379
+ "target events from another. This is very unlikely as the "
380
+ "target_event is keyed on the files which are distinct for "
381
+ "each split"
382
+ )
383
+ combined_target_events.update(target_events)
384
+ return combined_target_events
helpers/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+
5
+
6
+ def worker_init_fn(x):
7
+ seed = (torch.initial_seed() + x * 1000) % 2 ** 31 # problem with nearly seeded randoms
8
+
9
+ np.random.seed(seed)
10
+ random.seed(seed)
11
+ torch.manual_seed(seed)
12
+ return
images/downstream_task_results.png ADDED
inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import librosa
3
+ import torch
4
+
5
+ from data_util import audioset_classes
6
+ from helpers.decode import batched_decode_preds
7
+ from helpers.encode import ManyHotEncoder
8
+ from models.atstframe.ATSTF_wrapper import ATSTWrapper
9
+ from models.beats.BEATs_wrapper import BEATsWrapper
10
+ from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
11
+ from models.m2d.M2D_wrapper import M2DWrapper
12
+ from models.asit.ASIT_wrapper import ASiTWrapper
13
+ from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
14
+ from models.prediction_wrapper import PredictionsWrapper
15
+ from models.frame_mn.utils import NAME_TO_WIDTH
16
+
17
+
18
+ def sound_event_detection(args):
19
+ """
20
+ Running Sound Event Detection on an audio clip.
21
+ """
22
+ device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
23
+ model_name = args.model_name
24
+
25
+ if model_name == "BEATs":
26
+ beats = BEATsWrapper()
27
+ model = PredictionsWrapper(beats, checkpoint="BEATs_strong_1")
28
+ elif model_name == "ATST-F":
29
+ atst = ATSTWrapper()
30
+ model = PredictionsWrapper(atst, checkpoint="ATST-F_strong_1")
31
+ elif model_name == "fpasst":
32
+ fpasst = FPaSSTWrapper()
33
+ model = PredictionsWrapper(fpasst, checkpoint="fpasst_strong_1")
34
+ elif model_name == "M2D":
35
+ m2d = M2DWrapper()
36
+ model = PredictionsWrapper(m2d, checkpoint="M2D_strong_1", embed_dim=m2d.m2d.cfg.feature_d)
37
+ elif model_name == "ASIT":
38
+ asit = ASiTWrapper()
39
+ model = PredictionsWrapper(asit, checkpoint="ASIT_strong_1")
40
+ elif model_name.startswith("frame_mn"):
41
+ width = NAME_TO_WIDTH(model_name)
42
+ frame_mn = FrameMNWrapper(width)
43
+ embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
44
+ model = PredictionsWrapper(frame_mn, checkpoint=f"{model_name}_strong_1", embed_dim=embed_dim)
45
+ else:
46
+ raise NotImplementedError(f"Model {model_name} not (yet) implemented")
47
+
48
+ model.eval()
49
+ model.to(device)
50
+
51
+ sample_rate = 16_000 # all our models are trained on 16 kHz audio
52
+ segment_duration = 10 # all models are trained on 10-second pieces
53
+ segment_samples = segment_duration * sample_rate
54
+
55
+ # load audio
56
+ (waveform, _) = librosa.core.load(args.audio_file, sr=sample_rate, mono=True)
57
+ waveform = torch.from_numpy(waveform[None, :]).to(device)
58
+ waveform_len = waveform.shape[1]
59
+
60
+ audio_len = waveform_len / sample_rate # in seconds
61
+ print("Audio length (seconds): ", audio_len)
62
+
63
+ # encoder manages decoding of model predictions into dataframes
64
+ # containing event labels, onsets and offsets
65
+ encoder = ManyHotEncoder(audioset_classes.as_strong_train_classes, audio_len=audio_len)
66
+
67
+ # split audio file into 10-second chunks
68
+ num_chunks = waveform_len // segment_samples + (waveform_len % segment_samples != 0)
69
+ all_predictions = []
70
+
71
+ # Process each 10-second chunk
72
+ for i in range(num_chunks):
73
+ start_idx = i * segment_samples
74
+ end_idx = min((i + 1) * segment_samples, waveform_len)
75
+ waveform_chunk = waveform[:, start_idx:end_idx]
76
+
77
+ # Pad the last chunk if it's shorter than 10 seconds
78
+ if waveform_chunk.shape[1] < segment_samples:
79
+ pad_size = segment_samples - waveform_chunk.shape[1]
80
+ waveform_chunk = torch.nn.functional.pad(waveform_chunk, (0, pad_size))
81
+
82
+ # Run inference for each chunk
83
+ with torch.no_grad():
84
+ mel = model.mel_forward(waveform_chunk)
85
+ y_strong, _ = model(mel)
86
+
87
+ # Collect predictions
88
+ all_predictions.append(y_strong)
89
+
90
+ # Concatenate all predictions along the time axis
91
+ y_strong = torch.cat(all_predictions, dim=2)
92
+ # convert into probabilities
93
+ y_strong = torch.sigmoid(y_strong)
94
+
95
+ (
96
+ scores_unprocessed,
97
+ scores_postprocessed,
98
+ decoded_predictions
99
+ ) = batched_decode_preds(
100
+ y_strong.float(),
101
+ [args.audio_file],
102
+ encoder,
103
+ median_filter=args.median_window,
104
+ thresholds=args.detection_thresholds,
105
+ )
106
+
107
+ for th in decoded_predictions:
108
+ print("***************************************")
109
+ print(f"Detected events using threshold {th}:")
110
+ print(decoded_predictions[th].sort_values(by="onset"))
111
+ print("***************************************")
112
+
113
+
114
+ if __name__ == "__main__":
115
+ parser = argparse.ArgumentParser(description='Example of parser. ')
116
+ # model names: [BEATs, ASIT, ATST-F, fpasst, M2D]
117
+ parser.add_argument('--model_name', type=str, default='BEATs')
118
+ parser.add_argument('--audio_file', type=str,
119
+ default='test_files/752547__iscence__milan_metro_coming_in_station.wav')
120
+ parser.add_argument('--detection_thresholds', type=float, default=(0.1, 0.2, 0.5))
121
+ parser.add_argument('--median_window', type=float, default=9)
122
+ parser.add_argument('--cuda', action='store_true', default=False)
123
+ args = parser.parse_args()
124
+
125
+ assert args.model_name in ["BEATs", "ASIT", "ATST-F", "fpasst", "M2D"] or args.model_name.startswith("frame_mn")
126
+ sound_event_detection(args)
models/asit/ASIT_wrapper.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.asit.data_transformations import DataAugmentation
2
+ from models.asit.vision_transformer import vit_base
3
+ from models.transformer_wrapper import BaseModelWrapper
4
+
5
+
6
+ class ASiTWrapper(BaseModelWrapper):
7
+ def __init__(self) -> None:
8
+ super().__init__()
9
+ self.asit_mel = DataAugmentation()
10
+ self.asit = vit_base(
11
+ patch_size=[16, 16],
12
+ audio_size=[128, 592],
13
+ stride=[16, 16],
14
+ in_chans=1,
15
+ num_classes=0
16
+ )
17
+
18
+ def mel_forward(self, x):
19
+ return self.asit_mel(x)
20
+
21
+ def forward(self, spec):
22
+ return self.asit(spec)
23
+
24
+ def separate_params(self):
25
+ pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
26
+ for k, p in self.named_parameters():
27
+ if any(['cls_token' in k,
28
+ 'pos_embed' in k,
29
+ 'norm_stats' in k,
30
+ 'patch_embed' in k]):
31
+ pt_params[0].append(p)
32
+ elif 'blocks.0.' in k:
33
+ pt_params[0].append(p)
34
+ elif 'blocks.1.' in k:
35
+ pt_params[1].append(p)
36
+ elif 'blocks.2.' in k:
37
+ pt_params[2].append(p)
38
+ elif 'blocks.3.' in k:
39
+ pt_params[3].append(p)
40
+ elif 'blocks.4.' in k:
41
+ pt_params[4].append(p)
42
+ elif 'blocks.5.' in k:
43
+ pt_params[5].append(p)
44
+ elif 'blocks.6.' in k:
45
+ pt_params[6].append(p)
46
+ elif 'blocks.7.' in k:
47
+ pt_params[7].append(p)
48
+ elif 'blocks.8.' in k:
49
+ pt_params[8].append(p)
50
+ elif 'blocks.9.' in k:
51
+ pt_params[9].append(p)
52
+ elif 'blocks.10.' in k:
53
+ pt_params[10].append(p)
54
+ elif 'blocks.11.' in k:
55
+ pt_params[11].append(p)
56
+ elif 'asit.norm.weight' in k or 'asit.norm.bias' in k:
57
+ pt_params[11].append(p)
58
+ else:
59
+ raise ValueError(f"Check separate params for ASiT! Unknown key: {k}")
60
+ return list(reversed(pt_params))
models/asit/data_transformations.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional
3
+ import torchaudio
4
+
5
+
6
+ class DataAugmentation(object):
7
+ def __init__(self, data_mean=-4.2677393, data_std=4.5689974, num_mel_bins=128, sample_rate=16000):
8
+ self.data_mean = data_mean
9
+ self.data_std = data_std
10
+ self.num_mel_bins = num_mel_bins
11
+ self.sample_rate = sample_rate
12
+
13
+ def _wav2fbank(self, waveform):
14
+ waveform = (waveform - waveform.mean())
15
+ fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=self.sample_rate,
16
+ use_energy=False,
17
+ window_type='hanning', num_mel_bins=self.num_mel_bins, dither=0.0,
18
+ frame_shift=10)
19
+ return fbank
20
+
21
+ def convert_waveform(self, waveform):
22
+ w = self._wav2fbank(waveform)
23
+ fbank = (w - self.data_mean) / (self.data_std * 2)
24
+ fbank = fbank.unsqueeze(0)
25
+ return fbank
26
+
27
+ def __call__(self, batch):
28
+ # apply convert_waveform to each sample of the batch and return the result
29
+ return torch.stack([self.convert_waveform(sample.reshape(1, -1)) for sample in batch]).permute(0, 1, 3, 2)
models/asit/utils.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings("ignore")
4
+
5
+ import os
6
+ import sys
7
+ import time
8
+ import math
9
+ import random
10
+ import datetime
11
+ import subprocess
12
+ from collections import defaultdict, deque
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.distributed as dist
17
+
18
+ import argparse
19
+
20
+ from numpy.random import randint
21
+
22
+
23
+ def GMML_replace_list(samples, corrup_prev, masks_prev, drop_type='noise', max_replace=0.35, align=16):
24
+ rep_drop = 1 if drop_type == '' else (1 / (len(drop_type.split('-')) + 1))
25
+
26
+ n_imgs = samples.size()[0] # this is batch size, but in case bad inistance happened while loading
27
+ samples_aug = samples.detach().clone()
28
+ masks = torch.zeros_like(samples_aug)
29
+ for i in range(n_imgs):
30
+ idx_rnd = randint(0, n_imgs)
31
+ if random.random() < rep_drop:
32
+ samples_aug[i], masks[i] = GMML_drop_rand_patches(samples_aug[i], samples[idx_rnd], max_replace=max_replace,
33
+ align=align)
34
+ else:
35
+ samples_aug[i], masks[i] = corrup_prev[i], masks_prev[i]
36
+
37
+ return samples_aug, masks
38
+
39
+
40
+ def GMML_drop_rand_patches(X, X_rep=None, drop_type='noise', max_replace=0.7, align=16, max_block_sz=0.3):
41
+ #######################
42
+ # max_replace: percentage of image to be replaced
43
+ # align: align corruption with the patch sizes
44
+ # max_block_sz: percentage of the maximum block to be dropped
45
+ #######################
46
+
47
+ np.random.seed()
48
+ C, H, W = X.size()
49
+ n_drop_pix = np.random.uniform(min(0.5, max_replace), max_replace) * H * W
50
+ mx_blk_height = int(H * max_block_sz)
51
+ mx_blk_width = int(W * max_block_sz)
52
+
53
+ align = max(1, align)
54
+
55
+ mask = torch.zeros_like(X)
56
+ drop_t = np.random.choice(drop_type.split('-'))
57
+
58
+ while mask[0].sum() < n_drop_pix:
59
+
60
+ ####### get a random block to replace
61
+ rnd_r = (randint(0, H - align) // align) * align
62
+ rnd_c = (randint(0, W - align) // align) * align
63
+
64
+ rnd_h = min(randint(align, mx_blk_height), H - rnd_r)
65
+ rnd_h = round(rnd_h / align) * align
66
+ rnd_w = min(randint(align, mx_blk_width), W - rnd_c)
67
+ rnd_w = round(rnd_w / align) * align
68
+
69
+ if X_rep is not None:
70
+ X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = X_rep[:, rnd_r:rnd_r + rnd_h,
71
+ rnd_c:rnd_c + rnd_w].detach().clone()
72
+ else:
73
+ if drop_t == 'noise':
74
+ X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = torch.empty((C, rnd_h, rnd_w), dtype=X.dtype,
75
+ device=X.device).normal_()
76
+ elif drop_t == 'zeros':
77
+ X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = torch.zeros((C, rnd_h, rnd_w), dtype=X.dtype,
78
+ device=X.device)
79
+ else:
80
+ ####### get a random block to replace from
81
+ rnd_r2 = (randint(0, H - rnd_h) // align) * align
82
+ rnd_c2 = (randint(0, W - rnd_w) // align) * align
83
+
84
+ X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = X[:, rnd_r2:rnd_r2 + rnd_h,
85
+ rnd_c2:rnd_c2 + rnd_w].detach().clone()
86
+
87
+ mask[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = 1
88
+
89
+ return X, mask
90
+
91
+
92
+ class collate_batch(object): # replace from other images
93
+ def __init__(self, drop_replace=0., drop_align=1):
94
+ self.drop_replace = drop_replace
95
+ self.drop_align = drop_align
96
+
97
+ def __call__(self, batch):
98
+ batch = torch.utils.data.dataloader.default_collate(batch)
99
+
100
+ if self.drop_replace > 0:
101
+ batch[0][1][0], batch[0][2][0] = GMML_replace_list(batch[0][0][0], batch[0][1][0], batch[0][2][0],
102
+ max_replace=self.drop_replace, align=self.drop_align)
103
+ batch[0][1][1], batch[0][2][1] = GMML_replace_list(batch[0][0][1], batch[0][1][1], batch[0][2][1],
104
+ max_replace=self.drop_replace, align=self.drop_align)
105
+
106
+ return batch
107
+
108
+
109
+ def clip_gradients(model, clip):
110
+ norms = []
111
+ for name, p in model.named_parameters():
112
+ if p.grad is not None:
113
+ param_norm = p.grad.data.norm(2)
114
+ norms.append(param_norm.item())
115
+ clip_coef = clip / (param_norm + 1e-6)
116
+ if clip_coef < 1:
117
+ p.grad.data.mul_(clip_coef)
118
+ return norms
119
+
120
+
121
+ def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
122
+ if epoch >= freeze_last_layer:
123
+ return
124
+ for n, p in model.named_parameters():
125
+ if "last_layer" in n:
126
+ p.grad = None
127
+
128
+
129
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
130
+ """
131
+ Re-start from checkpoint
132
+ """
133
+ if not os.path.isfile(ckp_path):
134
+ return
135
+ print("Found checkpoint at {}".format(ckp_path))
136
+
137
+ # open checkpoint file
138
+ checkpoint = torch.load(ckp_path, map_location="cpu")
139
+
140
+ # key is what to look for in the checkpoint file
141
+ # value is the object to load
142
+ # example: {'state_dict': model}
143
+ for key, value in kwargs.items():
144
+ if key in checkpoint and value is not None:
145
+ try:
146
+ msg = value.load_state_dict(checkpoint[key], strict=False)
147
+ print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
148
+ except TypeError:
149
+ try:
150
+ msg = value.load_state_dict(checkpoint[key])
151
+ print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
152
+ except ValueError:
153
+ print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
154
+ else:
155
+ print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
156
+
157
+ # re load variable important for the run
158
+ if run_variables is not None:
159
+ for var_name in run_variables:
160
+ if var_name in checkpoint:
161
+ run_variables[var_name] = checkpoint[var_name]
162
+
163
+
164
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
165
+ warmup_schedule = np.array([])
166
+ warmup_iters = warmup_epochs * niter_per_ep
167
+ if warmup_epochs > 0:
168
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
169
+
170
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
171
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
172
+
173
+ schedule = np.concatenate((warmup_schedule, schedule))
174
+ assert len(schedule) == epochs * niter_per_ep
175
+ return schedule
176
+
177
+
178
+ def bool_flag(s):
179
+ """
180
+ Parse boolean arguments from the command line.
181
+ """
182
+ FALSY_STRINGS = {"off", "false", "0"}
183
+ TRUTHY_STRINGS = {"on", "true", "1"}
184
+ if s.lower() in FALSY_STRINGS:
185
+ return False
186
+ elif s.lower() in TRUTHY_STRINGS:
187
+ return True
188
+ else:
189
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
190
+
191
+
192
+ def fix_random_seeds(seed=31):
193
+ """
194
+ Fix random seeds.
195
+ """
196
+ torch.manual_seed(seed)
197
+ torch.cuda.manual_seed_all(seed)
198
+ np.random.seed(seed)
199
+
200
+
201
+ class SmoothedValue(object):
202
+ """Track a series of values and provide access to smoothed values over a
203
+ window or the global series average.
204
+ """
205
+
206
+ def __init__(self, window_size=20, fmt=None):
207
+ if fmt is None:
208
+ fmt = "{median:.6f} ({global_avg:.6f})"
209
+ self.deque = deque(maxlen=window_size)
210
+ self.total = 0.0
211
+ self.count = 0
212
+ self.fmt = fmt
213
+
214
+ def update(self, value, n=1):
215
+ self.deque.append(value)
216
+ self.count += n
217
+ self.total += value * n
218
+
219
+ def synchronize_between_processes(self):
220
+ """
221
+ Warning: does not synchronize the deque!
222
+ """
223
+ if not is_dist_avail_and_initialized():
224
+ return
225
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
226
+ dist.barrier()
227
+ dist.all_reduce(t)
228
+ t = t.tolist()
229
+ self.count = int(t[0])
230
+ self.total = t[1]
231
+
232
+ @property
233
+ def median(self):
234
+ d = torch.tensor(list(self.deque))
235
+ return d.median().item()
236
+
237
+ @property
238
+ def avg(self):
239
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
240
+ return d.mean().item()
241
+
242
+ @property
243
+ def global_avg(self):
244
+ return self.total / self.count
245
+
246
+ @property
247
+ def max(self):
248
+ return max(self.deque)
249
+
250
+ @property
251
+ def value(self):
252
+ return self.deque[-1]
253
+
254
+ def __str__(self):
255
+ return self.fmt.format(
256
+ median=self.median,
257
+ avg=self.avg,
258
+ global_avg=self.global_avg,
259
+ max=self.max,
260
+ value=self.value)
261
+
262
+
263
+ def reduce_dict(input_dict, average=True):
264
+ """
265
+ Args:
266
+ input_dict (dict): all the values will be reduced
267
+ average (bool): whether to do average or sum
268
+ Reduce the values in the dictionary from all processes so that all processes
269
+ have the averaged results. Returns a dict with the same fields as
270
+ input_dict, after reduction.
271
+ """
272
+ world_size = get_world_size()
273
+ if world_size < 2:
274
+ return input_dict
275
+ with torch.no_grad():
276
+ names = []
277
+ values = []
278
+ # sort the keys so that they are consistent across processes
279
+ for k in sorted(input_dict.keys()):
280
+ names.append(k)
281
+ values.append(input_dict[k])
282
+ values = torch.stack(values, dim=0)
283
+ dist.all_reduce(values)
284
+ if average:
285
+ values /= world_size
286
+ reduced_dict = {k: v for k, v in zip(names, values)}
287
+ return reduced_dict
288
+
289
+
290
+ class MetricLogger(object):
291
+ def __init__(self, delimiter="\t"):
292
+ self.meters = defaultdict(SmoothedValue)
293
+ self.delimiter = delimiter
294
+
295
+ def update(self, **kwargs):
296
+ for k, v in kwargs.items():
297
+ if isinstance(v, torch.Tensor):
298
+ v = v.item()
299
+ assert isinstance(v, (float, int))
300
+ self.meters[k].update(v)
301
+
302
+ def __getattr__(self, attr):
303
+ if attr in self.meters:
304
+ return self.meters[attr]
305
+ if attr in self.__dict__:
306
+ return self.__dict__[attr]
307
+ raise AttributeError("'{}' object has no attribute '{}'".format(
308
+ type(self).__name__, attr))
309
+
310
+ def __str__(self):
311
+ loss_str = []
312
+ for name, meter in self.meters.items():
313
+ loss_str.append(
314
+ "{}: {}".format(name, str(meter))
315
+ )
316
+ return self.delimiter.join(loss_str)
317
+
318
+ def synchronize_between_processes(self):
319
+ for meter in self.meters.values():
320
+ meter.synchronize_between_processes()
321
+
322
+ def add_meter(self, name, meter):
323
+ self.meters[name] = meter
324
+
325
+ def log_every(self, iterable, print_freq, header=None):
326
+ i = 0
327
+ if not header:
328
+ header = ''
329
+ start_time = time.time()
330
+ end = time.time()
331
+ iter_time = SmoothedValue(fmt='{avg:.6f}')
332
+ data_time = SmoothedValue(fmt='{avg:.6f}')
333
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
334
+ if torch.cuda.is_available():
335
+ log_msg = self.delimiter.join([
336
+ header,
337
+ '[{0' + space_fmt + '}/{1}]',
338
+ 'eta: {eta}',
339
+ '{meters}',
340
+ 'time: {time}',
341
+ 'data: {data}',
342
+ 'max mem: {memory:.0f}'
343
+ ])
344
+ else:
345
+ log_msg = self.delimiter.join([
346
+ header,
347
+ '[{0' + space_fmt + '}/{1}]',
348
+ 'eta: {eta}',
349
+ '{meters}',
350
+ 'time: {time}',
351
+ 'data: {data}'
352
+ ])
353
+ MB = 1024.0 * 1024.0
354
+ for obj in iterable:
355
+ data_time.update(time.time() - end)
356
+ yield obj
357
+ iter_time.update(time.time() - end)
358
+ if i % print_freq == 0 or i == len(iterable) - 1:
359
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
360
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
361
+ if torch.cuda.is_available():
362
+ print(log_msg.format(
363
+ i, len(iterable), eta=eta_string,
364
+ meters=str(self),
365
+ time=str(iter_time), data=str(data_time),
366
+ memory=torch.cuda.max_memory_allocated() / MB))
367
+ else:
368
+ print(log_msg.format(
369
+ i, len(iterable), eta=eta_string,
370
+ meters=str(self),
371
+ time=str(iter_time), data=str(data_time)))
372
+ i += 1
373
+ end = time.time()
374
+ total_time = time.time() - start_time
375
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
376
+ print('{} Total time: {} ({:.6f} s / it)'.format(
377
+ header, total_time_str, total_time / len(iterable)))
378
+
379
+
380
+ def get_sha():
381
+ cwd = os.path.dirname(os.path.abspath(__file__))
382
+
383
+ def _run(command):
384
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
385
+
386
+ sha = 'N/A'
387
+ diff = "clean"
388
+ branch = 'N/A'
389
+ try:
390
+ sha = _run(['git', 'rev-parse', 'HEAD'])
391
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
392
+ diff = _run(['git', 'diff-index', 'HEAD'])
393
+ diff = "has uncommited changes" if diff else "clean"
394
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
395
+ except Exception:
396
+ pass
397
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
398
+ return message
399
+
400
+
401
+ def is_dist_avail_and_initialized():
402
+ if not dist.is_available():
403
+ return False
404
+ if not dist.is_initialized():
405
+ return False
406
+ return True
407
+
408
+
409
+ def get_world_size():
410
+ if not is_dist_avail_and_initialized():
411
+ return 1
412
+ return dist.get_world_size()
413
+
414
+
415
+ def get_rank():
416
+ if not is_dist_avail_and_initialized():
417
+ return 0
418
+ return dist.get_rank()
419
+
420
+
421
+ def is_main_process():
422
+ return get_rank() == 0
423
+
424
+
425
+ def save_on_master(*args, **kwargs):
426
+ if is_main_process():
427
+ torch.save(*args, **kwargs)
428
+
429
+
430
+ def setup_for_distributed(is_master):
431
+ """
432
+ This function disables printing when not in master process
433
+ """
434
+ import builtins as __builtin__
435
+ builtin_print = __builtin__.print
436
+
437
+ def print(*args, **kwargs):
438
+ force = kwargs.pop('force', False)
439
+ if is_master or force:
440
+ builtin_print(*args, **kwargs)
441
+
442
+ __builtin__.print = print
443
+
444
+
445
+ def init_distributed_mode(args):
446
+ # launched with torch.distributed.launch
447
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
448
+ args.rank = int(os.environ["RANK"])
449
+ args.world_size = int(os.environ['WORLD_SIZE'])
450
+ args.gpu = int(os.environ['LOCAL_RANK'])
451
+ # launched with submitit on a slurm cluster
452
+ elif 'SLURM_PROCID' in os.environ:
453
+ args.rank = int(os.environ['SLURM_PROCID'])
454
+ args.gpu = args.rank % torch.cuda.device_count()
455
+ elif torch.cuda.is_available():
456
+ print('Will run the code on one GPU.')
457
+ args.rank, args.gpu, args.world_size = 0, 0, 1
458
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
459
+ os.environ['MASTER_PORT'] = '29500'
460
+ else:
461
+ print('Does not support training without GPU.')
462
+ sys.exit(1)
463
+
464
+ args.distributed = True
465
+ dist.init_process_group(
466
+ backend="nccl",
467
+ init_method=args.dist_url,
468
+ world_size=args.world_size,
469
+ rank=args.rank,
470
+ )
471
+
472
+ torch.cuda.set_device(args.gpu)
473
+ print('| distributed init (rank {}): {}'.format(
474
+ args.rank, args.dist_url), flush=True)
475
+ dist.barrier()
476
+ setup_for_distributed(args.rank == 0)
477
+
478
+
479
+ def accuracy(output, target, topk=(1,)):
480
+ """Computes the accuracy over the k top predictions for the specified values of k"""
481
+ maxk = max(topk)
482
+ batch_size = target.size(0)
483
+ _, pred = output.topk(maxk, 1, True, True)
484
+ pred = pred.t()
485
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
486
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
487
+
488
+
489
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
490
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
491
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
492
+ def norm_cdf(x):
493
+ # Computes standard normal cumulative distribution function
494
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
495
+
496
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
497
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
498
+ "The distribution of values may be incorrect.",
499
+ stacklevel=2)
500
+
501
+ with torch.no_grad():
502
+ # Values are generated by using a truncated uniform distribution and
503
+ # then using the inverse CDF for the normal distribution.
504
+ # Get upper and lower cdf values
505
+ l = norm_cdf((a - mean) / std)
506
+ u = norm_cdf((b - mean) / std)
507
+
508
+ # Uniformly fill tensor with values from [l, u], then translate to
509
+ # [2l-1, 2u-1].
510
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
511
+
512
+ # Use inverse cdf transform for normal distribution to get truncated
513
+ # standard normal
514
+ tensor.erfinv_()
515
+
516
+ # Transform to proper mean, std
517
+ tensor.mul_(std * math.sqrt(2.))
518
+ tensor.add_(mean)
519
+
520
+ # Clamp to ensure it's in the proper range
521
+ tensor.clamp_(min=a, max=b)
522
+ return tensor
523
+
524
+
525
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
526
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
527
+
528
+
529
+ def get_params_groups(model):
530
+ regularized = []
531
+ not_regularized = []
532
+ for name, param in model.named_parameters():
533
+ if not param.requires_grad:
534
+ continue
535
+ # we do not regularize biases nor Norm parameters
536
+ if name.endswith(".bias") or len(param.shape) == 1:
537
+ not_regularized.append(param)
538
+ else:
539
+ regularized.append(param)
540
+ return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
models/asit/vision_transformer.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from models.asit.utils import trunc_normal_
7
+
8
+
9
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
10
+ if drop_prob == 0. or not training:
11
+ return x
12
+ keep_prob = 1 - drop_prob
13
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
14
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
15
+ random_tensor.floor_() # binarize
16
+ output = x.div(keep_prob) * random_tensor
17
+ return output
18
+
19
+
20
+ class DropPath(nn.Module):
21
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
22
+ """
23
+
24
+ def __init__(self, drop_prob=None):
25
+ super(DropPath, self).__init__()
26
+ self.drop_prob = drop_prob
27
+
28
+ def forward(self, x):
29
+ return drop_path(x, self.drop_prob, self.training)
30
+
31
+
32
+ class Mlp(nn.Module):
33
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
34
+ super().__init__()
35
+ out_features = out_features or in_features
36
+ hidden_features = hidden_features or in_features
37
+ self.fc1 = nn.Linear(in_features, hidden_features)
38
+ self.act = act_layer()
39
+ self.fc2 = nn.Linear(hidden_features, out_features)
40
+ self.drop = nn.Dropout(drop)
41
+
42
+ def forward(self, x):
43
+ x = self.fc1(x)
44
+ x = self.act(x)
45
+ x = self.drop(x)
46
+ x = self.fc2(x)
47
+ x = self.drop(x)
48
+ return x
49
+
50
+
51
+ class Attention(nn.Module):
52
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
53
+ super().__init__()
54
+ self.num_heads = num_heads
55
+ head_dim = dim // num_heads
56
+ self.scale = qk_scale or head_dim ** -0.5
57
+
58
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
59
+ self.attn_drop = nn.Dropout(attn_drop)
60
+ self.proj = nn.Linear(dim, dim)
61
+ self.proj_drop = nn.Dropout(proj_drop)
62
+
63
+ def forward(self, x):
64
+ B, N, C = x.shape
65
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
66
+ q, k, v = qkv[0], qkv[1], qkv[2]
67
+
68
+ attn = (q @ k.transpose(-2, -1)) * self.scale
69
+ attn = attn.softmax(dim=-1)
70
+ attn = self.attn_drop(attn)
71
+
72
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
73
+ x = self.proj(x)
74
+ x = self.proj_drop(x)
75
+ return x, attn
76
+
77
+
78
+ class Block(nn.Module):
79
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
80
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
81
+ super().__init__()
82
+ self.norm1 = norm_layer(dim)
83
+ self.attn = Attention(
84
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
85
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
86
+ self.norm2 = norm_layer(dim)
87
+ mlp_hidden_dim = int(dim * mlp_ratio)
88
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
89
+
90
+ def forward(self, x, return_attention=False):
91
+ y, attn = self.attn(self.norm1(x))
92
+ if return_attention:
93
+ return attn
94
+ x = x + self.drop_path(y)
95
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
96
+ return x
97
+
98
+
99
+ class PatchEmbed(nn.Module):
100
+ """ Image to Patch Embedding
101
+ """
102
+
103
+ def __init__(self, img_size=[1024, 128], patch_size=[16, 16], in_chans=3, embed_dim=768):
104
+ super().__init__()
105
+ num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
106
+ self.img_size = img_size
107
+ self.patch_size = patch_size
108
+ self.num_patches = num_patches
109
+
110
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
111
+
112
+ def forward(self, x):
113
+ B, C, H, W = x.shape
114
+ x = self.proj(x).flatten(2).transpose(1, 2)
115
+ return x
116
+
117
+
118
+ class VisionTransformer(nn.Module):
119
+ """ Vision Transformer """
120
+
121
+ def __init__(self, audio_size=[1024, 128], patch_size=[16, 16], in_chans=3, num_classes=0, embed_dim=768, depth=12,
122
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
123
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
124
+ super().__init__()
125
+ self.num_features = self.embed_dim = embed_dim
126
+ self.audio_size = audio_size
127
+ self.patch_size = patch_size
128
+ self.patch_embed = PatchEmbed(
129
+ img_size=audio_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
130
+ num_patches = self.patch_embed.num_patches
131
+
132
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
133
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
134
+ self.pos_drop = nn.Dropout(p=drop_rate)
135
+
136
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
137
+ self.blocks = nn.ModuleList([
138
+ Block(
139
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
140
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
141
+ for i in range(depth)])
142
+ self.norm = norm_layer(embed_dim)
143
+
144
+ # Classifier head
145
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
146
+
147
+ trunc_normal_(self.pos_embed, std=.02)
148
+ trunc_normal_(self.cls_token, std=.02)
149
+ self.apply(self._init_weights)
150
+
151
+ def _init_weights(self, m):
152
+ if isinstance(m, nn.Linear):
153
+ trunc_normal_(m.weight, std=.02)
154
+ if isinstance(m, nn.Linear) and m.bias is not None:
155
+ nn.init.constant_(m.bias, 0)
156
+ elif isinstance(m, nn.LayerNorm):
157
+ nn.init.constant_(m.bias, 0)
158
+ nn.init.constant_(m.weight, 1.0)
159
+
160
+ def interpolate_pos_encoding(self, x, w, h):
161
+ npatch = (w / 16) * (h / 16)
162
+ N = self.pos_embed.shape[1] - 1
163
+ if npatch == N:
164
+ return self.pos_embed
165
+
166
+ class_pos_embed = self.pos_embed[:, 0]
167
+ patch_pos_embed = self.pos_embed[:, 1:]
168
+
169
+ sz1 = w // self.patch_size[0]
170
+ sz2 = h // self.patch_size[0]
171
+
172
+ prev_sz1 = self.audio_size[0] // self.patch_size[0]
173
+ prev_sz2 = self.audio_size[1] // self.patch_size[1]
174
+ patch_pos_embed = torch.nn.functional.interpolate(
175
+ patch_pos_embed.transpose(1, 2).reshape(1, self.embed_dim, prev_sz1, prev_sz2), size=(sz1, sz2),
176
+ mode='bicubic', align_corners=False)
177
+
178
+ patch_pos_embed = patch_pos_embed.reshape(1, self.embed_dim, sz1 * sz2).transpose(1, 2)
179
+
180
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
181
+
182
+ def prepare_tokens(self, x):
183
+ B, nc, w, h = x.shape
184
+ x = self.patch_embed(x) # patch linear embedding
185
+
186
+ # add the [CLS] token to the embed patch tokens
187
+ cls_tokens = self.cls_token.expand(B, -1, -1)
188
+ x = torch.cat((cls_tokens, x), dim=1)
189
+
190
+ # add positional encoding to each token
191
+ x = x + self.interpolate_pos_encoding(x, w, h)
192
+ # x = x + self.pos_embed
193
+ return self.pos_drop(x)
194
+
195
+ def forward(self, x, classify=False):
196
+ x = x.permute(0, 1, 3, 2)
197
+ x = self.prepare_tokens(x)
198
+ for blk in self.blocks:
199
+ x = blk(x)
200
+ x = self.norm(x)
201
+ if classify == True:
202
+ return self.head(x[:, 0])
203
+ return x
204
+
205
+ def get_last_selfattention(self, x):
206
+ x = self.prepare_tokens(x)
207
+ for i, blk in enumerate(self.blocks):
208
+ if i < len(self.blocks) - 1:
209
+ x = blk(x)
210
+ else:
211
+ # return attention of the last block
212
+ return blk(x, return_attention=True)
213
+
214
+ def get_intermediate_layers(self, x, n=1):
215
+ x = self.prepare_tokens(x)
216
+ # we return the output tokens from the `n` last blocks
217
+ output = []
218
+ for i, blk in enumerate(self.blocks):
219
+ x = blk(x)
220
+ if len(self.blocks) - i <= n:
221
+ output.append(self.norm(x))
222
+ return output
223
+
224
+
225
+ def vit_tiny(patch_size=16, **kwargs):
226
+ model = VisionTransformer(
227
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
228
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
229
+ return model
230
+
231
+
232
+ def vit_small(patch_size=[16, 16], audio_size=[1024, 128], stride=[16, 16], **kwargs):
233
+ model = VisionTransformer(
234
+ patch_size=patch_size, audio_size=audio_size, stride=stride, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
235
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
236
+ return model
237
+
238
+
239
+ def vit_base(patch_size=[16, 16], audio_size=[1024, 128], stride=[16, 16], **kwargs):
240
+ model = VisionTransformer(
241
+ patch_size=patch_size, audio_size=audio_size, stride=stride, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
242
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
243
+ return model
244
+
245
+
246
+ class CLSHead(nn.Module):
247
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
248
+ bottleneck_dim=256):
249
+ super().__init__()
250
+ nlayers = max(nlayers, 1)
251
+ if nlayers == 1:
252
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
253
+ else:
254
+ layers = [nn.Linear(in_dim, hidden_dim)]
255
+ if use_bn:
256
+ layers.append(nn.BatchNorm1d(hidden_dim))
257
+ layers.append(nn.GELU())
258
+ for _ in range(nlayers - 2):
259
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
260
+ if use_bn:
261
+ layers.append(nn.BatchNorm1d(hidden_dim))
262
+ layers.append(nn.GELU())
263
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
264
+ self.mlp = nn.Sequential(*layers)
265
+ self.apply(self._init_weights)
266
+ self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)
267
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
268
+ self.last_layer.weight_g.data.fill_(1)
269
+
270
+ def _init_weights(self, m):
271
+ if isinstance(m, nn.Linear):
272
+ trunc_normal_(m.weight, std=.02)
273
+ if isinstance(m, nn.Linear) and m.bias is not None:
274
+ nn.init.constant_(m.bias, 0)
275
+
276
+ def forward(self, x):
277
+ x = self.mlp(x)
278
+ x = nn.functional.normalize(x, dim=-1, p=2)
279
+ return self.last_layer(x)
280
+
281
+
282
+ class RECHead(nn.Module):
283
+ def __init__(self, in_dim, audio_size, in_chans=3, patch_size=16):
284
+ super().__init__()
285
+
286
+ self.audio_size = audio_size
287
+ self.patch_size = patch_size
288
+
289
+ layers = [nn.Linear(in_dim, in_dim)]
290
+ layers.append(nn.GELU())
291
+ layers.append(nn.Linear(in_dim, in_dim))
292
+ layers.append(nn.GELU())
293
+ layers.append(nn.Linear(in_dim, in_dim))
294
+ layers.append(nn.GELU())
295
+
296
+ self.mlp = nn.Sequential(*layers)
297
+ self.apply(self._init_weights)
298
+
299
+ self.convTrans = nn.ConvTranspose2d(in_dim, in_chans, kernel_size=(patch_size, patch_size),
300
+ stride=(patch_size, patch_size))
301
+
302
+ def _init_weights(self, m):
303
+ if isinstance(m, nn.Linear):
304
+ trunc_normal_(m.weight, std=.02)
305
+ if isinstance(m, nn.Linear) and m.bias is not None:
306
+ nn.init.constant_(m.bias, 0)
307
+
308
+ def forward(self, x):
309
+ x = self.mlp(x)
310
+
311
+ x_rec = x.transpose(1, 2)
312
+ out_sz = (self.audio_size[0] // self.patch_size, self.audio_size[
313
+ 1] // self.patch_size) # tuple( ( int(math.sqrt(x_rec.size()[2])) , int(math.sqrt(x_rec.size()[2])) ) )
314
+ x_rec = self.convTrans(x_rec.unflatten(2, out_sz))
315
+
316
+ return x_rec
models/atstframe/ATSTF_wrapper.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
3
+
4
+ from models.atstframe.audio_transformer import FrameASTModel
5
+ from models.transformer_wrapper import BaseModelWrapper
6
+
7
+
8
+ class ATSTWrapper(BaseModelWrapper):
9
+ def __init__(self, atst_dropout=0.0) -> None:
10
+ super().__init__()
11
+ self.atst_mel = ATSTMel()
12
+ self.atst = FrameASTModel(atst_dropout=atst_dropout)
13
+ self.fake_length = torch.tensor([1001])
14
+ self.cls_embed = None
15
+
16
+ def mel_forward(self, x):
17
+ return self.atst_mel(x)
18
+
19
+ def forward(self, spec):
20
+ atst_x = self.atst.get_intermediate_layers(
21
+ spec,
22
+ self.fake_length.to(spec).repeat(len(spec)),
23
+ 1,
24
+ scene=False
25
+ )
26
+ return atst_x
27
+
28
+ def separate_params(self):
29
+ pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
30
+ for k, p in self.named_parameters():
31
+ if k in ['atst.mask_embed', 'atst.pos_embed', 'atst.patch_embed.patch_embed.weight',
32
+ 'atst.patch_embed.patch_embed.bias'] or "blocks.0." in k:
33
+ pt_params[0].append(p)
34
+ elif "blocks.1." in k:
35
+ pt_params[1].append(p)
36
+ elif "blocks.2." in k:
37
+ pt_params[2].append(p)
38
+ elif "blocks.3." in k:
39
+ pt_params[3].append(p)
40
+ elif "blocks.4." in k:
41
+ pt_params[4].append(p)
42
+ elif "blocks.5." in k:
43
+ pt_params[5].append(p)
44
+ elif "blocks.6." in k:
45
+ pt_params[6].append(p)
46
+ elif "blocks.7." in k:
47
+ pt_params[7].append(p)
48
+ elif "blocks.8" in k:
49
+ pt_params[8].append(p)
50
+ elif "blocks.9." in k:
51
+ pt_params[9].append(p)
52
+ elif "blocks.10." in k:
53
+ pt_params[10].append(p)
54
+ elif "blocks.11." in k or ".norm_frame." in k:
55
+ pt_params[11].append(p)
56
+ else:
57
+ raise ValueError(f"Check separate params for ATST! Unknown key: {k}")
58
+ return list(reversed(pt_params))
59
+
60
+
61
+ class ATSTMel(torch.nn.Module):
62
+ def __init__(self) -> None:
63
+ super().__init__()
64
+ self.mel_transform = MelSpectrogram(
65
+ 16000,
66
+ f_min=60,
67
+ f_max=7800,
68
+ hop_length=160,
69
+ win_length=1024,
70
+ n_fft=1024,
71
+ n_mels=64
72
+ )
73
+ self.amp_to_db = AmplitudeToDB(stype="power", top_db=80)
74
+ self.scaler = MinMax(min=-79.6482, max=50.6842)
75
+
76
+ def amp2db(self, spec):
77
+ return self.amp_to_db(spec).clamp(min=-50, max=80)
78
+
79
+ def forward(self, audio):
80
+ with torch.autocast(device_type="cuda", enabled=False):
81
+ spec = self.mel_transform(audio)
82
+ spec = self.scaler(self.amp2db(spec))
83
+ spec = spec.unsqueeze(1)
84
+ return spec
85
+
86
+
87
+ class CustomAudioTransform:
88
+ def __repr__(self):
89
+ return self.__class__.__name__ + '()'
90
+
91
+
92
+ class MinMax(CustomAudioTransform):
93
+ def __init__(self, min, max):
94
+ self.min = min
95
+ self.max = max
96
+
97
+ def __call__(self, input):
98
+ if self.min is None:
99
+ min_ = torch.min(input)
100
+ max_ = torch.max(input)
101
+ else:
102
+ min_ = self.min
103
+ max_ = self.max
104
+ input = (input - min_) / (max_ - min_) * 2. - 1.
105
+ return input
models/atstframe/audio_transformer.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+ import torch
5
+ from torch import nn
6
+
7
+ from .transformer import Block
8
+
9
+
10
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
11
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
12
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
13
+ def norm_cdf(x):
14
+ # Computes standard normal cumulative distribution function
15
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
16
+
17
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
18
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
19
+ "The distribution of values may be incorrect.",
20
+ stacklevel=2)
21
+
22
+ with torch.no_grad():
23
+ # Values are generated by using a truncated uniform distribution and
24
+ # then using the inverse CDF for the normal distribution.
25
+ # Get upper and lower cdf values
26
+ l = norm_cdf((a - mean) / std)
27
+ u = norm_cdf((b - mean) / std)
28
+
29
+ # Uniformly fill tensor with values from [l, u], then translate to
30
+ # [2l-1, 2u-1].
31
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
32
+
33
+ # Use inverse cdf transform for normal distribution to get truncated
34
+ # standard normal
35
+ tensor.erfinv_()
36
+
37
+ # Transform to proper mean, std
38
+ tensor.mul_(std * math.sqrt(2.))
39
+ tensor.add_(mean)
40
+
41
+ # Clamp to ensure it's in the proper range
42
+ tensor.clamp_(min=a, max=b)
43
+ return tensor
44
+
45
+
46
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
47
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
48
+
49
+
50
+ def get_num_patches(height=64, width=1001, patch_height=16, patch_width=16):
51
+ return (height // patch_height) * (width // patch_width)
52
+
53
+
54
+ from einops.layers.torch import Rearrange
55
+
56
+
57
+ class PatchEmbed_v2(nn.Module):
58
+ def __init__(self, patch_height=64, patch_width=4, embed_dim=768, input_dim=1):
59
+ super().__init__()
60
+ self.patch_height = patch_height
61
+ self.patch_width = patch_width
62
+ self.patch_maker = Rearrange('b c (h p1) (w p2) -> b (w h) (p1 p2 c)', p1=patch_height, p2=patch_width)
63
+ self.patch_embed = nn.Linear(patch_height * patch_width * input_dim, embed_dim)
64
+
65
+ def forward(self, melspec, length=None):
66
+ height = melspec.shape[2] - melspec.shape[2] % self.patch_height
67
+ width = melspec.shape[3] - melspec.shape[3] % self.patch_width
68
+ patch = self.patch_maker(melspec[:, :, :height, :width])
69
+ patch_embed = self.patch_embed(patch)
70
+
71
+ if length is not None:
72
+ patch_length = (torch.div(height, self.patch_height, rounding_mode='trunc')) * torch.div(
73
+ (length - length % self.patch_width), self.patch_width, rounding_mode='trunc')
74
+ else:
75
+ patch_length = None
76
+
77
+ return patch, patch_embed, patch_length
78
+
79
+
80
+ class FrameAST(nn.Module):
81
+ """ Vision Transformer """
82
+
83
+ def __init__(self, nprompt=0, spec_h=64, spec_w=1001, patch_w=16, patch_h=16, pos_type="cut", in_chans=1,
84
+ num_classes=0, embed_dim=768, depth=12,
85
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.,
86
+ drop_path_rate=0.0, norm_layer=nn.LayerNorm, **kwargs):
87
+ super().__init__()
88
+ self.num_features = self.embed_dim = embed_dim
89
+ self.spec_w = spec_w
90
+ self.spec_h = spec_h
91
+ self.embed_dim = embed_dim
92
+ self.patch_w = patch_w
93
+ self.patch_h = patch_h
94
+
95
+ self.pos_type = pos_type
96
+
97
+ self.patch_embed = PatchEmbed_v2(patch_h, patch_w, embed_dim)
98
+ self.mask_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
99
+
100
+ # hack
101
+ self.nprompt = nprompt
102
+ if self.nprompt > 0:
103
+ self.prompt_embed = nn.Parameter(torch.zeros(1, self.nprompt, self.embed_dim))
104
+ trunc_normal_(self.prompt_embed, std=.02)
105
+
106
+ num_patches = get_num_patches(spec_h, spec_w, patch_h, patch_w)
107
+ self.num_patches = num_patches
108
+
109
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
110
+ self.pos_drop = nn.Dropout(p=drop_rate)
111
+
112
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
113
+ self.blocks = nn.ModuleList([
114
+ Block(
115
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
116
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
117
+ for i in range(depth)])
118
+ self.norm_frame = norm_layer(embed_dim)
119
+
120
+ trunc_normal_(self.pos_embed, std=.02)
121
+ trunc_normal_(self.mask_embed, std=.02)
122
+ self.apply(self._init_weights)
123
+
124
+ def _init_weights(self, m):
125
+ if isinstance(m, nn.Linear):
126
+ trunc_normal_(m.weight, std=.02)
127
+ if isinstance(m, nn.Linear) and m.bias is not None:
128
+ nn.init.constant_(m.bias, 0)
129
+ elif isinstance(m, nn.LayerNorm):
130
+ nn.init.constant_(m.bias, 0)
131
+ nn.init.constant_(m.weight, 1.0)
132
+
133
+ def prepare_tokens(self, x, mask_index, length, mask=True):
134
+ B, nc, h, w = x.shape
135
+ mel_patches, x, patch_length = self.patch_embed(x, length) # patch linear embedding
136
+ B, T, C = x.shape
137
+
138
+ if (mask_index is not None) and mask:
139
+ mask_index_expand = mask_index.unsqueeze(2).expand(B, T, self.embed_dim).float()
140
+ x = (1 - mask_index_expand) * x + mask_index_expand * self.mask_embed.expand(B, T, C)
141
+
142
+ # add positional encoding to each token
143
+ if self.pos_type == "cut":
144
+ pos = self.pos_embed[:, 1:T + 1, :].expand(B, -1, -1)
145
+ x = x + pos
146
+ else:
147
+ pos = self.interpolate_pos_encoding(x, h, w)
148
+ x = x + pos[:, 1:]
149
+
150
+ # pos = self.pos_embed[:,1:T+1,:].expand(B,-1,-1)
151
+ # x = x + pos
152
+
153
+ return self.pos_drop(x), pos, mel_patches, h, w, patch_length
154
+
155
+ def forward(self, x, mask_index=None, mask_input=True, length=None):
156
+ x, pos, mel_patches, h, w, patch_length = self.prepare_tokens(x, mask_index, length, mask_input)
157
+
158
+ length_mask = torch.arange(mel_patches.shape[1]).to(x.device) < patch_length.unsqueeze(1)
159
+ length_mask = length_mask.to(x.device)
160
+ mask_index = mask_index & length_mask
161
+
162
+ if self.nprompt > 0:
163
+ x = torch.cat([self.prompt_embed.expand(x.shape[0], -1, -1), x], dim=1)
164
+
165
+ for i, blk in enumerate(self.blocks):
166
+ x = blk(x, patch_length + self.nprompt)
167
+
168
+ frame_repr = self.norm_frame(x)
169
+
170
+ return frame_repr[:, self.nprompt:][mask_index]
171
+
172
+ def interpolate_pos_encoding(self, x, h, w):
173
+ npatch = x.shape[1] - 1
174
+ N = self.pos_embed.shape[1] - 1
175
+ if npatch == N and w == self.spec_w and h == self.spec_h:
176
+ return self.pos_embed
177
+ class_pos_embed = self.pos_embed[:, 0]
178
+ patch_pos_embed = self.pos_embed[:, 1:]
179
+ dim = x.shape[-1]
180
+ w0 = w // self.patch_embed.patch_width
181
+ h0 = h // self.patch_embed.patch_height
182
+ # we add a small number to avoid floating point error in the interpolation
183
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
184
+ w0, h0 = w0 + 0.1, h0 + 0.1
185
+ patch_pos_embed = nn.functional.interpolate(
186
+ patch_pos_embed.reshape(1, self.spec_h // self.patch_h, self.spec_w // self.patch_w, dim).permute(0, 3, 1,
187
+ 2),
188
+ scale_factor=(h0 / (self.spec_h // self.patch_h), w0 / (self.spec_w // self.patch_w)),
189
+ mode='bicubic',
190
+ )
191
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
192
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
193
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
194
+
195
+ def get_last_selfattention(self, x):
196
+ x, _, _, _, _, _ = self.prepare_tokens(x, mask_index=None, length=None, mask=False)
197
+ atts = []
198
+ for i, blk in enumerate(self.blocks):
199
+ if i < len(self.blocks) - 1:
200
+ x, att = blk(x, return_attention=True)
201
+ atts.append(att)
202
+ else:
203
+ x, att = blk(x, return_attention=True)
204
+ atts.append(att)
205
+ return atts
206
+ # return attention of the last block
207
+
208
+ def get_intermediate_layers(self, x, length, n=1, scene=True, other_emb=None):
209
+ x, _, _, _, _, patch_length = self.prepare_tokens(x, mask_index=None, length=length, mask=False)
210
+ # we return the output tokens from the `n` last blocks
211
+ if other_emb is not None:
212
+ x = torch.cat([other_emb, x], dim=1)
213
+ output = []
214
+ if self.nprompt > 0:
215
+ x = torch.cat([self.prompt_embed.expand(x.shape[0], -1, -1), x], dim=1)
216
+ for i, blk in enumerate(self.blocks):
217
+ x = blk(x, patch_length + self.nprompt)
218
+ if len(self.blocks) - i <= n:
219
+ norm_x = self.norm_frame(x)
220
+ if scene:
221
+ length_mask = torch.arange(x.shape[1] - self.nprompt).to(x.device) < patch_length.unsqueeze(1)
222
+ avg = torch.sum(norm_x[:, self.nprompt:] * length_mask.unsqueeze(-1), dim=1) / (
223
+ patch_length.unsqueeze(-1) + 1e-6)
224
+ negative = (~length_mask) * -1e10
225
+ # max = torch.max(norm_x[:,self.nprompt:]+negative.unsqueeze(-1),1).values
226
+ output.append(avg)
227
+ if self.nprompt > 0:
228
+ output.append(torch.mean(norm_x[:, :self.nprompt], dim=1))
229
+ else:
230
+ output.append(norm_x[:, self.nprompt:])
231
+
232
+ return torch.cat(output, dim=-1)
233
+
234
+
235
+ def get_cls_avg(output_i, cur_len, use_cls):
236
+ length_mask = torch.arange(output_i[0].shape[1]).to(output_i[0].device) < cur_len.unsqueeze(1)
237
+ cls = [torch.zeros_like(x[:, 0]) for x in output_i]
238
+ avg = [torch.sum(x * length_mask.unsqueeze(-1), dim=1) / (cur_len.unsqueeze(1) + 1e-6) for x in output_i]
239
+ return cls, avg
240
+
241
+
242
+ def FrameASTModel(patch_h=64, patch_w=4, atst_dropout=0.1, **kwargs):
243
+ return FrameAST(
244
+ patch_h=patch_h,
245
+ patch_w=patch_w,
246
+ embed_dim=768,
247
+ depth=12,
248
+ num_heads=12,
249
+ qkv_bias=False,
250
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
251
+ drop_path_rate=atst_dropout,
252
+ drop_rate=atst_dropout,
253
+ **kwargs)
models/atstframe/transformer.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
6
+ if drop_prob == 0. or not training:
7
+ return x
8
+ keep_prob = 1 - drop_prob
9
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
10
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
11
+ random_tensor.floor_() # binarize
12
+ output = x.div(keep_prob) * random_tensor
13
+ return output
14
+
15
+
16
+ class DropPath(nn.Module):
17
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
18
+ """
19
+
20
+ def __init__(self, drop_prob=None):
21
+ super(DropPath, self).__init__()
22
+ self.drop_prob = drop_prob
23
+
24
+ def forward(self, x):
25
+ return drop_path(x, self.drop_prob, self.training)
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
30
+ super().__init__()
31
+ out_features = out_features or in_features
32
+ hidden_features = hidden_features or in_features
33
+ self.fc1 = nn.Linear(in_features, hidden_features)
34
+ self.act = act_layer()
35
+ self.fc2 = nn.Linear(hidden_features, out_features)
36
+ self.drop = nn.Dropout(drop)
37
+
38
+ def forward(self, x):
39
+ x = self.fc1(x)
40
+ x = self.act(x)
41
+ x = self.drop(x)
42
+ x = self.fc2(x)
43
+ x = self.drop(x)
44
+ return x
45
+
46
+
47
+ class Attention(nn.Module):
48
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
49
+ super().__init__()
50
+ self.num_heads = num_heads
51
+ head_dim = dim // num_heads
52
+ self.scale = qk_scale or head_dim ** -0.5
53
+
54
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
55
+ self.attn_drop = nn.Dropout(attn_drop)
56
+ self.proj = nn.Linear(dim, dim)
57
+ self.proj_drop = nn.Dropout(proj_drop)
58
+
59
+ def forward(self, x, mask):
60
+ B, N, C = x.shape
61
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
62
+ q, k, v = qkv[0], qkv[1], qkv[2]
63
+ attn = (q @ k.transpose(-2, -1)) * self.scale
64
+ if mask is not None:
65
+ attn += mask
66
+
67
+ attn = attn.softmax(dim=-1)
68
+ attn = self.attn_drop(attn)
69
+
70
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
71
+ x = self.proj(x)
72
+ x = self.proj_drop(x)
73
+ return x, attn
74
+
75
+
76
+ class Block(nn.Module):
77
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
78
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
79
+ super().__init__()
80
+ self.norm1 = norm_layer(dim)
81
+ self.attn = Attention(
82
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
83
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
84
+ self.norm2 = norm_layer(dim)
85
+ mlp_hidden_dim = int(dim * mlp_ratio)
86
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
87
+
88
+ def forward(self, x, length=None, return_attention=False):
89
+
90
+ # if length is not None:
91
+ # print(length)
92
+ # mask_att = get_attention_mask(x,length)
93
+ # else:
94
+ mask_att = None
95
+
96
+ y, attn = self.attn(self.norm1(x), mask_att)
97
+ x = x + self.drop_path(y)
98
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
99
+ if return_attention:
100
+ return x, attn
101
+ else:
102
+ return x
103
+
104
+
105
+ def get_attention_mask(x, length):
106
+ batch_size, max_len, _ = x.shape
107
+ # create mask for padded elements and zero-out them
108
+ mask = torch.arange(max_len, device=length.device).expand(batch_size, max_len) >= length[:, None]
109
+ # extend the mask to attention shape and set weight
110
+ mask = -10000.0 * mask[:, None, None, :]
111
+ mask = mask.expand(batch_size, 1, max_len, max_len).to(x.device)
112
+ return mask
models/beats/BEATs.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from models.beats.backbone import (
17
+ TransformerEncoder,
18
+ )
19
+
20
+ import logging
21
+ from typing import Optional
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class BEATsConfig:
27
+ def __init__(self, cfg=None):
28
+ self.input_patch_size: int = 16 # path size of patch embedding
29
+ self.embed_dim: int = 512 # patch embedding dimension
30
+ self.conv_bias: bool = False # include bias in conv encoder
31
+
32
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
33
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
34
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
35
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
36
+ self.activation_fn: str = "gelu" # activation function to use
37
+
38
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
39
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
40
+ self.deep_norm: bool = True # apply deep_norm first in the transformer
41
+
42
+ # dropouts
43
+ self.dropout: float = 0.1 # dropout probability for the transformer
44
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
45
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
46
+ self.encoder_layerdrop: float = 0.05 # probability of dropping a tarnsformer layer
47
+ self.dropout_input: float = 0.1 # dropout to apply to the input (after feat extr)
48
+
49
+ # positional embeddings
50
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
51
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
52
+
53
+ # relative position embedding
54
+ self.relative_position_embedding: bool = True # apply relative position embedding
55
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
56
+ self.max_distance: int = 800 # maximum distance for relative position embedding
57
+ self.gru_rel_pos: bool = True # apply gated relative position embedding
58
+
59
+ # label predictor
60
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
61
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
62
+ self.predictor_class: int = 527 # target class number for the predictor
63
+
64
+ if cfg is not None:
65
+ self.update(cfg)
66
+
67
+ def update(self, cfg: dict):
68
+ self.__dict__.update(cfg)
69
+
70
+
71
+ class BEATs(nn.Module):
72
+ def __init__(
73
+ self,
74
+ cfg: BEATsConfig,
75
+ ) -> None:
76
+ super().__init__()
77
+ logger.info(f"BEATs Config: {cfg.__dict__}")
78
+
79
+ self.cfg = cfg
80
+
81
+ self.embed = cfg.embed_dim
82
+ self.post_extract_proj = (
83
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
84
+ if self.embed != cfg.encoder_embed_dim
85
+ else None
86
+ )
87
+
88
+ self.input_patch_size = cfg.input_patch_size
89
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
90
+ bias=cfg.conv_bias)
91
+
92
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
93
+
94
+ assert not cfg.deep_norm or not cfg.layer_norm_first
95
+ self.encoder = TransformerEncoder(cfg)
96
+ self.layer_norm = LayerNorm(self.embed)
97
+
98
+ if cfg.finetuned_model:
99
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
100
+ self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
101
+ else:
102
+ self.predictor = None
103
+
104
+ def forward_padding_mask(
105
+ self,
106
+ features: torch.Tensor,
107
+ padding_mask: torch.Tensor,
108
+ ) -> torch.Tensor:
109
+ extra = padding_mask.size(1) % features.size(1)
110
+ if extra > 0:
111
+ padding_mask = padding_mask[:, :-extra]
112
+ padding_mask = padding_mask.view(
113
+ padding_mask.size(0), features.size(1), -1
114
+ )
115
+ padding_mask = padding_mask.all(-1)
116
+ return padding_mask
117
+
118
+ def preprocess(
119
+ self,
120
+ source: torch.Tensor,
121
+ fbank_mean: float = 15.41663,
122
+ fbank_std: float = 6.55582,
123
+ ) -> torch.Tensor:
124
+ fbanks = []
125
+ for waveform in source:
126
+ waveform = waveform.unsqueeze(0) * 2 ** 15
127
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
128
+ fbanks.append(fbank)
129
+ fbank = torch.stack(fbanks, dim=0)
130
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
131
+ return fbank
132
+
133
+ def extract_features(
134
+ self,
135
+ source: torch.Tensor,
136
+ padding_mask: Optional[torch.Tensor] = None,
137
+ fbank_mean: float = 15.41663,
138
+ fbank_std: float = 6.55582,
139
+ do_preprocess: bool = True,
140
+ ):
141
+ if do_preprocess:
142
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
143
+
144
+ if padding_mask is not None:
145
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
146
+
147
+ fbank = fbank.unsqueeze(1)
148
+ else:
149
+ fbank = source
150
+ features = self.patch_embedding(fbank)
151
+ features = features.reshape(features.shape[0], features.shape[1], -1)
152
+ features = features.transpose(1, 2)
153
+ features = self.layer_norm(features)
154
+
155
+ if padding_mask is not None:
156
+ padding_mask = self.forward_padding_mask(features, padding_mask)
157
+
158
+ if self.post_extract_proj is not None:
159
+ features = self.post_extract_proj(features)
160
+
161
+ x = self.dropout_input(features)
162
+
163
+ x, layer_results = self.encoder(
164
+ x,
165
+ padding_mask=padding_mask,
166
+ )
167
+
168
+ if self.predictor is not None:
169
+ x = self.predictor_dropout(x)
170
+ logits = self.predictor(x)
171
+
172
+ if padding_mask is not None and padding_mask.any():
173
+ logits[padding_mask] = 0
174
+ logits = logits.sum(dim=1)
175
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
176
+ else:
177
+ logits = logits.mean(dim=1)
178
+
179
+ lprobs = torch.sigmoid(logits)
180
+
181
+ return lprobs, padding_mask
182
+ else:
183
+ return x, padding_mask
models/beats/BEATs_wrapper.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from models.beats.BEATs import BEATsConfig, BEATs
4
+ from models.transformer_wrapper import BaseModelWrapper
5
+
6
+
7
+ class BEATsWrapper(BaseModelWrapper):
8
+ def __init__(self):
9
+ super().__init__()
10
+ cfg = BEATsConfig()
11
+ self.beats = BEATs(cfg)
12
+
13
+ def mel_forward(self, x):
14
+ with torch.autocast(device_type="cuda", enabled=False):
15
+ mel = self.beats.preprocess(x)
16
+ mel = mel.unsqueeze(1).transpose(2, 3)
17
+ return mel
18
+
19
+ def forward(self, x):
20
+ x = x.transpose(2, 3)
21
+ features = self.beats.extract_features(x, do_preprocess=False)[0]
22
+ return features
23
+
24
+ def separate_params(self):
25
+ pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
26
+ for k, p in self.named_parameters():
27
+ if ".layers.0." in k:
28
+ pt_params[0].append(p)
29
+ elif ".layers.1." in k:
30
+ pt_params[1].append(p)
31
+ elif ".layers.2." in k:
32
+ pt_params[2].append(p)
33
+ elif ".layers.3." in k:
34
+ pt_params[3].append(p)
35
+ elif ".layers.4." in k:
36
+ pt_params[4].append(p)
37
+ elif ".layers.5." in k:
38
+ pt_params[5].append(p)
39
+ elif ".layers.6." in k:
40
+ pt_params[6].append(p)
41
+ elif ".layers.7." in k:
42
+ pt_params[7].append(p)
43
+ elif ".layers.8." in k:
44
+ pt_params[8].append(p)
45
+ elif ".layers.9." in k:
46
+ pt_params[9].append(p)
47
+ elif ".layers.10." in k:
48
+ pt_params[10].append(p)
49
+ elif ".layers.11." in k:
50
+ pt_params[11].append(p)
51
+ elif (".post_extract_proj." in k or ".patch_embedding." in k or '.pos_conv.' in k
52
+ or 'beats.layer_norm.' in k or "beats.encoder.layer_norm." in k):
53
+ pt_params[0].append(p)
54
+ else:
55
+ raise ValueError(f"Check separate params for BEATs! Unknown key: {k}")
56
+ return list(reversed(pt_params))
models/beats/Tokenizers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from backbone import (
17
+ TransformerEncoder,
18
+ )
19
+ from quantizer import (
20
+ NormEMAVectorQuantizer,
21
+ )
22
+
23
+ import logging
24
+ from typing import Optional
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class TokenizersConfig:
30
+ def __init__(self, cfg=None):
31
+ self.input_patch_size: int = -1 # path size of patch embedding
32
+ self.embed_dim: int = 512 # patch embedding dimension
33
+ self.conv_bias: bool = False # include bias in conv encoder
34
+
35
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
36
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
37
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
38
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
39
+ self.activation_fn: str = "gelu" # activation function to use
40
+
41
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
42
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
43
+
44
+ # dropouts
45
+ self.dropout: float = 0.1 # dropout probability for the transformer
46
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
47
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
48
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
49
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
50
+
51
+ # positional embeddings
52
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
53
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
54
+
55
+ # relative position embedding
56
+ self.relative_position_embedding: bool = False # apply relative position embedding
57
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
58
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
59
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
60
+
61
+ # quantizer
62
+ self.quant_n: int = 1024 # codebook number in quantizer
63
+ self.quant_dim: int = 256 # codebook dimension in quantizer
64
+
65
+ if cfg is not None:
66
+ self.update(cfg)
67
+
68
+ def update(self, cfg: dict):
69
+ self.__dict__.update(cfg)
70
+
71
+
72
+ class Tokenizers(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: TokenizersConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
85
+ if self.embed != cfg.encoder_embed_dim
86
+ else None
87
+ )
88
+
89
+ self.input_patch_size = cfg.input_patch_size
90
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
91
+ bias=cfg.conv_bias)
92
+
93
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
94
+
95
+ assert not cfg.deep_norm or not cfg.layer_norm_first
96
+ self.encoder = TransformerEncoder(cfg)
97
+ self.layer_norm = LayerNorm(self.embed)
98
+
99
+ self.quantize = NormEMAVectorQuantizer(
100
+ n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
101
+ )
102
+ self.quant_n = cfg.quant_n
103
+ self.quantize_layer = nn.Sequential(
104
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
105
+ nn.Tanh(),
106
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
107
+ )
108
+
109
+ def forward_padding_mask(
110
+ self,
111
+ features: torch.Tensor,
112
+ padding_mask: torch.Tensor,
113
+ ) -> torch.Tensor:
114
+ extra = padding_mask.size(1) % features.size(1)
115
+ if extra > 0:
116
+ padding_mask = padding_mask[:, :-extra]
117
+ padding_mask = padding_mask.view(
118
+ padding_mask.size(0), features.size(1), -1
119
+ )
120
+ padding_mask = padding_mask.all(-1)
121
+ return padding_mask
122
+
123
+ def preprocess(
124
+ self,
125
+ source: torch.Tensor,
126
+ fbank_mean: float = 15.41663,
127
+ fbank_std: float = 6.55582,
128
+ ) -> torch.Tensor:
129
+ fbanks = []
130
+ for waveform in source:
131
+ waveform = waveform.unsqueeze(0) * 2 ** 15
132
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
133
+ fbanks.append(fbank)
134
+ fbank = torch.stack(fbanks, dim=0)
135
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
136
+ return fbank
137
+
138
+ def extract_labels(
139
+ self,
140
+ source: torch.Tensor,
141
+ padding_mask: Optional[torch.Tensor] = None,
142
+ fbank_mean: float = 15.41663,
143
+ fbank_std: float = 6.55582,
144
+ ):
145
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
146
+
147
+ if padding_mask is not None:
148
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
149
+
150
+ fbank = fbank.unsqueeze(1)
151
+ features = self.patch_embedding(fbank)
152
+ features = features.reshape(features.shape[0], features.shape[1], -1)
153
+ features = features.transpose(1, 2)
154
+ features = self.layer_norm(features)
155
+
156
+ if padding_mask is not None:
157
+ padding_mask = self.forward_padding_mask(features, padding_mask)
158
+
159
+ if self.post_extract_proj is not None:
160
+ features = self.post_extract_proj(features)
161
+
162
+ x = self.dropout_input(features)
163
+
164
+ x, layer_results = self.encoder(
165
+ x,
166
+ padding_mask=padding_mask,
167
+ )
168
+
169
+ quantize_input = self.quantize_layer(x)
170
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
+
172
+ return embed_ind
models/beats/backbone.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import numpy as np
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ import torch.nn.functional as F
16
+ from torch.nn import LayerNorm, Parameter
17
+ from models.beats.modules import (
18
+ GradMultiply,
19
+ SamePad,
20
+ get_activation_fn,
21
+ GLU_Linear,
22
+ quant_noise,
23
+ )
24
+
25
+
26
+ class TransformerEncoder(nn.Module):
27
+ def __init__(self, args):
28
+ super().__init__()
29
+
30
+ self.dropout = args.dropout
31
+ self.embedding_dim = args.encoder_embed_dim
32
+
33
+ self.pos_conv = nn.Conv1d(
34
+ self.embedding_dim,
35
+ self.embedding_dim,
36
+ kernel_size=args.conv_pos,
37
+ padding=args.conv_pos // 2,
38
+ groups=args.conv_pos_groups,
39
+ )
40
+ dropout = 0
41
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
42
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
43
+ nn.init.constant_(self.pos_conv.bias, 0)
44
+
45
+ self.pos_conv = torch.nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
46
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
47
+
48
+ if hasattr(args, "relative_position_embedding"):
49
+ self.relative_position_embedding = args.relative_position_embedding
50
+ self.num_buckets = args.num_buckets
51
+ self.max_distance = args.max_distance
52
+ else:
53
+ self.relative_position_embedding = False
54
+ self.num_buckets = 0
55
+ self.max_distance = 0
56
+
57
+ self.layers = nn.ModuleList(
58
+ [
59
+ TransformerSentenceEncoderLayer(
60
+ embedding_dim=self.embedding_dim,
61
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
62
+ num_attention_heads=args.encoder_attention_heads,
63
+ dropout=self.dropout,
64
+ attention_dropout=args.attention_dropout,
65
+ activation_dropout=args.activation_dropout,
66
+ activation_fn=args.activation_fn,
67
+ layer_norm_first=args.layer_norm_first,
68
+ deep_norm=args.deep_norm,
69
+ has_relative_attention_bias=self.relative_position_embedding,
70
+ num_buckets=self.num_buckets,
71
+ max_distance=self.max_distance,
72
+ gru_rel_pos=args.gru_rel_pos,
73
+ encoder_layers=args.encoder_layers,
74
+ )
75
+ for i in range(args.encoder_layers)
76
+ ]
77
+ )
78
+ if self.relative_position_embedding:
79
+ for i in range(1, args.encoder_layers):
80
+ del self.layers[i].self_attn.relative_attention_bias
81
+ self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
82
+
83
+ self.layer_norm_first = args.layer_norm_first
84
+ self.layer_norm = LayerNorm(self.embedding_dim)
85
+ self.layerdrop = args.encoder_layerdrop
86
+
87
+ self.apply(init_bert_params)
88
+
89
+ if args.deep_norm:
90
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
91
+ for i in range(args.encoder_layers):
92
+ nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
93
+ nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
94
+ nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
95
+ nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
96
+ nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
97
+ nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
98
+
99
+ self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
100
+
101
+ def forward(self, x, padding_mask=None, layer=None):
102
+ x, layer_results = self.extract_features(x, padding_mask, layer)
103
+
104
+ if self.layer_norm_first and layer is None:
105
+ x = self.layer_norm(x)
106
+
107
+ return x, layer_results
108
+
109
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
110
+
111
+ if padding_mask is not None:
112
+ x[padding_mask] = 0
113
+
114
+ x_conv = self.pos_conv(x.transpose(1, 2))
115
+ x_conv = x_conv.transpose(1, 2)
116
+ x = x + x_conv
117
+
118
+ if not self.layer_norm_first:
119
+ x = self.layer_norm(x)
120
+
121
+ x = F.dropout(x, p=self.dropout, training=self.training)
122
+
123
+ # B x T x C -> T x B x C
124
+ x = x.transpose(0, 1)
125
+
126
+ layer_results = []
127
+ z = None
128
+ if tgt_layer is not None:
129
+ layer_results.append((x, z))
130
+ r = None
131
+ pos_bias = None
132
+ for i, layer in enumerate(self.layers):
133
+ if self.layer_wise_gradient_decay_ratio != 1.0:
134
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
135
+ dropout_probability = np.random.random()
136
+ if not self.training or (dropout_probability > self.layerdrop):
137
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
138
+ if tgt_layer is not None:
139
+ layer_results.append((x, z))
140
+ if i == tgt_layer:
141
+ r = x
142
+ break
143
+
144
+ if r is not None:
145
+ x = r
146
+
147
+ # T x B x C -> B x T x C
148
+ x = x.transpose(0, 1)
149
+
150
+ return x, layer_results
151
+
152
+
153
+ class TransformerSentenceEncoderLayer(nn.Module):
154
+ def __init__(
155
+ self,
156
+ embedding_dim: float = 768,
157
+ ffn_embedding_dim: float = 3072,
158
+ num_attention_heads: float = 8,
159
+ dropout: float = 0.1,
160
+ attention_dropout: float = 0.1,
161
+ activation_dropout: float = 0.1,
162
+ activation_fn: str = "relu",
163
+ layer_norm_first: bool = False,
164
+ deep_norm: bool = False,
165
+ has_relative_attention_bias: bool = False,
166
+ num_buckets: int = 0,
167
+ max_distance: int = 0,
168
+ rescale_init: bool = False,
169
+ gru_rel_pos: bool = False,
170
+ encoder_layers: int = 0,
171
+ ) -> None:
172
+
173
+ super().__init__()
174
+ self.embedding_dim = embedding_dim
175
+ self.dropout = dropout
176
+ self.activation_dropout = activation_dropout
177
+
178
+ self.activation_name = activation_fn
179
+ self.activation_fn = get_activation_fn(activation_fn)
180
+ self.self_attn = MultiheadAttention(
181
+ self.embedding_dim,
182
+ num_attention_heads,
183
+ dropout=attention_dropout,
184
+ self_attention=True,
185
+ has_relative_attention_bias=has_relative_attention_bias,
186
+ num_buckets=num_buckets,
187
+ max_distance=max_distance,
188
+ rescale_init=rescale_init,
189
+ gru_rel_pos=gru_rel_pos,
190
+ )
191
+
192
+ self.dropout1 = nn.Dropout(dropout)
193
+ self.dropout2 = nn.Dropout(self.activation_dropout)
194
+ self.dropout3 = nn.Dropout(dropout)
195
+
196
+ self.layer_norm_first = layer_norm_first
197
+
198
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
199
+
200
+ if self.activation_name == "glu":
201
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
202
+ else:
203
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
204
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
205
+
206
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
207
+
208
+ self.deep_norm = deep_norm
209
+ if self.deep_norm:
210
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
211
+ else:
212
+ self.deep_norm_alpha = 1
213
+
214
+ def forward(
215
+ self,
216
+ x: torch.Tensor,
217
+ self_attn_mask: torch.Tensor = None,
218
+ self_attn_padding_mask: torch.Tensor = None,
219
+ need_weights: bool = False,
220
+ pos_bias=None
221
+ ):
222
+ residual = x
223
+
224
+ if self.layer_norm_first:
225
+ x = self.self_attn_layer_norm(x)
226
+ x, attn, pos_bias = self.self_attn(
227
+ query=x,
228
+ key=x,
229
+ value=x,
230
+ key_padding_mask=self_attn_padding_mask,
231
+ need_weights=False,
232
+ attn_mask=self_attn_mask,
233
+ position_bias=pos_bias
234
+ )
235
+ x = self.dropout1(x)
236
+ x = residual + x
237
+
238
+ residual = x
239
+ x = self.final_layer_norm(x)
240
+ if self.activation_name == "glu":
241
+ x = self.fc1(x)
242
+ else:
243
+ x = self.activation_fn(self.fc1(x))
244
+ x = self.dropout2(x)
245
+ x = self.fc2(x)
246
+ x = self.dropout3(x)
247
+ x = residual + x
248
+ else:
249
+ x, attn, pos_bias = self.self_attn(
250
+ query=x,
251
+ key=x,
252
+ value=x,
253
+ key_padding_mask=self_attn_padding_mask,
254
+ need_weights=need_weights,
255
+ attn_mask=self_attn_mask,
256
+ position_bias=pos_bias
257
+ )
258
+
259
+ x = self.dropout1(x)
260
+ x = residual * self.deep_norm_alpha + x
261
+
262
+ x = self.self_attn_layer_norm(x)
263
+
264
+ residual = x
265
+ if self.activation_name == "glu":
266
+ x = self.fc1(x)
267
+ else:
268
+ x = self.activation_fn(self.fc1(x))
269
+ x = self.dropout2(x)
270
+ x = self.fc2(x)
271
+ x = self.dropout3(x)
272
+ x = residual * self.deep_norm_alpha + x
273
+ x = self.final_layer_norm(x)
274
+
275
+ return x, attn, pos_bias
276
+
277
+
278
+ class MultiheadAttention(nn.Module):
279
+ """Multi-headed attention.
280
+
281
+ See "Attention Is All You Need" for more details.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ embed_dim,
287
+ num_heads,
288
+ kdim=None,
289
+ vdim=None,
290
+ dropout=0.0,
291
+ bias=True,
292
+ add_bias_kv=False,
293
+ add_zero_attn=False,
294
+ self_attention=False,
295
+ encoder_decoder_attention=False,
296
+ q_noise=0.0,
297
+ qn_block_size=8,
298
+ has_relative_attention_bias=False,
299
+ num_buckets=32,
300
+ max_distance=128,
301
+ gru_rel_pos=False,
302
+ rescale_init=False,
303
+ ):
304
+ super().__init__()
305
+ self.embed_dim = embed_dim
306
+ self.kdim = kdim if kdim is not None else embed_dim
307
+ self.vdim = vdim if vdim is not None else embed_dim
308
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
309
+
310
+ self.num_heads = num_heads
311
+ self.dropout_module = nn.Dropout(dropout)
312
+
313
+ self.has_relative_attention_bias = has_relative_attention_bias
314
+ self.num_buckets = num_buckets
315
+ self.max_distance = max_distance
316
+ if self.has_relative_attention_bias:
317
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
318
+
319
+ self.head_dim = embed_dim // num_heads
320
+ self.q_head_dim = self.head_dim
321
+ self.k_head_dim = self.head_dim
322
+ assert (
323
+ self.head_dim * num_heads == self.embed_dim
324
+ ), "embed_dim must be divisible by num_heads"
325
+ self.scaling = self.head_dim ** -0.5
326
+
327
+ self.self_attention = self_attention
328
+ self.encoder_decoder_attention = encoder_decoder_attention
329
+
330
+ assert not self.self_attention or self.qkv_same_dim, (
331
+ "Self-attention requires query, key and " "value to be of the same size"
332
+ )
333
+
334
+ k_bias = True
335
+ if rescale_init:
336
+ k_bias = False
337
+
338
+ k_embed_dim = embed_dim
339
+ q_embed_dim = embed_dim
340
+
341
+ self.k_proj = quant_noise(
342
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
343
+ )
344
+ self.v_proj = quant_noise(
345
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
346
+ )
347
+ self.q_proj = quant_noise(
348
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
349
+ )
350
+
351
+ self.out_proj = quant_noise(
352
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
353
+ )
354
+
355
+ if add_bias_kv:
356
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
357
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
358
+ else:
359
+ self.bias_k = self.bias_v = None
360
+
361
+ self.add_zero_attn = add_zero_attn
362
+
363
+ self.gru_rel_pos = gru_rel_pos
364
+ if self.gru_rel_pos:
365
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
366
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
367
+
368
+ self.reset_parameters()
369
+
370
+ def reset_parameters(self):
371
+ if self.qkv_same_dim:
372
+ # Empirically observed the convergence to be much better with
373
+ # the scaled initialization
374
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
375
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
376
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
377
+ else:
378
+ nn.init.xavier_uniform_(self.k_proj.weight)
379
+ nn.init.xavier_uniform_(self.v_proj.weight)
380
+ nn.init.xavier_uniform_(self.q_proj.weight)
381
+
382
+ nn.init.xavier_uniform_(self.out_proj.weight)
383
+ if self.out_proj.bias is not None:
384
+ nn.init.constant_(self.out_proj.bias, 0.0)
385
+ if self.bias_k is not None:
386
+ nn.init.xavier_normal_(self.bias_k)
387
+ if self.bias_v is not None:
388
+ nn.init.xavier_normal_(self.bias_v)
389
+ if self.has_relative_attention_bias:
390
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
391
+
392
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
393
+ num_buckets = self.num_buckets
394
+ max_distance = self.max_distance
395
+ relative_buckets = 0
396
+
397
+ if bidirectional:
398
+ num_buckets = num_buckets // 2
399
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
400
+ relative_positions = torch.abs(relative_positions)
401
+ else:
402
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
403
+
404
+ max_exact = num_buckets // 2
405
+ is_small = relative_positions < max_exact
406
+
407
+ relative_postion_if_large = max_exact + (
408
+ torch.log(relative_positions.float() / max_exact)
409
+ / math.log(max_distance / max_exact)
410
+ * (num_buckets - max_exact)
411
+ ).to(torch.long)
412
+ relative_postion_if_large = torch.min(
413
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
414
+ )
415
+
416
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
417
+ return relative_buckets
418
+
419
+ def compute_bias(self, query_length, key_length):
420
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
421
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
422
+ relative_position = memory_position - context_position
423
+ relative_position_bucket = self._relative_positions_bucket(
424
+ relative_position,
425
+ bidirectional=True
426
+ )
427
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
428
+ values = self.relative_attention_bias(relative_position_bucket)
429
+ values = values.permute([2, 0, 1])
430
+ return values
431
+
432
+ def forward(
433
+ self,
434
+ query,
435
+ key: Optional[Tensor],
436
+ value: Optional[Tensor],
437
+ key_padding_mask: Optional[Tensor] = None,
438
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
439
+ need_weights: bool = True,
440
+ static_kv: bool = False,
441
+ attn_mask: Optional[Tensor] = None,
442
+ before_softmax: bool = False,
443
+ need_head_weights: bool = False,
444
+ position_bias: Optional[Tensor] = None
445
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
446
+ """Input shape: Time x Batch x Channel
447
+
448
+ Args:
449
+ key_padding_mask (ByteTensor, optional): mask to exclude
450
+ keys that are pads, of shape `(batch, src_len)`, where
451
+ padding elements are indicated by 1s.
452
+ need_weights (bool, optional): return the attention weights,
453
+ averaged over heads (default: False).
454
+ attn_mask (ByteTensor, optional): typically used to
455
+ implement causal attention, where the mask prevents the
456
+ attention from looking forward in time (default: None).
457
+ before_softmax (bool, optional): return the raw attention
458
+ weights and values before the attention softmax.
459
+ need_head_weights (bool, optional): return the attention
460
+ weights for each head. Implies *need_weights*. Default:
461
+ return the average attention weights over all heads.
462
+ """
463
+ if need_head_weights:
464
+ need_weights = True
465
+
466
+ is_tpu = query.device.type == "xla"
467
+
468
+ tgt_len, bsz, embed_dim = query.size()
469
+ src_len = tgt_len
470
+ assert embed_dim == self.embed_dim
471
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
472
+ if key is not None:
473
+ src_len, key_bsz, _ = key.size()
474
+ if not torch.jit.is_scripting():
475
+ assert key_bsz == bsz
476
+ assert value is not None
477
+ assert src_len, bsz == value.shape[:2]
478
+
479
+ if self.has_relative_attention_bias and position_bias is None:
480
+ position_bias = self.compute_bias(tgt_len, src_len)
481
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
482
+
483
+ if incremental_state is not None:
484
+ saved_state = self._get_input_buffer(incremental_state)
485
+ if saved_state is not None and "prev_key" in saved_state:
486
+ # previous time steps are cached - no need to recompute
487
+ # key and value if they are static
488
+ if static_kv:
489
+ assert self.encoder_decoder_attention and not self.self_attention
490
+ key = value = None
491
+ else:
492
+ saved_state = None
493
+
494
+ if self.self_attention:
495
+ q = self.q_proj(query)
496
+ k = self.k_proj(query)
497
+ v = self.v_proj(query)
498
+ elif self.encoder_decoder_attention:
499
+ # encoder-decoder attention
500
+ q = self.q_proj(query)
501
+ if key is None:
502
+ assert value is None
503
+ k = v = None
504
+ else:
505
+ k = self.k_proj(key)
506
+ v = self.v_proj(key)
507
+
508
+ else:
509
+ assert key is not None and value is not None
510
+ q = self.q_proj(query)
511
+ k = self.k_proj(key)
512
+ v = self.v_proj(value)
513
+ q *= self.scaling
514
+ alpha = 32
515
+ q *= 1 / alpha
516
+
517
+ if self.bias_k is not None:
518
+ assert self.bias_v is not None
519
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
520
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
521
+ if attn_mask is not None:
522
+ attn_mask = torch.cat(
523
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
524
+ )
525
+ if key_padding_mask is not None:
526
+ key_padding_mask = torch.cat(
527
+ [
528
+ key_padding_mask,
529
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
530
+ ],
531
+ dim=1,
532
+ )
533
+
534
+ q = (
535
+ q.contiguous()
536
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
537
+ .transpose(0, 1)
538
+ )
539
+ if k is not None:
540
+ k = (
541
+ k.contiguous()
542
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
543
+ .transpose(0, 1)
544
+ )
545
+ if v is not None:
546
+ v = (
547
+ v.contiguous()
548
+ .view(-1, bsz * self.num_heads, self.head_dim)
549
+ .transpose(0, 1)
550
+ )
551
+
552
+ if saved_state is not None:
553
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
554
+ if "prev_key" in saved_state:
555
+ _prev_key = saved_state["prev_key"]
556
+ assert _prev_key is not None
557
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
558
+ if static_kv:
559
+ k = prev_key
560
+ else:
561
+ assert k is not None
562
+ k = torch.cat([prev_key, k], dim=1)
563
+ src_len = k.size(1)
564
+ if "prev_value" in saved_state:
565
+ _prev_value = saved_state["prev_value"]
566
+ assert _prev_value is not None
567
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
568
+ if static_kv:
569
+ v = prev_value
570
+ else:
571
+ assert v is not None
572
+ v = torch.cat([prev_value, v], dim=1)
573
+ prev_key_padding_mask: Optional[Tensor] = None
574
+ if "prev_key_padding_mask" in saved_state:
575
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
576
+ assert k is not None and v is not None
577
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
578
+ key_padding_mask=key_padding_mask,
579
+ prev_key_padding_mask=prev_key_padding_mask,
580
+ batch_size=bsz,
581
+ src_len=k.size(1),
582
+ static_kv=static_kv,
583
+ )
584
+
585
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
586
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
587
+ saved_state["prev_key_padding_mask"] = key_padding_mask
588
+ # In this branch incremental_state is never None
589
+ assert incremental_state is not None
590
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
591
+ assert k is not None
592
+ assert k.size(1) == src_len
593
+
594
+ # This is part of a workaround to get around fork/join parallelism
595
+ # not supporting Optional types.
596
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
597
+ key_padding_mask = None
598
+
599
+ if key_padding_mask is not None:
600
+ assert key_padding_mask.size(0) == bsz
601
+ assert key_padding_mask.size(1) == src_len
602
+
603
+ if self.add_zero_attn:
604
+ assert v is not None
605
+ src_len += 1
606
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
607
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
608
+ if attn_mask is not None:
609
+ attn_mask = torch.cat(
610
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
611
+ )
612
+ if key_padding_mask is not None:
613
+ key_padding_mask = torch.cat(
614
+ [
615
+ key_padding_mask,
616
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
617
+ key_padding_mask
618
+ ),
619
+ ],
620
+ dim=1,
621
+ )
622
+
623
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
624
+ attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
625
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
626
+
627
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
628
+
629
+ if attn_mask is not None:
630
+ attn_mask = attn_mask.unsqueeze(0)
631
+ attn_weights += attn_mask
632
+
633
+ if key_padding_mask is not None:
634
+ # don't attend to padding symbols
635
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
636
+ if not is_tpu:
637
+ attn_weights = attn_weights.masked_fill(
638
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
639
+ float("-inf"),
640
+ )
641
+ else:
642
+ attn_weights = attn_weights.transpose(0, 2)
643
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
644
+ attn_weights = attn_weights.transpose(0, 2)
645
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
646
+
647
+ if before_softmax:
648
+ return attn_weights, v, position_bias
649
+
650
+ if position_bias is not None:
651
+ attn_mask_rel_pos = position_bias
652
+ if self.gru_rel_pos == 1:
653
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
654
+ _B, _H, _L, __ = query_layer.size()
655
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
656
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
657
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
658
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
659
+
660
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
661
+
662
+ attn_weights = attn_weights + attn_mask_rel_pos
663
+
664
+ attn_weights_float = F.softmax(
665
+ attn_weights, dim=-1
666
+ )
667
+ attn_weights = attn_weights_float.type_as(attn_weights)
668
+ attn_probs = self.dropout_module(attn_weights)
669
+
670
+ assert v is not None
671
+ attn = torch.bmm(attn_probs, v)
672
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
673
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
674
+ attn = self.out_proj(attn)
675
+ attn_weights: Optional[Tensor] = None
676
+ if need_weights:
677
+ attn_weights = attn_weights_float.view(
678
+ bsz, self.num_heads, tgt_len, src_len
679
+ ).transpose(1, 0)
680
+ if not need_head_weights:
681
+ # average attention weights over heads
682
+ attn_weights = attn_weights.mean(dim=0)
683
+
684
+ return attn, attn_weights, position_bias
685
+
686
+ @staticmethod
687
+ def _append_prev_key_padding_mask(
688
+ key_padding_mask: Optional[Tensor],
689
+ prev_key_padding_mask: Optional[Tensor],
690
+ batch_size: int,
691
+ src_len: int,
692
+ static_kv: bool,
693
+ ) -> Optional[Tensor]:
694
+ # saved key padding masks have shape (bsz, seq_len)
695
+ if prev_key_padding_mask is not None and static_kv:
696
+ new_key_padding_mask = prev_key_padding_mask
697
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
698
+ new_key_padding_mask = torch.cat(
699
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
700
+ )
701
+ # During incremental decoding, as the padding token enters and
702
+ # leaves the frame, there will be a time when prev or current
703
+ # is None
704
+ elif prev_key_padding_mask is not None:
705
+ if src_len > prev_key_padding_mask.size(1):
706
+ filler = torch.zeros(
707
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
708
+ device=prev_key_padding_mask.device,
709
+ )
710
+ new_key_padding_mask = torch.cat(
711
+ [prev_key_padding_mask.float(), filler.float()], dim=1
712
+ )
713
+ else:
714
+ new_key_padding_mask = prev_key_padding_mask.float()
715
+ elif key_padding_mask is not None:
716
+ if src_len > key_padding_mask.size(1):
717
+ filler = torch.zeros(
718
+ (batch_size, src_len - key_padding_mask.size(1)),
719
+ device=key_padding_mask.device,
720
+ )
721
+ new_key_padding_mask = torch.cat(
722
+ [filler.float(), key_padding_mask.float()], dim=1
723
+ )
724
+ else:
725
+ new_key_padding_mask = key_padding_mask.float()
726
+ else:
727
+ new_key_padding_mask = prev_key_padding_mask
728
+ return new_key_padding_mask
729
+
730
+ def _get_input_buffer(
731
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
732
+ ) -> Dict[str, Optional[Tensor]]:
733
+ result = self.get_incremental_state(incremental_state, "attn_state")
734
+ if result is not None:
735
+ return result
736
+ else:
737
+ empty_result: Dict[str, Optional[Tensor]] = {}
738
+ return empty_result
739
+
740
+ def _set_input_buffer(
741
+ self,
742
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
743
+ buffer: Dict[str, Optional[Tensor]],
744
+ ):
745
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
746
+
747
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
748
+ return attn_weights
749
+
750
+
751
+ def init_bert_params(module):
752
+ """
753
+ Initialize the weights specific to the BERT Model.
754
+ This overrides the default initializations depending on the specified arguments.
755
+ 1. If normal_init_linear_weights is set then weights of linear
756
+ layer will be initialized using the normal distribution and
757
+ bais will be set to the specified value.
758
+ 2. If normal_init_embed_weights is set then weights of embedding
759
+ layer will be initialized using the normal distribution.
760
+ 3. If normal_init_proj_weights is set then weights of
761
+ in_project_weight for MultiHeadAttention initialized using
762
+ the normal distribution (to be validated).
763
+ """
764
+
765
+ def normal_(data):
766
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
767
+ # so that the RNG is consistent with and without FSDP
768
+ data.copy_(
769
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
770
+ )
771
+
772
+ if isinstance(module, nn.Linear):
773
+ normal_(module.weight.data)
774
+ if module.bias is not None:
775
+ module.bias.data.zero_()
776
+ if isinstance(module, nn.Embedding):
777
+ normal_(module.weight.data)
778
+ if module.padding_idx is not None:
779
+ module.weight.data[module.padding_idx].zero_()
780
+ if isinstance(module, MultiheadAttention):
781
+ normal_(module.q_proj.weight.data)
782
+ normal_(module.k_proj.weight.data)
783
+ normal_(module.v_proj.weight.data)
models/beats/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ import torch
13
+ from torch import Tensor, nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class GradMultiply(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, scale):
20
+ ctx.scale = scale
21
+ res = x.new(x)
22
+ return res
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad):
26
+ return grad * ctx.scale, None
27
+
28
+
29
+ class SamePad(nn.Module):
30
+ def __init__(self, kernel_size, causal=False):
31
+ super().__init__()
32
+ if causal:
33
+ self.remove = kernel_size - 1
34
+ else:
35
+ self.remove = 1 if kernel_size % 2 == 0 else 0
36
+
37
+ def forward(self, x):
38
+ if self.remove > 0:
39
+ x = x[:, :, : -self.remove]
40
+ return x
41
+
42
+
43
+ class Swish(nn.Module):
44
+ def __init__(self):
45
+ super(Swish, self).__init__()
46
+ self.act = torch.nn.Sigmoid()
47
+
48
+ def forward(self, x):
49
+ return x * self.act(x)
50
+
51
+
52
+ class GLU_Linear(nn.Module):
53
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
54
+ super(GLU_Linear, self).__init__()
55
+
56
+ self.glu_type = glu_type
57
+ self.output_dim = output_dim
58
+
59
+ if glu_type == "sigmoid":
60
+ self.glu_act = torch.nn.Sigmoid()
61
+ elif glu_type == "swish":
62
+ self.glu_act = Swish()
63
+ elif glu_type == "relu":
64
+ self.glu_act = torch.nn.ReLU()
65
+ elif glu_type == "gelu":
66
+ self.glu_act = torch.nn.GELU()
67
+
68
+ if bias_in_glu:
69
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
70
+ else:
71
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
72
+
73
+ def forward(self, x):
74
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
75
+ x = self.linear(x)
76
+
77
+ if self.glu_type == "bilinear":
78
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
79
+ else:
80
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
81
+
82
+ return x
83
+
84
+
85
+ def gelu_accurate(x):
86
+ if not hasattr(gelu_accurate, "_a"):
87
+ gelu_accurate._a = math.sqrt(2 / math.pi)
88
+ return (
89
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
90
+ )
91
+
92
+
93
+ def gelu(x: torch.Tensor) -> torch.Tensor:
94
+ return torch.nn.functional.gelu(x.float()).type_as(x)
95
+
96
+
97
+ def get_activation_fn(activation: str):
98
+ """Returns the activation function corresponding to `activation`"""
99
+
100
+ if activation == "relu":
101
+ return F.relu
102
+ elif activation == "gelu":
103
+ return gelu
104
+ elif activation == "gelu_fast":
105
+ warnings.warn(
106
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
107
+ )
108
+ return gelu_accurate
109
+ elif activation == "gelu_accurate":
110
+ return gelu_accurate
111
+ elif activation == "tanh":
112
+ return torch.tanh
113
+ elif activation == "linear":
114
+ return lambda x: x
115
+ elif activation == "glu":
116
+ return lambda x: x
117
+ else:
118
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
119
+
120
+
121
+ def quant_noise(module, p, block_size):
122
+ """
123
+ Wraps modules and applies quantization noise to the weights for
124
+ subsequent quantization with Iterative Product Quantization as
125
+ described in "Training with Quantization Noise for Extreme Model Compression"
126
+
127
+ Args:
128
+ - module: nn.Module
129
+ - p: amount of Quantization Noise
130
+ - block_size: size of the blocks for subsequent quantization with iPQ
131
+
132
+ Remarks:
133
+ - Module weights must have the right sizes wrt the block size
134
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
135
+ - For more detail on how to quantize by blocks with convolutional weights,
136
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
137
+ - We implement the simplest form of noise here as stated in the paper
138
+ which consists in randomly dropping blocks
139
+ """
140
+
141
+ # if no quantization noise, don't register hook
142
+ if p <= 0:
143
+ return module
144
+
145
+ # supported modules
146
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
147
+
148
+ # test whether module.weight has the right sizes wrt block_size
149
+ is_conv = module.weight.ndim == 4
150
+
151
+ # 2D matrix
152
+ if not is_conv:
153
+ assert (
154
+ module.weight.size(1) % block_size == 0
155
+ ), "Input features must be a multiple of block sizes"
156
+
157
+ # 4D matrix
158
+ else:
159
+ # 1x1 convolutions
160
+ if module.kernel_size == (1, 1):
161
+ assert (
162
+ module.in_channels % block_size == 0
163
+ ), "Input channels must be a multiple of block sizes"
164
+ # regular convolutions
165
+ else:
166
+ k = module.kernel_size[0] * module.kernel_size[1]
167
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
168
+
169
+ def _forward_pre_hook(mod, input):
170
+ # no noise for evaluation
171
+ if mod.training:
172
+ if not is_conv:
173
+ # gather weight and sizes
174
+ weight = mod.weight
175
+ in_features = weight.size(1)
176
+ out_features = weight.size(0)
177
+
178
+ # split weight matrix into blocks and randomly drop selected blocks
179
+ mask = torch.zeros(
180
+ in_features // block_size * out_features, device=weight.device
181
+ )
182
+ mask.bernoulli_(p)
183
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
184
+
185
+ else:
186
+ # gather weight and sizes
187
+ weight = mod.weight
188
+ in_channels = mod.in_channels
189
+ out_channels = mod.out_channels
190
+
191
+ # split weight matrix into blocks and randomly drop selected blocks
192
+ if mod.kernel_size == (1, 1):
193
+ mask = torch.zeros(
194
+ int(in_channels // block_size * out_channels),
195
+ device=weight.device,
196
+ )
197
+ mask.bernoulli_(p)
198
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
199
+ else:
200
+ mask = torch.zeros(
201
+ weight.size(0), weight.size(1), device=weight.device
202
+ )
203
+ mask.bernoulli_(p)
204
+ mask = (
205
+ mask.unsqueeze(2)
206
+ .unsqueeze(3)
207
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
208
+ )
209
+
210
+ # scale weights and apply mask
211
+ mask = mask.to(
212
+ torch.bool
213
+ ) # x.bool() is not currently supported in TorchScript
214
+ s = 1 / (1 - p)
215
+ mod.weight.data = s * weight.masked_fill(mask, 0)
216
+
217
+ module.register_forward_pre_hook(_forward_pre_hook)
218
+ return module
models/beats/quantizer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.distributed as distributed
14
+
15
+ try:
16
+ from einops import rearrange, repeat
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ def l2norm(t):
22
+ return F.normalize(t, p=2, dim=-1)
23
+
24
+
25
+ def ema_inplace(moving_avg, new, decay):
26
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, 'n d -> n () d') \
50
+ - rearrange(means, 'c d -> () c d')
51
+ dists = -(diffs ** 2).sum(dim=-1)
52
+
53
+ buckets = dists.max(dim=-1).indices
54
+ bins = torch.bincount(buckets, minlength=num_clusters)
55
+ zero_mask = bins == 0
56
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
57
+
58
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
59
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
60
+ new_means = new_means / bins_min_clamped[..., None]
61
+
62
+ if use_cosine_sim:
63
+ new_means = l2norm(new_means)
64
+
65
+ means = torch.where(zero_mask[..., None], means, new_means)
66
+
67
+ return means, bins
68
+
69
+
70
+ class EmbeddingEMA(nn.Module):
71
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
72
+ super().__init__()
73
+ self.num_tokens = num_tokens
74
+ self.codebook_dim = codebook_dim
75
+ self.decay = decay
76
+ self.eps = eps
77
+ if codebook_init_path == '':
78
+ if not kmeans_init:
79
+ weight = torch.randn(num_tokens, codebook_dim)
80
+ weight = l2norm(weight)
81
+ else:
82
+ weight = torch.zeros(num_tokens, codebook_dim)
83
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
84
+ else:
85
+ print(f"load init codebook weight from {codebook_init_path}")
86
+ codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
87
+ weight = codebook_ckpt_weight.clone()
88
+ self.register_buffer('initted', torch.Tensor([True]))
89
+
90
+ self.weight = nn.Parameter(weight, requires_grad=False)
91
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
92
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
93
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
94
+ self.update = True
95
+
96
+ @torch.jit.ignore
97
+ def init_embed_(self, data):
98
+ if self.initted:
99
+ return
100
+ print("Performing Kemans init for codebook")
101
+ embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
102
+ self.weight.data.copy_(embed)
103
+ self.cluster_size.data.copy_(cluster_size)
104
+ self.initted.data.copy_(torch.Tensor([True]))
105
+
106
+ def forward(self, embed_id):
107
+ return F.embedding(embed_id, self.weight)
108
+
109
+ def cluster_size_ema_update(self, new_cluster_size):
110
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
111
+
112
+ def embed_avg_ema_update(self, new_embed_avg):
113
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
114
+
115
+ def weight_update(self, num_tokens):
116
+ n = self.cluster_size.sum()
117
+ smoothed_cluster_size = (
118
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
119
+ )
120
+ # normalize embedding average with smoothed cluster size
121
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
122
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
123
+ self.weight.data.copy_(embed_normalized)
124
+
125
+
126
+ def norm_ema_inplace(moving_avg, new, decay):
127
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
128
+ moving_avg.data.copy_(l2norm(moving_avg.data))
129
+
130
+
131
+ class NormEMAVectorQuantizer(nn.Module):
132
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
133
+ statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
134
+ super().__init__()
135
+ self.codebook_dim = embedding_dim
136
+ self.num_tokens = n_embed
137
+ self.beta = beta
138
+ self.decay = decay
139
+
140
+ # learnable = True if orthogonal_reg_weight > 0 else False
141
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
142
+
143
+ self.statistic_code_usage = statistic_code_usage
144
+ if statistic_code_usage:
145
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
146
+ if distributed.is_available() and distributed.is_initialized():
147
+ print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
148
+ self.all_reduce_fn = distributed.all_reduce
149
+ else:
150
+ self.all_reduce_fn = nn.Identity()
151
+
152
+ def reset_cluster_size(self, device):
153
+ if self.statistic_code_usage:
154
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
155
+ self.cluster_size = self.cluster_size.to(device)
156
+
157
+ def forward(self, z):
158
+ # reshape z -> (batch, height, width, channel) and flatten
159
+ # z, 'b c h w -> b h w c'
160
+ # z = rearrange(z, 'b c h w -> b h w c')
161
+ # z = z.transpose(1, 2)
162
+ z = l2norm(z)
163
+ z_flattened = z.reshape(-1, self.codebook_dim)
164
+
165
+ self.embedding.init_embed_(z_flattened)
166
+
167
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
168
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
169
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
170
+
171
+ encoding_indices = torch.argmin(d, dim=1)
172
+
173
+ z_q = self.embedding(encoding_indices).view(z.shape)
174
+
175
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
176
+
177
+ if not self.training:
178
+ with torch.no_grad():
179
+ cluster_size = encodings.sum(0)
180
+ self.all_reduce_fn(cluster_size)
181
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
182
+
183
+ if self.training and self.embedding.update:
184
+ # EMA cluster size
185
+
186
+ bins = encodings.sum(0)
187
+ self.all_reduce_fn(bins)
188
+
189
+ # self.embedding.cluster_size_ema_update(bins)
190
+ ema_inplace(self.cluster_size, bins, self.decay)
191
+
192
+ zero_mask = (bins == 0)
193
+ bins = bins.masked_fill(zero_mask, 1.)
194
+
195
+ embed_sum = z_flattened.t() @ encodings
196
+ self.all_reduce_fn(embed_sum)
197
+
198
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
199
+ embed_normalized = l2norm(embed_normalized)
200
+
201
+ embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
202
+ embed_normalized)
203
+ norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
204
+
205
+ # compute loss for embedding
206
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
207
+
208
+ # preserve gradients
209
+ z_q = z + (z_q - z).detach()
210
+
211
+ # reshape back to match original input shape
212
+ # z_q, 'b h w c -> b c h w'
213
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
214
+ # z_q = z_q.transpose(1, 2)
215
+ return z_q, loss, encoding_indices
models/frame_mn/Frame_MN_wrapper.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.frame_passt.preprocess import AugmentMelSTFT
2
+ from models.transformer_wrapper import BaseModelWrapper
3
+ from models.frame_mn.model import get_model
4
+
5
+
6
+ class FrameMNWrapper(BaseModelWrapper):
7
+ def __init__(self, width_mult=1.0) -> None:
8
+ super().__init__()
9
+ self.mel = AugmentMelSTFT(
10
+ n_mels=128,
11
+ sr=16_000,
12
+ win_length=400,
13
+ hopsize=160,
14
+ n_fft=512,
15
+ freqm=0,
16
+ timem=0,
17
+ htk=False,
18
+ fmin=0.0,
19
+ fmax=None,
20
+ norm=1,
21
+ fmin_aug_range=10,
22
+ fmax_aug_range=2000,
23
+ fast_norm=True,
24
+ preamp=True,
25
+ padding="center",
26
+ periodic_window=False,
27
+ )
28
+
29
+ self.frame_mn = get_model(
30
+ width_mult=width_mult
31
+ )
32
+
33
+ def mel_forward(self, x):
34
+ return self.mel(x)
35
+
36
+ def forward(self, x):
37
+ return self.frame_mn(x)
38
+
39
+ def separate_params(self):
40
+ pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
41
+ for k, p in self.named_parameters():
42
+ if any(['cls_token' in k,
43
+ 'pos_embed' in k,
44
+ 'norm_stats' in k,
45
+ 'patch_embed' in k]):
46
+ pt_params[0].append(p)
47
+ elif 'blocks.0.' in k:
48
+ pt_params[0].append(p)
49
+ elif 'blocks.1.' in k:
50
+ pt_params[1].append(p)
51
+ elif 'blocks.2.' in k:
52
+ pt_params[2].append(p)
53
+ elif 'blocks.3.' in k:
54
+ pt_params[3].append(p)
55
+ elif 'blocks.4.' in k:
56
+ pt_params[4].append(p)
57
+ elif 'blocks.5.' in k:
58
+ pt_params[5].append(p)
59
+ elif 'blocks.6.' in k:
60
+ pt_params[6].append(p)
61
+ elif 'blocks.7.' in k:
62
+ pt_params[7].append(p)
63
+ elif 'blocks.8.' in k:
64
+ pt_params[8].append(p)
65
+ elif 'blocks.9.' in k:
66
+ pt_params[9].append(p)
67
+ elif 'blocks.10.' in k:
68
+ pt_params[10].append(p)
69
+ elif 'blocks.11.' in k:
70
+ pt_params[11].append(p)
71
+ elif 'asit.norm.weight' in k or 'asit.norm.bias' in k:
72
+ pt_params[11].append(p)
73
+ else:
74
+ raise ValueError(f"Check separate params for ASiT! Unknown key: {k}")
75
+ return list(reversed(pt_params))
models/frame_mn/block_types.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Callable, List
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ from torchvision.ops.misc import ConvNormActivation
6
+
7
+ from models.frame_mn.utils import make_divisible, cnn_out_size
8
+
9
+
10
+ class ConcurrentSEBlock(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ c_dim: int,
14
+ f_dim: int,
15
+ t_dim: int,
16
+ se_cnf: Dict
17
+ ) -> None:
18
+ super().__init__()
19
+ dims = [c_dim, f_dim, t_dim]
20
+ self.conc_se_layers = nn.ModuleList()
21
+ for d in se_cnf['se_dims']:
22
+ input_dim = dims[d-1]
23
+ squeeze_dim = make_divisible(input_dim // se_cnf['se_r'], 8)
24
+ self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d))
25
+ if se_cnf['se_agg'] == "max":
26
+ self.agg_op = lambda x: torch.max(x, dim=0)[0]
27
+ elif se_cnf['se_agg'] == "avg":
28
+ self.agg_op = lambda x: torch.mean(x, dim=0)
29
+ elif se_cnf['se_agg'] == "add":
30
+ self.agg_op = lambda x: torch.sum(x, dim=0)
31
+ elif se_cnf['se_agg'] == "min":
32
+ self.agg_op = lambda x: torch.min(x, dim=0)[0]
33
+ else:
34
+ raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented")
35
+
36
+ def forward(self, input: Tensor) -> Tensor:
37
+ # apply all concurrent se layers
38
+ se_outs = []
39
+ for se_layer in self.conc_se_layers:
40
+ se_outs.append(se_layer(input))
41
+ out = self.agg_op(torch.stack(se_outs, dim=0))
42
+ return out
43
+
44
+
45
+ class SqueezeExcitation(torch.nn.Module):
46
+ """
47
+ This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507.
48
+ Args:
49
+ input_dim (int): Input dimension
50
+ squeeze_dim (int): Size of Bottleneck
51
+ activation (Callable): activation applied to bottleneck
52
+ scale_activation (Callable): activation applied to the output
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ input_dim: int,
58
+ squeeze_dim: int,
59
+ se_dim: int,
60
+ activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
61
+ scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
62
+ ) -> None:
63
+ super().__init__()
64
+ self.fc1 = torch.nn.Linear(input_dim, squeeze_dim)
65
+ self.fc2 = torch.nn.Linear(squeeze_dim, input_dim)
66
+ assert se_dim in [1, 2, 3]
67
+ self.se_dim = [1, 2, 3]
68
+ self.se_dim.remove(se_dim)
69
+ self.activation = activation()
70
+ self.scale_activation = scale_activation()
71
+
72
+ def _scale(self, input: Tensor) -> Tensor:
73
+ scale = torch.mean(input, self.se_dim, keepdim=True)
74
+ shape = scale.size()
75
+ scale = self.fc1(scale.squeeze(2).squeeze(2))
76
+ scale = self.activation(scale)
77
+ scale = self.fc2(scale)
78
+ scale = scale
79
+ return self.scale_activation(scale).view(shape)
80
+
81
+ def forward(self, input: Tensor) -> Tensor:
82
+ scale = self._scale(input)
83
+ return scale * input
84
+
85
+
86
+ class InvertedResidualConfig:
87
+ # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
88
+ def __init__(
89
+ self,
90
+ input_channels: int,
91
+ kernel: int,
92
+ expanded_channels: int,
93
+ out_channels: int,
94
+ use_se: bool,
95
+ activation: str,
96
+ stride: tuple[int],
97
+ dilation: tuple[int],
98
+ width_mult: float,
99
+ ):
100
+ self.input_channels = self.adjust_channels(input_channels, width_mult)
101
+ self.kernel = kernel
102
+ self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
103
+ self.out_channels = self.adjust_channels(out_channels, width_mult)
104
+ self.use_se = use_se
105
+ self.use_hs = activation == "HS"
106
+ self.stride = stride
107
+ self.dilation = dilation
108
+ self.f_dim = None
109
+ self.t_dim = None
110
+
111
+ @staticmethod
112
+ def adjust_channels(channels: int, width_mult: float):
113
+ return make_divisible(channels * width_mult, 8)
114
+
115
+ def out_size(self, in_size, idx=None):
116
+ dilation = self.dilation if idx is None else self.dilation[idx]
117
+ padding = (self.kernel - 1) // 2 * dilation
118
+ stride = self.stride if idx is None else self.stride[idx]
119
+ return cnn_out_size(in_size, padding, dilation, self.kernel, stride)
120
+
121
+
122
+ class InvertedResidual(nn.Module):
123
+ def __init__(
124
+ self,
125
+ cnf: InvertedResidualConfig,
126
+ se_cnf: Dict,
127
+ norm_layer: Callable[..., nn.Module],
128
+ depthwise_norm_layer: Callable[..., nn.Module]
129
+ ):
130
+ super().__init__()
131
+
132
+ if not (1 <= cnf.stride[0] <= 2 or 1 <= cnf.stride[1] <= 2):
133
+ raise ValueError("illegal stride value")
134
+
135
+ self.use_res_connect = cnf.stride[0] == 1 and cnf.stride[1] == 1 and cnf.input_channels == cnf.out_channels
136
+
137
+ layers: List[nn.Module] = []
138
+ activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
139
+
140
+ # expand
141
+ if cnf.expanded_channels != cnf.input_channels:
142
+ layers.append(
143
+ ConvNormActivation(
144
+ cnf.input_channels,
145
+ cnf.expanded_channels,
146
+ kernel_size=1,
147
+ norm_layer=norm_layer,
148
+ activation_layer=activation_layer,
149
+ )
150
+ )
151
+
152
+ # depthwise
153
+ d = cnf.dilation > 1 if isinstance(cnf.dilation, int) else cnf.dilation[1] > 1
154
+ stride = [cnf.stride, cnf.stride] if isinstance(cnf.stride, int) else list(cnf.stride)
155
+
156
+ if d:
157
+ stride[1] = 1
158
+
159
+ layers.append(
160
+ ConvNormActivation(
161
+ cnf.expanded_channels,
162
+ cnf.expanded_channels,
163
+ kernel_size=cnf.kernel,
164
+ stride=tuple(stride),
165
+ dilation=cnf.dilation,
166
+ groups=cnf.expanded_channels,
167
+ norm_layer=depthwise_norm_layer,
168
+ activation_layer=activation_layer,
169
+ )
170
+ )
171
+ if cnf.use_se and se_cnf['se_dims'] is not None:
172
+ layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf))
173
+
174
+ # project
175
+ layers.append(
176
+ ConvNormActivation(
177
+ cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
178
+ )
179
+ )
180
+
181
+ self.block = nn.Sequential(*layers)
182
+ self.out_channels = cnf.out_channels
183
+ # self._is_cn = cnf.stride[0] > 1 and cnf.stride[1] > 1
184
+
185
+ def forward(self, inp: Tensor) -> Tensor:
186
+ result = self.block(inp)
187
+ if self.use_res_connect:
188
+ result += inp
189
+ return result
models/frame_mn/model.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.parse
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Sequence, Tuple
5
+
6
+ import torch
7
+ from torch import nn, Tensor
8
+ from torch.hub import load_state_dict_from_url
9
+ from torchvision.ops.misc import ConvNormActivation
10
+
11
+ from models.frame_mn.block_types import InvertedResidualConfig, InvertedResidual
12
+ from models.frame_mn.utils import cnn_out_size
13
+
14
+ # Adapted version of MobileNetV3 pytorch implementation
15
+ # https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py
16
+
17
+ # points to github releases
18
+ model_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/"
19
+ # folder to store downloaded models to
20
+ model_dir = "resources"
21
+
22
+ pretrained_models = {
23
+ # pytorch ImageNet pre-trained model
24
+ # own ImageNet pre-trained models will follow
25
+ # NOTE: for easy loading we provide the adapted state dict ready for AudioSet training (1 input channel,
26
+ # 527 output classes)
27
+ # NOTE: the classifier is just a random initialization, feature extractor (conv layers) is pre-trained
28
+ "mn10_im_pytorch": urllib.parse.urljoin(model_url, "mn10_im_pytorch.pt"),
29
+ # self-trained models on ImageNet
30
+ "mn01_im": urllib.parse.urljoin(model_url, "mn01_im.pt"),
31
+ "mn02_im": urllib.parse.urljoin(model_url, "mn02_im.pt"),
32
+ "mn04_im": urllib.parse.urljoin(model_url, "mn04_im.pt"),
33
+ "mn05_im": urllib.parse.urljoin(model_url, "mn05_im.pt"),
34
+ "mn06_im": urllib.parse.urljoin(model_url, "mn06_im.pt"),
35
+ "mn10_im": urllib.parse.urljoin(model_url, "mn10_im.pt"),
36
+ "mn20_im": urllib.parse.urljoin(model_url, "mn20_im.pt"),
37
+ "mn30_im": urllib.parse.urljoin(model_url, "mn30_im.pt"),
38
+ "mn40_im": urllib.parse.urljoin(model_url, "mn40_im.pt"),
39
+ # Models trained on AudioSet
40
+ "mn01_as": urllib.parse.urljoin(model_url, "mn01_as_mAP_298.pt"),
41
+ "mn02_as": urllib.parse.urljoin(model_url, "mn02_as_mAP_378.pt"),
42
+ "mn04_as": urllib.parse.urljoin(model_url, "mn04_as_mAP_432.pt"),
43
+ "mn05_as": urllib.parse.urljoin(model_url, "mn05_as_mAP_443.pt"),
44
+ "mn10_as": urllib.parse.urljoin(model_url, "mn10_as_mAP_471.pt"),
45
+ "mn20_as": urllib.parse.urljoin(model_url, "mn20_as_mAP_478.pt"),
46
+ "mn30_as": urllib.parse.urljoin(model_url, "mn30_as_mAP_482.pt"),
47
+ "mn40_as": urllib.parse.urljoin(model_url, "mn40_as_mAP_484.pt"),
48
+ "mn40_as(2)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483.pt"),
49
+ "mn40_as(3)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483(2).pt"),
50
+ "mn40_as_no_im_pre": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483.pt"),
51
+ "mn40_as_no_im_pre(2)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483(2).pt"),
52
+ "mn40_as_no_im_pre(3)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_482.pt"),
53
+ "mn40_as_ext": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_487.pt"),
54
+ "mn40_as_ext(2)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_486.pt"),
55
+ "mn40_as_ext(3)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_485.pt"),
56
+ # varying hop size (time resolution)
57
+ "mn10_as_hop_5": urllib.parse.urljoin(model_url, "mn10_as_hop_5_mAP_475.pt"),
58
+ "mn10_as_hop_15": urllib.parse.urljoin(model_url, "mn10_as_hop_15_mAP_463.pt"),
59
+ "mn10_as_hop_20": urllib.parse.urljoin(model_url, "mn10_as_hop_20_mAP_456.pt"),
60
+ "mn10_as_hop_25": urllib.parse.urljoin(model_url, "mn10_as_hop_25_mAP_447.pt"),
61
+ # varying n_mels (frequency resolution)
62
+ "mn10_as_mels_40": urllib.parse.urljoin(model_url, "mn10_as_mels_40_mAP_453.pt"),
63
+ "mn10_as_mels_64": urllib.parse.urljoin(model_url, "mn10_as_mels_64_mAP_461.pt"),
64
+ "mn10_as_mels_256": urllib.parse.urljoin(model_url, "mn10_as_mels_256_mAP_474.pt"),
65
+ # fully-convolutional head
66
+ "mn10_as_fc": urllib.parse.urljoin(model_url, "mn10_as_fc_mAP_465.pt"),
67
+ "mn10_as_fc_s2221": urllib.parse.urljoin(model_url, "mn10_as_fc_s2221_mAP_466.pt"),
68
+ "mn10_as_fc_s2211": urllib.parse.urljoin(model_url, "mn10_as_fc_s2211_mAP_466.pt"),
69
+ }
70
+
71
+
72
+ class MN(nn.Module):
73
+ def __init__(
74
+ self,
75
+ inverted_residual_setting: List[InvertedResidualConfig],
76
+ block: Optional[Callable[..., nn.Module]] = None,
77
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
78
+ in_conv_kernel: int = 3,
79
+ in_conv_stride: int = 2,
80
+ in_channels: int = 1,
81
+ **kwargs: Any,
82
+ ) -> None:
83
+ """
84
+ MobileNet V3 main class
85
+
86
+ Args:
87
+ inverted_residual_setting (List[InvertedResidualConfig]): Network structure
88
+ block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for models
89
+ norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
90
+ in_conv_kernel (int): Size of kernel for first convolution
91
+ in_conv_stride (int): Size of stride for first convolution
92
+ in_channels (int): Number of input channels
93
+ """
94
+ super(MN, self).__init__()
95
+
96
+ if not inverted_residual_setting:
97
+ raise ValueError("The inverted_residual_setting should not be empty")
98
+ elif not (
99
+ isinstance(inverted_residual_setting, Sequence)
100
+ and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
101
+ ):
102
+ raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
103
+
104
+ if block is None:
105
+ block = InvertedResidual
106
+
107
+ depthwise_norm_layer = norm_layer = \
108
+ norm_layer if norm_layer is not None else partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
109
+
110
+ layers: List[nn.Module] = []
111
+
112
+ kernel_sizes = [in_conv_kernel]
113
+ strides = [in_conv_stride]
114
+
115
+ # building first layer
116
+ firstconv_output_channels = inverted_residual_setting[0].input_channels
117
+ layers.append(
118
+ ConvNormActivation(
119
+ in_channels,
120
+ firstconv_output_channels,
121
+ kernel_size=in_conv_kernel,
122
+ stride=in_conv_stride,
123
+ norm_layer=norm_layer,
124
+ activation_layer=nn.Hardswish,
125
+ )
126
+ )
127
+
128
+ # get squeeze excitation config
129
+ se_cnf = kwargs.get('se_conf', None)
130
+
131
+ # building inverted residual blocks
132
+ # - keep track of size of frequency and time dimensions for possible application of Squeeze-and-Excitation
133
+ # on the frequency/time dimension
134
+ # - applying Squeeze-and-Excitation on the time dimension is not recommended as this constrains the network to
135
+ # a particular length of the audio clip, whereas Squeeze-and-Excitation on the frequency bands is fine,
136
+ # as the number of frequency bands is usually not changing
137
+ f_dim, t_dim = kwargs.get('input_dims', (128, 1000))
138
+ # take into account first conv layer
139
+ f_dim = cnn_out_size(f_dim, 1, 1, 3, 2)
140
+ t_dim = cnn_out_size(t_dim, 1, 1, 3, 2)
141
+ for cnf in inverted_residual_setting:
142
+ f_dim = cnf.out_size(f_dim, idx=0)
143
+ t_dim = cnf.out_size(t_dim, idx=1)
144
+ cnf.f_dim, cnf.t_dim = f_dim, t_dim # update dimensions in block config
145
+ layers.append(block(cnf, se_cnf, norm_layer, depthwise_norm_layer))
146
+ kernel_sizes.append(cnf.kernel)
147
+ strides.append(cnf.stride)
148
+
149
+ # building last several layers
150
+ lastconv_input_channels = inverted_residual_setting[-1].out_channels
151
+ lastconv_output_channels = 6 * lastconv_input_channels
152
+ self.lastconv_output_channels = lastconv_output_channels
153
+ layers.append(
154
+ ConvNormActivation(
155
+ lastconv_input_channels,
156
+ lastconv_output_channels,
157
+ kernel_size=1,
158
+ norm_layer=norm_layer,
159
+ activation_layer=nn.Hardswish,
160
+ )
161
+ )
162
+
163
+ self.features = nn.Sequential(*layers)
164
+
165
+ # no prediction head needed - we want to use Frame-MobileNet to extract a 3D sequence
166
+ # i.e.: batch size x sequence length x channel dimension
167
+
168
+ for m in self.modules():
169
+ if isinstance(m, nn.Conv2d):
170
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
171
+ if m.bias is not None:
172
+ nn.init.zeros_(m.bias)
173
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
174
+ nn.init.ones_(m.weight)
175
+ nn.init.zeros_(m.bias)
176
+ elif isinstance(m, nn.Linear):
177
+ nn.init.normal_(m.weight, 0, 0.01)
178
+ if m.bias is not None:
179
+ nn.init.zeros_(m.bias)
180
+
181
+ def _forward_impl(self, x: Tensor, return_fmaps: bool = False) -> Tensor:
182
+ fmaps = []
183
+
184
+ for i, layer in enumerate(self.features):
185
+ x = layer(x)
186
+ if return_fmaps:
187
+ fmaps.append(x)
188
+
189
+ # reshape: batch size x channels x frequency bands x time -> batch size x time x channels
190
+ # works, because frequency dimension is exactly 1
191
+ x = x.squeeze(2).permute(0, 2, 1)
192
+ return x
193
+
194
+ def forward(self, x: Tensor) -> Tensor:
195
+ return self._forward_impl(x)
196
+
197
+ def load_model(self, path, wandb_id):
198
+ ckpt_path = os.path.join(path, wandb_id + ".ckpt")
199
+
200
+ pretrained_weights = torch.load(ckpt_path, map_location="cpu")["state_dict"]
201
+ pretrained_weights = {k[10:]: v for k, v in pretrained_weights.items() if k[:10] == "net.model."}
202
+ self.load_state_dict(pretrained_weights)
203
+
204
+ print("Loaded model successfully. Wandb_id:", wandb_id)
205
+
206
+
207
+ def _mobilenet_v3_conf(
208
+ width_mult: float = 1.0,
209
+ reduced_tail: bool = False,
210
+ dilated: bool = False,
211
+ strides: Tuple[int] = None,
212
+ dilation_list_t_dim: Optional[List[int]] = None,
213
+ **kwargs
214
+ ):
215
+ reduce_divider = 2 if reduced_tail else 1
216
+ if dilation_list_t_dim is None:
217
+ dilation_list_t_dim = [1] * 15
218
+ if dilated:
219
+ dilation_list_t_dim[-3:] = [2] * 3
220
+
221
+ print("dilation_list_t_dim: ")
222
+ print(dilation_list_t_dim)
223
+
224
+ bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
225
+ adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
226
+
227
+ if strides is None:
228
+ # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
229
+ f_strides = (1, 2, 2, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2)
230
+ t_strides = (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
231
+
232
+ strides = tuple(zip(f_strides, t_strides))
233
+
234
+ # InvertedResidualConfig:
235
+ # input_channels, kernel, expanded_channels, out_channels, use_se, activation, stride, dilation
236
+ inverted_residual_setting = [
237
+ bneck_conf(16, 3, 16, 16, False, "RE", strides[0], (1, dilation_list_t_dim[0])), # 0
238
+ bneck_conf(16, 3, 64, 24, False, "RE", strides[1], (1, dilation_list_t_dim[1])), # 1 - C1
239
+ bneck_conf(24, 3, 72, 24, False, "RE", strides[2], (1, dilation_list_t_dim[2])), # 2
240
+ bneck_conf(24, 5, 72, 40, True, "RE", strides[3], (1, dilation_list_t_dim[3])), # 3 - C2
241
+ bneck_conf(40, 5, 120, 40, True, "RE", strides[4], (1, dilation_list_t_dim[4])), # 4
242
+ bneck_conf(40, 5, 120, 40, True, "RE", strides[5], (1, dilation_list_t_dim[5])), # 5
243
+ bneck_conf(40, 3, 240, 80, False, "HS", strides[6], (1, dilation_list_t_dim[6])), # 6 - C3
244
+ bneck_conf(80, 3, 200, 80, False, "HS", strides[7], (1, dilation_list_t_dim[7])), # 7
245
+ bneck_conf(80, 3, 184, 80, False, "HS", strides[8], (1, dilation_list_t_dim[8])), # 8
246
+ bneck_conf(80, 3, 184, 80, False, "HS", strides[9], (1, dilation_list_t_dim[9])), # 9
247
+ bneck_conf(80, 3, 480, 112, True, "HS", strides[10], (1, dilation_list_t_dim[10])), # 10
248
+ bneck_conf(112, 3, 672, 112, True, "HS", strides[11], (1, dilation_list_t_dim[11])), # 11
249
+ bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", strides[12], (1, dilation_list_t_dim[12])),
250
+ # 12 - C4 # dilation
251
+ bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", strides[13],
252
+ (1, dilation_list_t_dim[13])), # 13 # dilation
253
+ bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", strides[14],
254
+ (1, dilation_list_t_dim[14])), # 14 # dilation
255
+ ]
256
+ last_channel = adjust_channels(1280 // reduce_divider)
257
+
258
+ return inverted_residual_setting, last_channel
259
+
260
+
261
+ def _mobilenet_v3(
262
+ inverted_residual_setting: List[InvertedResidualConfig],
263
+ pretrained_name: str,
264
+ **kwargs: Any,
265
+ ):
266
+ model = MN(inverted_residual_setting, **kwargs)
267
+
268
+ if pretrained_name in pretrained_models:
269
+ model_url = pretrained_models.get(pretrained_name)
270
+ state_dict = load_state_dict_from_url(model_url, model_dir=model_dir, map_location="cpu")
271
+ if kwargs['head_type'] == "mlp":
272
+ num_classes = state_dict['classifier.5.bias'].size(0)
273
+ elif kwargs['head_type'] == "fully_convolutional":
274
+ num_classes = state_dict['classifier.1.bias'].size(0)
275
+ else:
276
+ print("Loading weights for classifier only implemented for head types 'mlp' and 'fully_convolutional'")
277
+ num_classes = -1
278
+ if kwargs['num_classes'] != num_classes:
279
+ # if the number of logits is not matching the state dict,
280
+ # drop the corresponding pre-trained part
281
+ pretrain_logits = state_dict['classifier.5.bias'].size(0) if kwargs['head_type'] == "mlp" \
282
+ else state_dict['classifier.1.bias'].size(0)
283
+ print(f"Number of classes defined: {kwargs['num_classes']}, "
284
+ f"but try to load pre-trained layer with logits: {pretrain_logits}\n"
285
+ "Dropping last layer.")
286
+ if kwargs['head_type'] == "mlp":
287
+ del state_dict['classifier.5.weight']
288
+ del state_dict['classifier.5.bias']
289
+ else:
290
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith('classifier')}
291
+ try:
292
+ model.load_state_dict(state_dict)
293
+ except RuntimeError as e:
294
+ print(str(e))
295
+ print("Loading weights pre-trained weights in a non-strict manner.")
296
+ model.load_state_dict(state_dict, strict=False)
297
+ elif pretrained_name:
298
+ raise NotImplementedError(f"Model name '{pretrained_name}' unknown.")
299
+ return model
300
+
301
+
302
+ def mobilenet_v3(pretrained_name: str = None, **kwargs: Any) \
303
+ -> MN:
304
+ """
305
+ Constructs a MobileNetV3 architecture from
306
+ "Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>".
307
+ """
308
+ inverted_residual_setting, last_channel = _mobilenet_v3_conf(**kwargs)
309
+ return _mobilenet_v3(inverted_residual_setting, pretrained_name, **kwargs)
310
+
311
+
312
+ def get_model(pretrained_name: str = None, width_mult: float = 1.0,
313
+ reduced_tail: bool = False, dilated: bool = False, dilation_list_t_dim=None,
314
+ strides: Tuple[int, int, int, int] = None,
315
+ head_type: str = "mlp", multihead_attention_heads: int = 4, input_dim_f: int = 128,
316
+ input_dim_t: int = 1000, se_dims: str = 'c', se_agg: str = "max", se_r: int = 4):
317
+ """
318
+ Arguments to modify the instantiation of a MobileNetv3
319
+
320
+ Args:
321
+ pretrained_name (str): Specifies name of pre-trained model to load
322
+ width_mult (float): Scales width of network
323
+ reduced_tail (bool): Scales down network tail
324
+ dilated (bool): Applies dilated convolution to network tail
325
+ dilation_list_t_dim (List): List of dilation factors to apply to network tail
326
+ strides (Tuple): Strides that are set to '2' in original implementation;
327
+ might be changed to modify the size of receptive field and the downsampling factor in
328
+ time and frequency dimension
329
+ head_type (str): decides which classification head to use
330
+ multihead_attention_heads (int): number of heads in case 'multihead_attention_heads' is used
331
+ input_dim_f (int): number of frequency bands
332
+ input_dim_t (int): number of time frames
333
+ se_dims (Tuple): choose dimension to apply squeeze-excitation on, if multiple dimensions are chosen, then
334
+ squeeze-excitation is applied concurrently and se layer outputs are fused by se_agg operation
335
+ se_agg (str): operation to fuse output of concurrent se layers
336
+ se_r (int): squeeze excitation bottleneck size
337
+ se_dims (str): contains letters corresponding to dimensions 'c' - channel, 'f' - frequency, 't' - time
338
+ """
339
+
340
+ dim_map = {'c': 1, 'f': 2, 't': 3}
341
+ assert len(se_dims) <= 3 and all([s in dim_map.keys() for s in se_dims]) or se_dims == 'none'
342
+ input_dims = (input_dim_f, input_dim_t)
343
+ if se_dims == 'none':
344
+ se_dims = None
345
+ else:
346
+ se_dims = [dim_map[s] for s in se_dims]
347
+ se_conf = dict(se_dims=se_dims, se_agg=se_agg, se_r=se_r)
348
+ m = mobilenet_v3(pretrained_name=pretrained_name,
349
+ width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated,
350
+ dilation_list_t_dim=dilation_list_t_dim,
351
+ strides=strides,
352
+ head_type=head_type, multihead_attention_heads=multihead_attention_heads,
353
+ input_dims=input_dims, se_conf=se_conf
354
+ )
355
+ print(m)
356
+ return m
models/frame_mn/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Callable
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+
7
+
8
+ def NAME_TO_WIDTH(name):
9
+ frame_mn_map = {
10
+ 'frame_mn01': 0.1,
11
+ 'frame_mn02': 0.2,
12
+ 'frame_mn04': 0.4,
13
+ 'frame_mn05': 0.5,
14
+ 'frame_mn06': 0.6,
15
+ 'frame_mn08': 0.8,
16
+ 'frame_mn10': 1.0,
17
+ 'frame_mn12': 1.2,
18
+ 'frame_mn14': 1.4,
19
+ 'frame_mn16': 1.6,
20
+ 'frame_mn20': 2.0,
21
+ 'frame_mn30': 3.0,
22
+ 'frame_mn40': 4.0,
23
+ }
24
+
25
+ frame_dymn_map = {
26
+ 'frame_dymn04': 0.4,
27
+ 'frame_dymn10': 1.0,
28
+ 'frame_dymn20': 2.0,
29
+ }
30
+
31
+ try:
32
+ if name.startswith('frame_dymn'):
33
+ w = frame_dymn_map[name[:len('frame_dymnxx')]]
34
+ else:
35
+ w = frame_mn_map[name[:len('frame_mnxx')]]
36
+ except:
37
+ w = 1.0
38
+
39
+ return w
40
+
41
+
42
+ def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
43
+ """
44
+ This function is taken from the original tf repo.
45
+ It ensures that all layers have a channel number that is divisible by 8
46
+ It can be seen here:
47
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
48
+ """
49
+ if min_value is None:
50
+ min_value = divisor
51
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
52
+ # Make sure that round down does not go down by more than 10%.
53
+ if new_v < 0.9 * v:
54
+ new_v += divisor
55
+ return new_v
56
+
57
+
58
+ def cnn_out_size(in_size, padding, dilation, kernel, stride):
59
+ s = in_size + 2 * padding - dilation * (kernel - 1) - 1
60
+ return math.floor(s / stride + 1)
61
+
62
+
63
+ def collapse_dim(x: Tensor, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean,
64
+ combine_dim: int = None):
65
+ """
66
+ Collapses dimension of multi-dimensional tensor by pooling or combining dimensions
67
+ :param x: input Tensor
68
+ :param dim: dimension to collapse
69
+ :param mode: 'pool' or 'combine'
70
+ :param pool_fn: function to be applied in case of pooling
71
+ :param combine_dim: dimension to join 'dim' to
72
+ :return: collapsed tensor
73
+ """
74
+ if mode == "pool":
75
+ return pool_fn(x, dim)
76
+ elif mode == "combine":
77
+ s = list(x.size())
78
+ s[combine_dim] *= dim
79
+ s[dim] //= dim
80
+ return x.view(s)
81
+
82
+
83
+ class CollapseDim(nn.Module):
84
+ def __init__(self, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean,
85
+ combine_dim: int = None):
86
+ super(CollapseDim, self).__init__()
87
+ self.dim = dim
88
+ self.mode = mode
89
+ self.pool_fn = pool_fn
90
+ self.combine_dim = combine_dim
91
+
92
+ def forward(self, x):
93
+ return collapse_dim(x, dim=self.dim, mode=self.mode, pool_fn=self.pool_fn, combine_dim=self.combine_dim)
models/frame_passt/fpasst.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Most of this code comes from the timm library.
3
+ We tried to disentangle from the timm library version.
4
+
5
+ Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
6
+
7
+ """
8
+ import collections
9
+ import logging
10
+ import math
11
+ import os
12
+ import warnings
13
+ from collections import OrderedDict
14
+ from functools import partial
15
+ from itertools import repeat
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from models.frame_passt.vit_helpers import (DropPath, trunc_normal_,
21
+ build_model_with_cfg, adapt_input_conv)
22
+
23
+ _logger = logging.getLogger()
24
+
25
+
26
+ # From PyTorch internals
27
+ def _ntuple(n):
28
+ def parse(x):
29
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
30
+ return tuple(x)
31
+ return tuple(repeat(x, n))
32
+
33
+ return parse
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
39
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
40
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
41
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
42
+
43
+
44
+ def _cfg(url='', **kwargs):
45
+ return {
46
+ 'url': url,
47
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
48
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
49
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
50
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
51
+ **kwargs
52
+ }
53
+
54
+
55
+ default_cfgs = {
56
+ # patch models (weights from official Google JAX impl)
57
+ 'vit_tiny_patch16_224': _cfg(
58
+ url='https://storage.googleapis.com/vit_models/augreg/'
59
+ 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
60
+ 'vit_tiny_patch16_384': _cfg(
61
+ url='https://storage.googleapis.com/vit_models/augreg/'
62
+ 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
63
+ input_size=(3, 384, 384), crop_pct=1.0),
64
+ 'vit_small_patch32_224': _cfg(
65
+ url='https://storage.googleapis.com/vit_models/augreg/'
66
+ 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
67
+ 'vit_small_patch32_384': _cfg(
68
+ url='https://storage.googleapis.com/vit_models/augreg/'
69
+ 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
70
+ input_size=(3, 384, 384), crop_pct=1.0),
71
+ 'vit_small_patch16_224': _cfg(
72
+ url='https://storage.googleapis.com/vit_models/augreg/'
73
+ 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
74
+ 'vit_small_patch16_384': _cfg(
75
+ url='https://storage.googleapis.com/vit_models/augreg/'
76
+ 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
77
+ input_size=(3, 384, 384), crop_pct=1.0),
78
+ 'vit_base_patch32_224': _cfg(
79
+ url='https://storage.googleapis.com/vit_models/augreg/'
80
+ 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
81
+ 'vit_base_patch32_384': _cfg(
82
+ url='https://storage.googleapis.com/vit_models/augreg/'
83
+ 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
84
+ input_size=(3, 384, 384), crop_pct=1.0),
85
+ 'vit_base_patch16_224': _cfg(
86
+ url='https://storage.googleapis.com/vit_models/augreg/'
87
+ 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
88
+ 'vit_base_patch16_384': _cfg(
89
+ url='https://storage.googleapis.com/vit_models/augreg/'
90
+ 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
91
+ input_size=(3, 384, 384), crop_pct=1.0),
92
+ 'vit_large_patch32_224': _cfg(
93
+ url='', # no official model weights for this combo, only for in21k
94
+ ),
95
+ 'vit_large_patch32_384': _cfg(
96
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
97
+ input_size=(3, 384, 384), crop_pct=1.0),
98
+ 'vit_large_patch16_224': _cfg(
99
+ url='https://storage.googleapis.com/vit_models/augreg/'
100
+ 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
101
+ 'vit_large_patch16_384': _cfg(
102
+ url='https://storage.googleapis.com/vit_models/augreg/'
103
+ 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
104
+ input_size=(3, 384, 384), crop_pct=1.0),
105
+
106
+ # patch models, imagenet21k (weights from official Google JAX impl)
107
+ 'vit_tiny_patch16_224_in21k': _cfg(
108
+ url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
109
+ num_classes=21843),
110
+ 'vit_small_patch32_224_in21k': _cfg(
111
+ url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
112
+ num_classes=21843),
113
+ 'vit_small_patch16_224_in21k': _cfg(
114
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
115
+ num_classes=21843),
116
+ 'vit_base_patch32_224_in21k': _cfg(
117
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
118
+ num_classes=21843),
119
+ 'vit_base_patch16_224_in21k': _cfg(
120
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
121
+ num_classes=21843),
122
+ 'vit_large_patch32_224_in21k': _cfg(
123
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
124
+ num_classes=21843),
125
+ 'vit_large_patch16_224_in21k': _cfg(
126
+ url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
127
+ num_classes=21843),
128
+ 'vit_huge_patch14_224_in21k': _cfg(
129
+ url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
130
+ hf_hub='timm/vit_huge_patch14_224_in21k',
131
+ num_classes=21843),
132
+
133
+ # SAM trained models (https://arxiv.org/abs/2106.01548)
134
+ 'vit_base_patch32_sam_224': _cfg(
135
+ url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
136
+ 'vit_base_patch16_sam_224': _cfg(
137
+ url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
138
+
139
+ # deit models (FB weights)
140
+ 'deit_tiny_patch16_224': _cfg(
141
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
142
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
143
+ 'deit_small_patch16_224': _cfg(
144
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
145
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
146
+ 'deit_base_patch16_224': _cfg(
147
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
148
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
149
+ 'deit_base_patch16_384': _cfg(
150
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
151
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
152
+ 'deit_tiny_distilled_patch16_224': _cfg(
153
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
154
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
155
+ 'deit_small_distilled_patch16_224': _cfg(
156
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
157
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
158
+ 'deit_base_distilled_patch16_224': _cfg(
159
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
160
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
161
+ 'deit_base_distilled_patch16_384': _cfg(
162
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
163
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
164
+ classifier=('head', 'head_dist')),
165
+
166
+ # ViT ImageNet-21K-P pretraining by MILL
167
+ 'vit_base_patch16_224_miil_in21k': _cfg(
168
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
169
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
170
+ ),
171
+ 'vit_base_patch16_224_miil': _cfg(
172
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
173
+ '/vit_base_patch16_224_1k_miil_84_4.pth',
174
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
175
+ ),
176
+ # PaSST
177
+ 'passt_s_swa_p16_128_ap476': _cfg(
178
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt',
179
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
180
+ classifier=('head.1', 'head_dist'), num_classes=527),
181
+ 'passt_s_kd_p16_128_ap486': _cfg(
182
+ url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt',
183
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
184
+ classifier=('head.1', 'head_dist'), num_classes=527),
185
+ 'passt_l_kd_p16_128_ap47': _cfg(
186
+ url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.10/passt-l-kd-ap.47.pt',
187
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
188
+ classifier=('head.1', 'head_dist'), num_classes=527),
189
+ 'passt_s_swa_p16_128_ap4761': _cfg(
190
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.4761-swa.pt',
191
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
192
+ classifier=('head.1', 'head_dist'), num_classes=527),
193
+ 'passt_s_p16_128_ap472': _cfg(
194
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.472.pt',
195
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
196
+ classifier=('head.1', 'head_dist'), num_classes=527),
197
+ 'passt_s_p16_s16_128_ap468': _cfg(
198
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt',
199
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
200
+ classifier=('head.1', 'head_dist'), num_classes=527),
201
+ 'passt_s_swa_p16_s16_128_ap473': _cfg(
202
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.473-swa.pt',
203
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
204
+ classifier=('head.1', 'head_dist'), num_classes=527),
205
+ 'passt_s_swa_p16_s14_128_ap471': _cfg(
206
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.471-swa.pt',
207
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
208
+ classifier=('head.1', 'head_dist'), num_classes=527),
209
+ 'passt_s_p16_s14_128_ap469': _cfg(
210
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.469.pt',
211
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
212
+ classifier=('head.1', 'head_dist'), num_classes=527),
213
+ 'passt_s_swa_p16_s12_128_ap473': _cfg(
214
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.473-swa.pt',
215
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
216
+ classifier=('head.1', 'head_dist'), num_classes=527),
217
+ 'passt_s_p16_s12_128_ap470': _cfg(
218
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.470.pt',
219
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
220
+ classifier=('head.1', 'head_dist'), num_classes=527),
221
+ 'passt_s_swa_f128_stfthop100_p16_s10_ap473': _cfg(
222
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop100-p16-s10-ap.473-swa.pt',
223
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
224
+ classifier=('head.1', 'head_dist'), num_classes=527),
225
+ 'passt_s_swa_f128_stfthop160_p16_s10_ap473': _cfg(
226
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop160-p16-s10-ap.473-swa.pt',
227
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
228
+ classifier=('head.1', 'head_dist'), num_classes=527),
229
+ 'passt-s-f128-20sec-p16-s10-ap474-swa': _cfg(
230
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-20sec-p16-s10-ap.474-swa.pt',
231
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
232
+ classifier=('head.1', 'head_dist'), num_classes=527),
233
+ 'passt-s-f128-30sec-p16-s10-ap473-swa': _cfg(
234
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-30sec-p16-s10-ap.473-swa.pt',
235
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3000), crop_pct=1.0,
236
+ classifier=('head.1', 'head_dist'), num_classes=527),
237
+ 'openmic2008_passt_u_f128_p16_s10_ap85_swa': _cfg(
238
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85-swa.pt',
239
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
240
+ classifier=('head.1', 'head_dist'), num_classes=20),
241
+ 'openmic2008_passt_u_f128_p16_s10_ap85 ': _cfg(
242
+ url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85.pt',
243
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
244
+ classifier=('head.1', 'head_dist'), num_classes=20),
245
+ }
246
+
247
+
248
+ class Mlp(nn.Module):
249
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
250
+ """
251
+
252
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
253
+ super().__init__()
254
+ out_features = out_features or in_features
255
+ hidden_features = hidden_features or in_features
256
+ self.fc1 = nn.Linear(in_features, hidden_features)
257
+ self.act = act_layer()
258
+ self.fc2 = nn.Linear(hidden_features, out_features)
259
+ self.drop = nn.Dropout(drop)
260
+
261
+ def forward(self, x):
262
+ x = self.fc1(x)
263
+ x = self.act(x)
264
+ x = self.drop(x)
265
+ x = self.fc2(x)
266
+ x = self.drop(x)
267
+ return x
268
+
269
+
270
+ first_RUN = True
271
+
272
+ PLUS1_TRICK = False
273
+
274
+
275
+ class PatchEmbed(nn.Module):
276
+ """ 2D Image to Patch Embedding
277
+ """
278
+
279
+ def __init__(self, img_size=224, in_chans=1, frame_nr=1, stride=1, overlap=1, embed_dim=768, norm_layer=None):
280
+ super().__init__()
281
+ img_size = to_2tuple(img_size)
282
+ frame_nr = frame_nr
283
+ stride = stride
284
+ self.img_size = img_size
285
+ self.frame_nr = frame_nr
286
+ self.stride = stride
287
+ self.seq_len = int(img_size[1]) // frame_nr
288
+ self.num_patches = self.seq_len // stride
289
+ self.embed_dim = embed_dim
290
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=(int(img_size[0]), stride + overlap),
291
+ stride=stride, padding=(0, 1)) # 128 x 2 kernel
292
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
293
+
294
+ def forward(self, x):
295
+ B, C, F, T = x.shape
296
+ if not (F == self.img_size[0] and abs(T - self.img_size[1]) <= 1): # allows for a difference of 1
297
+ warnings.warn(f"Input image size ({F}*{T}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
298
+ x = self.proj(x)[:, :, :, 1:] # B embed_dim 1 T (F=1)
299
+ x = self.norm(x)
300
+ if first_RUN: print("self.norm(x)", x.size())
301
+ return x
302
+
303
+
304
+ class Attention(nn.Module):
305
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
306
+ super().__init__()
307
+ self.num_heads = num_heads
308
+ head_dim = dim // num_heads
309
+ self.scale = head_dim ** -0.5
310
+
311
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
312
+ self.attn_drop = attn_drop
313
+ self.proj = nn.Linear(dim, dim)
314
+ self.proj_drop = nn.Dropout(proj_drop)
315
+
316
+ def forward(self, x):
317
+ B, N, C = x.shape
318
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
319
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
320
+
321
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_drop,
322
+ is_causal=False, scale=self.scale)
323
+
324
+ x = x.transpose(1, 2).reshape(B, N, C)
325
+ x = self.proj(x)
326
+ x = self.proj_drop(x)
327
+ return x
328
+
329
+
330
+ class Block(nn.Module):
331
+
332
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
333
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
334
+ super().__init__()
335
+ self.norm1 = norm_layer(dim)
336
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
337
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
338
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
339
+ self.norm2 = norm_layer(dim)
340
+ mlp_hidden_dim = int(dim * mlp_ratio)
341
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
342
+
343
+ def forward(self, x):
344
+ x = x + self.drop_path(self.attn(self.norm1(x)))
345
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
346
+ return x
347
+
348
+
349
+ class PaSST(nn.Module):
350
+ """
351
+
352
+ Based on the implementation of Vision Transformer in timm library.
353
+ Take a look at the get_model function, adapting the weights of pretrained imagenet models.
354
+
355
+ """
356
+
357
+ def __init__(self, img_size=(128, 998),
358
+ in_chans=1, num_classes=527, embed_dim=768, depth=12,
359
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
360
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
361
+ act_layer=None, weight_init='',
362
+ frame_patchout=300, frame_nr=1, pos_embed_length=1000):
363
+ """
364
+ Args:
365
+ img_size (int, tuple): input image size
366
+ in_chans (int): number of input channels
367
+ num_classes (int): number of classes for classification head
368
+ embed_dim (int): embedding dimension
369
+ depth (int): depth of transformer
370
+ num_heads (int): number of attention heads
371
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
372
+ qkv_bias (bool): enable bias for qkv if True
373
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
374
+ distilled (bool): model includes a distillation token and head as in DeiT models
375
+ drop_rate (float): dropout rate
376
+ attn_drop_rate (float): attention dropout rate
377
+ drop_path_rate (float): stochastic depth rate
378
+ embed_layer (nn.Module): patch embedding layer
379
+ norm_layer: (nn.Module): normalization layer
380
+ act_layer: (nn.Module): activation layer
381
+ weight_init: (str): weight init scheme
382
+ frame_patchout (int): number of frames to patch out
383
+ frame_nr (int): the second dimension of the proj-convolution kernel
384
+ pos_embed_length (int): length of the positional embedding
385
+ """
386
+ super().__init__()
387
+ self.num_classes = num_classes
388
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
389
+ self.num_tokens = 2 if distilled else 1
390
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
391
+ act_layer = act_layer or nn.GELU
392
+ self.act_layer = act_layer()
393
+ self.in_chans = in_chans
394
+ self.frame_patchout = frame_patchout
395
+ self.pos_embed_len = pos_embed_length
396
+
397
+ # these three convolution are different compared to the vanilla passt
398
+ self.conv_in_1 = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
399
+ self.conv_in_2 = nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
400
+ self.conv_in_3 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # 64 instead of 4
401
+ img_size = (img_size[0], pos_embed_length) # 128, 250
402
+
403
+ self.patch_embed = embed_layer(
404
+ img_size=img_size, in_chans=in_chans, frame_nr=frame_nr, stride=frame_nr, embed_dim=embed_dim)
405
+ num_patches = self.patch_embed.num_patches
406
+
407
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
408
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
409
+ # PaSST
410
+ # refer to https://arxiv.org/abs/2110.05069 Section 2
411
+ self.new_pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) # for C and D tokens
412
+ self.freq_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, 1)) # | f
413
+ self.time_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, self.pos_embed_len)) # __ t
414
+ ####
415
+ self.pos_drop = nn.Dropout(p=drop_rate)
416
+
417
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
418
+ self.blocks = nn.Sequential(*[
419
+ Block(
420
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
421
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
422
+ for i in range(depth)])
423
+ self.norm = norm_layer(embed_dim)
424
+
425
+ # Representation layer
426
+ if representation_size and not distilled:
427
+ self.num_features = representation_size
428
+ self.pre_logits = nn.Sequential(OrderedDict([
429
+ ('fc', nn.Linear(embed_dim, representation_size)),
430
+ ('act', nn.Tanh())
431
+ ]))
432
+ else:
433
+ self.pre_logits = nn.Identity()
434
+
435
+ self.init_weights(weight_init)
436
+
437
+ def init_weights(self, mode=''):
438
+ assert mode in ('jax', 'jax_nlhb', 'nlhb', ''), f"mode: {mode}"
439
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
440
+ trunc_normal_(self.new_pos_embed, std=.02)
441
+ trunc_normal_(self.freq_new_pos_embed, std=.02)
442
+ trunc_normal_(self.time_new_pos_embed, std=.02)
443
+ if self.dist_token is not None:
444
+ trunc_normal_(self.dist_token, std=.02)
445
+ if mode.startswith('jax'):
446
+ # leave cls token as zeros to match jax impl
447
+ raise RuntimeError("Not supported yet")
448
+ else:
449
+ trunc_normal_(self.cls_token, std=.02)
450
+ self.apply(_init_vit_weights)
451
+
452
+ def _init_weights(self, m):
453
+ # this fn left here for compat with downstream users
454
+ _init_vit_weights(m)
455
+
456
+ @torch.jit.ignore
457
+ def no_weight_decay(self):
458
+ return {'new_pos_embed', 'freq_new_pos_embed', 'time_new_pos_embed', 'cls_token', 'dist_token'}
459
+
460
+ def forward_features(self, x):
461
+ global first_RUN # not jit friendly? use trace instead
462
+
463
+ # some 2D convolutions
464
+ f_dim = x.size(2) # 128
465
+ x = self.act_layer(self.conv_in_1(x))
466
+ x = self.act_layer(self.conv_in_2(x))
467
+ x = self.act_layer(self.conv_in_3(x))
468
+ if first_RUN: print("after convs", x.size())
469
+ x = x.reshape(x.shape[0], (x.shape[1] * x.shape[2]) // f_dim, f_dim, x.shape[3])
470
+ if first_RUN: print("after reshape", x.size())
471
+
472
+ x = self.patch_embed(x) # [b, e, f, t]
473
+ B_dim, E_dim, F_dim, T_dim = x.shape # slow
474
+ if first_RUN: print(" patch_embed : ", x.shape)
475
+ # Adding Time/Freq information
476
+ if first_RUN: print(" self.time_new_pos_embed.shape", self.time_new_pos_embed.shape)
477
+ time_new_pos_embed = self.time_new_pos_embed
478
+ if x.shape[-1] < time_new_pos_embed.shape[-1]:
479
+ if self.training:
480
+ toffset = torch.randint(1 + time_new_pos_embed.shape[-1] - x.shape[-1], (1,)).item()
481
+ if first_RUN: print(f" CUT with randomoffset={toffset} time_new_pos_embed.shape",
482
+ time_new_pos_embed.shape)
483
+ time_new_pos_embed = time_new_pos_embed[:, :, :, toffset:toffset + x.shape[-1]]
484
+ else:
485
+ time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]]
486
+ if first_RUN: print(" CUT time_new_pos_embed.shape", time_new_pos_embed.shape)
487
+ else:
488
+ # warnings.warn(
489
+ # f"the patches shape:{x.shape} are larger than the expected time encodings {time_new_pos_embed.shape}, x will be cut")
490
+ x = x[:, :, :, :time_new_pos_embed.shape[-1]]
491
+ x = x + time_new_pos_embed
492
+ if first_RUN: print(" self.freq_new_pos_embed.shape", self.freq_new_pos_embed.shape)
493
+ x = x + self.freq_new_pos_embed
494
+
495
+ # Structured Patchout https://arxiv.org/abs/2110.05069 Section 2.2
496
+ if self.training and self.frame_patchout:
497
+ if first_RUN: print(f"X Before frame Patchout of {self.frame_patchout} ", x.size())
498
+ # ([1, 768, 1, 82])
499
+ random_indices = torch.randperm(T_dim)[:T_dim - self.frame_patchout].sort().values
500
+ x = x[:, :, :, random_indices]
501
+ if first_RUN: print("X after frame Patchout", x.size())
502
+
503
+ x = x.flatten(2).transpose(1, 2)
504
+
505
+ # Add the C/D tokens
506
+ if first_RUN: print(" self.new_pos_embed.shape", self.new_pos_embed.shape)
507
+ cls_tokens = self.cls_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, :1, :]
508
+ if first_RUN: print(" self.cls_tokens.shape", cls_tokens.shape)
509
+ if self.dist_token is None:
510
+ x = torch.cat((cls_tokens, x), dim=1)
511
+ else:
512
+ dist_token = self.dist_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, 1:, :]
513
+ if first_RUN: print(" self.dist_token.shape", dist_token.shape)
514
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
515
+
516
+ if first_RUN: print(" final sequence x", x.shape)
517
+ x = self.pos_drop(x)
518
+ x = self.blocks(x)
519
+ if first_RUN: print(f" after {len(self.blocks)} atten blocks x", x.shape)
520
+ x = self.norm(x)
521
+ return x
522
+
523
+ def forward(self, x):
524
+ global first_RUN
525
+ if first_RUN: print("x", x.size())
526
+ x = self.forward_features(x)
527
+ c, x = x[:, :2].mean(1), x[:, 2:]
528
+ if first_RUN: print("x after forward_features", x.size())
529
+ first_RUN = False
530
+ return x
531
+
532
+ def load_model(self, path, wandb_id):
533
+ ckpt_path = os.path.join(path, wandb_id + ".ckpt")
534
+
535
+ pretrained_weights = torch.load(ckpt_path, map_location="cpu")["state_dict"]
536
+ pretrained_weights = {k[10:]: v for k, v in pretrained_weights.items() if k[:10] == "net.model."}
537
+ self.load_state_dict(pretrained_weights)
538
+
539
+ print("Loaded model successfully. Wandb_id:", wandb_id)
540
+
541
+
542
+ def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
543
+ """ ViT weight initialization
544
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
545
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
546
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
547
+ """
548
+ if isinstance(module, nn.Linear):
549
+ if name.startswith('head'):
550
+ nn.init.zeros_(module.weight)
551
+ nn.init.constant_(module.bias, head_bias)
552
+ elif name.startswith('pre_logits'):
553
+ lecun_normal_(module.weight)
554
+ nn.init.zeros_(module.bias)
555
+ else:
556
+ if jax_impl:
557
+ nn.init.xavier_uniform_(module.weight)
558
+ if module.bias is not None:
559
+ if 'mlp' in name:
560
+ nn.init.normal_(module.bias, std=1e-6)
561
+ else:
562
+ nn.init.zeros_(module.bias)
563
+ else:
564
+ trunc_normal_(module.weight, std=.02)
565
+ if module.bias is not None:
566
+ nn.init.zeros_(module.bias)
567
+ elif jax_impl and isinstance(module, nn.Conv2d):
568
+ # NOTE conv was left to pytorch default in my original init
569
+ lecun_normal_(module.weight)
570
+ if module.bias is not None:
571
+ nn.init.zeros_(module.bias)
572
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
573
+ nn.init.zeros_(module.bias)
574
+ nn.init.ones_(module.weight)
575
+
576
+
577
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), mode='bicubic'):
578
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
579
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
580
+ _logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, posemb_new.shape,
581
+ num_tokens)
582
+ ntok_new = posemb_new.shape[1]
583
+ if num_tokens:
584
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
585
+ ntok_new -= num_tokens
586
+ else:
587
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
588
+ gs_old = int(math.sqrt(len(posemb_grid)))
589
+ if not len(gs_new): # backwards compatibility
590
+ gs_new = [int(math.sqrt(ntok_new))] * 2
591
+ assert len(gs_new) >= 2
592
+ _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
593
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
594
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False)
595
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
596
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
597
+ return posemb
598
+
599
+
600
+ def adapt_image_pos_embed_to_passt(posemb, num_tokens=1, posemb_len=1000, mode='bicubic'):
601
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
602
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
603
+ if num_tokens:
604
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
605
+ else:
606
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
607
+ gs_old = int(math.sqrt(len(posemb_grid)))
608
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
609
+ posemb_grid = F.interpolate(posemb_grid, size=(1, posemb_len), mode=mode, align_corners=False)
610
+
611
+ freq_new_pos_embed = posemb_grid.mean(dim=3, keepdim=True)
612
+ time_new_pos_embed = posemb_grid.mean(dim=2, keepdim=True)
613
+ _logger.info('New Position cls/dstl embedding %s', posemb_tok.shape)
614
+ _logger.info('New FREQ Position embedding %s', freq_new_pos_embed.shape)
615
+ _logger.info('New TIME Position embedding %s', time_new_pos_embed.shape)
616
+ return posemb_tok, freq_new_pos_embed, time_new_pos_embed
617
+
618
+
619
+ def checkpoint_filter_fn(state_dict, model):
620
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
621
+ out_dict = {}
622
+ if 'model' in state_dict:
623
+ # For deit models
624
+ state_dict = state_dict['model']
625
+ state_dict = {k: v for k, v in state_dict.items()}
626
+ if "time_new_pos_embed" not in state_dict:
627
+ # we are working with ImageNet model
628
+ _logger.info("Adapting pos embedding from ImageNet pretrained model to PaSST.")
629
+ v = state_dict.pop("pos_embed")
630
+ new_pos_embed, freq_new_pos_embed, time_new_pos_embed = adapt_image_pos_embed_to_passt(
631
+ v, getattr(model, 'num_tokens', 1), model.pos_embed_len)
632
+ state_dict["new_pos_embed"] = new_pos_embed
633
+ state_dict["freq_new_pos_embed"] = freq_new_pos_embed
634
+ state_dict["time_new_pos_embed"] = time_new_pos_embed
635
+
636
+ for k, v in state_dict.items():
637
+ if 'patch_embed.proj.weight' in k:
638
+ embed_dim, C, H, W = v.shape
639
+ v = adapt_input_conv(model.in_chans, v, input_conv_name=k)
640
+ k1, k2 = model.patch_embed.proj.kernel_size # 128, 2
641
+
642
+ # clever reshape
643
+ assert H * W == k1 * k2, "Error in the kernel size of the patch embedding"
644
+
645
+ v = v.reshape(embed_dim, model.in_chans, k1, k2) # [embed_dim, 1, k1, k2]
646
+
647
+ out_dict[k] = v
648
+ return out_dict
649
+
650
+
651
+ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
652
+ default_cfg = default_cfg or default_cfgs[variant]
653
+ if kwargs.get('features_only', None):
654
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
655
+
656
+ # NOTE this extra code to support handling of repr size for in21k pretrained models
657
+ default_num_classes = default_cfg['num_classes']
658
+ num_classes = kwargs.get('num_classes', default_num_classes)
659
+ repr_size = kwargs.pop('representation_size', None)
660
+ if repr_size is not None and num_classes != default_num_classes:
661
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
662
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
663
+ _logger.warning("Removing representation layer for fine-tuning.")
664
+ repr_size = None
665
+
666
+ model = build_model_with_cfg(
667
+ PaSST, variant, pretrained,
668
+ default_cfg=default_cfg,
669
+ representation_size=repr_size,
670
+ pretrained_filter_fn=checkpoint_filter_fn,
671
+ pretrained_custom_load='npz' in default_cfg['url'],
672
+ **kwargs)
673
+ return model
674
+
675
+
676
+ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
677
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
678
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
679
+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
680
+ """
681
+ model_kwargs = dict(
682
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
683
+ model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
684
+ return model
685
+
686
+
687
+ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
688
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
689
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
690
+ """
691
+
692
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
693
+ model = _create_vision_transformer(
694
+ 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
695
+ return model
696
+
697
+
698
+ def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):
699
+ """ PaSST pre-trained on AudioSet
700
+ """
701
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=476 SWA \n\n")
702
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
703
+ if model_kwargs.get("stride") != (10, 10):
704
+ warnings.warn(
705
+ f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
706
+ model = _create_vision_transformer(
707
+ 'passt_s_swa_p16_128_ap476', pretrained=pretrained, distilled=True, **model_kwargs)
708
+ return model
709
+
710
+
711
+ def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
712
+ """ PaSST pre-trained on AudioSet
713
+ """
714
+ print("\n\n Loading PaSST pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=486 \n\n")
715
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
716
+ if model_kwargs.get("stride") != (10, 10):
717
+ warnings.warn(
718
+ f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
719
+ model = _create_vision_transformer(
720
+ 'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs)
721
+ return model
722
+
723
+
724
+ def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs):
725
+ """ PaSST pre-trained on AudioSet
726
+ """
727
+ print(
728
+ "\n\n Loading PaSST-L (light, reduced depth=7) pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=4708 \n\n")
729
+ model_kwargs = dict(patch_size=16, embed_dim=768,
730
+ depth=7, num_heads=12, **kwargs)
731
+ if model_kwargs.get("stride") != (10, 10):
732
+ warnings.warn(
733
+ f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
734
+ model = _create_vision_transformer(
735
+ 'passt_l_kd_p16_128_ap47', pretrained=pretrained, distilled=True, **model_kwargs)
736
+ return model
737
+
738
+
739
+ def passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs):
740
+ """ PaSST pre-trained on AudioSet
741
+ """
742
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=4763 SWA \n\n")
743
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
744
+ if model_kwargs.get("stride") != (10, 10):
745
+ warnings.warn(
746
+ f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
747
+ model = _create_vision_transformer(
748
+ 'passt_s_swa_p16_128_ap4761', pretrained=pretrained, distilled=True, **model_kwargs)
749
+ return model
750
+
751
+
752
+ def passt_s_p16_128_ap472(pretrained=False, **kwargs):
753
+ """ PaSST pre-trained on AudioSet
754
+ """
755
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=472 \n\n")
756
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
757
+ if model_kwargs.get("stride") != (10, 10):
758
+ warnings.warn(
759
+ f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
760
+ model = _create_vision_transformer(
761
+ 'passt_s_p16_128_ap472', pretrained=pretrained, distilled=True, **model_kwargs)
762
+ return model
763
+
764
+
765
+ def passt_s_p16_s12_128_ap470(pretrained=False, **kwargs):
766
+ """ PaSST pre-trained on AudioSet
767
+ """
768
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n")
769
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
770
+ if model_kwargs.get("stride") != (12, 12):
771
+ warnings.warn(
772
+ f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
773
+ model = _create_vision_transformer(
774
+ 'passt_s_p16_s12_128_ap470', pretrained=pretrained, distilled=True, **model_kwargs)
775
+ return model
776
+
777
+
778
+ def passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs):
779
+ print("\n\n Loading PASST TRAINED ON AUDISET with 20 Second time encodings, with STFT hop of 160 \n\n")
780
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
781
+ model = _create_vision_transformer(
782
+ 'passt-s-f128-20sec-p16-s10-ap474-swa', pretrained=pretrained, distilled=True, **model_kwargs)
783
+ return model
784
+
785
+
786
+ def passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs):
787
+ print("\n\n Loading PASST TRAINED ON AUDISET with 30 Second time encodings, with STFT hop of 160 \n\n")
788
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
789
+ model = _create_vision_transformer(
790
+ 'passt-s-f128-30sec-p16-s10-ap473-swa', pretrained=pretrained, distilled=True, **model_kwargs)
791
+ return model
792
+
793
+
794
+ def passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs):
795
+ """ PaSST pre-trained on AudioSet
796
+ """
797
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n")
798
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
799
+ if model_kwargs.get("stride") != (12, 12):
800
+ warnings.warn(
801
+ f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
802
+ model = _create_vision_transformer(
803
+ 'passt_s_swa_p16_s12_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)
804
+ return model
805
+
806
+
807
+ def passt_s_p16_s14_128_ap469(pretrained=False, **kwargs):
808
+ """ PaSST pre-trained on AudioSet
809
+ """
810
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n")
811
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
812
+ if model_kwargs.get("stride") != (14, 14):
813
+ warnings.warn(
814
+ f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
815
+ model = _create_vision_transformer(
816
+ 'passt_s_p16_s14_128_ap469', pretrained=pretrained, distilled=True, **model_kwargs)
817
+ return model
818
+
819
+
820
+ def passt_s_swa_p16_s14_128_ap471(pretrained=False, **kwargs):
821
+ """ PaSST pre-trained on AudioSet
822
+ """
823
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n")
824
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
825
+ if model_kwargs.get("stride") != (14, 14):
826
+ warnings.warn(
827
+ f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
828
+ model = _create_vision_transformer(
829
+ 'passt_s_swa_p16_s14_128_ap471', pretrained=pretrained, distilled=True, **model_kwargs)
830
+ return model
831
+
832
+
833
+ def passt_s_swa_p16_s16_128_ap473(pretrained=False, **kwargs):
834
+ """ PaSST pre-trained on AudioSet
835
+ """
836
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n")
837
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
838
+ if model_kwargs.get("stride") != (16, 16):
839
+ warnings.warn(
840
+ f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
841
+ model = _create_vision_transformer(
842
+ 'passt_s_swa_p16_s16_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)
843
+ return model
844
+
845
+
846
+ def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):
847
+ """ PaSST pre-trained on AudioSet
848
+ """
849
+ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n")
850
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
851
+ if model_kwargs.get("stride") != (16, 16):
852
+ warnings.warn(
853
+ f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
854
+ model = _create_vision_transformer(
855
+ 'passt_s_p16_s16_128_ap468', pretrained=pretrained, distilled=True, **model_kwargs)
856
+ return model
857
+
858
+
859
+ def fix_embedding_layer(model, embed="default"):
860
+ if embed == "default":
861
+ return model
862
+ if embed == "overlap":
863
+ model.patch_embed = PatchEmbedAdaptiveMean(replace=model.patch_embed)
864
+ if embed == "am_keepconv":
865
+ model.patch_embed = PatchEmbedAdaptiveMeanKeepConv(replace=model.patch_embed)
866
+ return model
867
+
868
+
869
+ def lighten_model(model, cut_depth=0):
870
+ if cut_depth == 0:
871
+ return model
872
+ if cut_depth:
873
+ if cut_depth < 0:
874
+ print(f"\n Reducing model depth by removing every {-cut_depth} layer \n\n")
875
+ else:
876
+ print(f"\n Reducing model depth by {cut_depth} \n\n")
877
+ if len(model.blocks) < cut_depth + 2:
878
+ raise ValueError(f"Cut depth a VIT with {len(model.blocks)} "
879
+ f"layers should be between 1 and {len(model.blocks) - 2}")
880
+ print(f"\n Before Cutting it was {len(model.blocks)} \n\n")
881
+
882
+ old_blocks = list(model.blocks.children())
883
+ if cut_depth < 0:
884
+ print(f"cut_depth={cut_depth}")
885
+ old_blocks = [old_blocks[0]] + old_blocks[1:-1:-cut_depth] + [old_blocks[-1]]
886
+ else:
887
+ old_blocks = [old_blocks[0]] + old_blocks[cut_depth + 1:]
888
+ model.blocks = nn.Sequential(*old_blocks)
889
+ print(f"\n Atfer Cutting it is {len(model.blocks)} \n\n")
890
+ return model
891
+
892
+
893
+ def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, in_channels=1,
894
+ input_fdim=128, input_tdim=998, frame_patchout=300, pos_embed_length=1000
895
+ ):
896
+ """
897
+ :param arch: Base ViT or Deit architecture
898
+ :param pretrained: use pretrained model on imagenet
899
+ :param n_classes: number of classes
900
+ :param in_channels: number of input channels: 1 for mono
901
+ :param input_fdim: the expected input frequency bins.
902
+ :param input_tdim: the expected input time bins.
903
+ :param frame_patchout: the number of frames to be removed from the input
904
+ @param wandb_id: tries to load model with corresponding wandb_id from 'pretrained_path'
905
+ :return:
906
+
907
+ """
908
+ model_func = None
909
+ input_size = (input_fdim, input_tdim)
910
+ if arch == "passt_deit_bd_p16_384": # base deit
911
+ model_func = deit_base_distilled_patch16_384
912
+ elif arch == "passt_s_kd_p16_128_ap486": # pretrained
913
+ model_func = passt_s_kd_p16_128_ap486
914
+ elif arch == "passt_l_kd_p16_128_ap47": # pretrained passt-L
915
+ model_func = passt_l_kd_p16_128_ap47
916
+ elif arch == "passt_s_swa_p16_128_ap476": # pretrained
917
+ model_func = passt_s_swa_p16_128_ap476
918
+ elif arch == "passt_s_swa_p16_128_ap4761":
919
+ model_func = passt_s_swa_p16_128_ap4761
920
+ elif arch == "passt_s_p16_128_ap472":
921
+ model_func = passt_s_p16_128_ap472
922
+ elif arch == "passt_s_p16_s16_128_ap468":
923
+ model_func = passt_s_p16_s16_128_ap468
924
+ elif arch == "passt_s_swa_p16_s16_128_ap473":
925
+ model_func = passt_s_swa_p16_s16_128_ap473
926
+ elif arch == "passt_s_swa_p16_s14_128_ap471":
927
+ model_func = passt_s_swa_p16_s14_128_ap471
928
+ elif arch == "passt_s_p16_s14_128_ap469":
929
+ model_func = passt_s_p16_s14_128_ap469
930
+ elif arch == "passt_s_swa_p16_s12_128_ap473":
931
+ model_func = passt_s_swa_p16_s12_128_ap473
932
+ elif arch == "passt_s_p16_s12_128_ap470":
933
+ model_func = passt_s_p16_s12_128_ap470
934
+ elif arch == "passt_s_f128_20sec_p16_s10_ap474":
935
+ model_func = passt_s_f128_20sec_p16_s10_ap474_swa
936
+ elif arch == "passt_s_f128_30sec_p16_s10_ap473":
937
+ model_func = passt_s_f128_30sec_p16_s10_ap473_swa
938
+
939
+ if model_func is None:
940
+ raise RuntimeError(f"Unknown model {arch}")
941
+ model = model_func(pretrained=pretrained, num_classes=n_classes, in_chans=in_channels,
942
+ img_size=input_size, frame_patchout=frame_patchout, pos_embed_length=pos_embed_length)
943
+ model = fix_embedding_layer(model)
944
+ model = lighten_model(model)
945
+ return model
946
+
947
+
948
+ class EnsembelerModel(nn.Module):
949
+ def __init__(self, models):
950
+ super(EnsembelerModel, self).__init__()
951
+ self.models = nn.ModuleList(models)
952
+
953
+ def forward(self, x):
954
+ # ModuleList can act as an iterable, or be indexed using ints
955
+ all_out = None
956
+ for i, m in enumerate(self.models):
957
+ out, _ = m(x)
958
+ if all_out is None:
959
+ all_out = out
960
+ else:
961
+ all_out = out + all_out
962
+ all_out = all_out / len(self.models)
963
+ return all_out, all_out
models/frame_passt/fpasst_wrapper.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.frame_passt.fpasst import get_model
2
+ from models.frame_passt.preprocess import AugmentMelSTFT
3
+ from models.transformer_wrapper import BaseModelWrapper
4
+
5
+
6
+ class FPaSSTWrapper(BaseModelWrapper):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.mel = AugmentMelSTFT(
10
+ n_mels=128,
11
+ sr=16_000,
12
+ win_length=400,
13
+ hopsize=160,
14
+ n_fft=512,
15
+ freqm=0,
16
+ timem=0,
17
+ htk=False,
18
+ fmin=0.0,
19
+ fmax=None,
20
+ norm=1,
21
+ fmin_aug_range=10,
22
+ fmax_aug_range=2000,
23
+ fast_norm=True,
24
+ preamp=True,
25
+ )
26
+ self.fpasst = get_model(
27
+ arch="passt_deit_bd_p16_384",
28
+ n_classes=527,
29
+ pos_embed_length=250,
30
+ frame_patchout=0,
31
+ in_channels=16
32
+ )
33
+
34
+ def mel_forward(self, x):
35
+ return self.mel(x)
36
+
37
+ def forward(self, x):
38
+ return self.fpasst(x)
39
+
40
+ def separate_params(self):
41
+ pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
42
+ for k, p in self.fpasst.named_parameters():
43
+ if k in ['cls_token',
44
+ 'dist_token',
45
+ 'new_pos_embed',
46
+ 'freq_new_pos_embed',
47
+ 'time_new_pos_embed',
48
+ 'conv_in_1.weight',
49
+ 'conv_in_1.bias',
50
+ 'conv_in_2.weight',
51
+ 'conv_in_2.bias',
52
+ 'conv_in_3.weight',
53
+ 'conv_in_3.bias',
54
+ 'patch_embed.proj.weight',
55
+ 'patch_embed.proj.bias',
56
+ ]:
57
+ pt_params[0].append(p)
58
+ elif 'blocks.0.' in k:
59
+ pt_params[0].append(p)
60
+ elif 'blocks.1.' in k:
61
+ pt_params[1].append(p)
62
+ elif 'blocks.2.' in k:
63
+ pt_params[2].append(p)
64
+ elif 'blocks.3.' in k:
65
+ pt_params[3].append(p)
66
+ elif 'blocks.4.' in k:
67
+ pt_params[4].append(p)
68
+ elif 'blocks.5.' in k:
69
+ pt_params[5].append(p)
70
+ elif 'blocks.6.' in k:
71
+ pt_params[6].append(p)
72
+ elif 'blocks.7.' in k:
73
+ pt_params[7].append(p)
74
+ elif 'blocks.8.' in k:
75
+ pt_params[8].append(p)
76
+ elif 'blocks.9.' in k:
77
+ pt_params[9].append(p)
78
+ elif 'blocks.10.' in k:
79
+ pt_params[10].append(p)
80
+ elif 'blocks.11.' in k:
81
+ pt_params[11].append(p)
82
+ elif k in ['norm.weight', 'norm.bias']:
83
+ pt_params[11].append(p)
84
+ else:
85
+ raise ValueError(f"Check separate params for frame-passt! Unexpected key: {k}")
86
+ return list(reversed(pt_params))
models/frame_passt/preprocess.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+
5
+ sz_float = 4 # size of a float
6
+ epsilon = 10e-8 # fudge factor for normalization
7
+
8
+
9
+ class AugmentMelSTFT(nn.Module):
10
+ def __init__(
11
+ self,
12
+ n_mels=128,
13
+ sr=32000,
14
+ win_length=None,
15
+ hopsize=320,
16
+ n_fft=1024,
17
+ freqm=0,
18
+ timem=0,
19
+ htk=False,
20
+ fmin=0.0,
21
+ fmax=None,
22
+ norm=1,
23
+ fmin_aug_range=1,
24
+ fmax_aug_range=1,
25
+ fast_norm=False,
26
+ preamp=True,
27
+ padding="center",
28
+ periodic_window=True,
29
+ ):
30
+ torch.nn.Module.__init__(self)
31
+ # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e
32
+ # Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast
33
+
34
+ if win_length is None:
35
+ win_length = n_fft
36
+
37
+ if isinstance(win_length, list) or isinstance(win_length, tuple):
38
+ assert isinstance(n_fft, list) or isinstance(n_fft, tuple)
39
+ assert len(win_length) == len(n_fft)
40
+ else:
41
+ win_length = [win_length]
42
+ n_fft = [n_fft]
43
+
44
+ self.win_length = win_length
45
+ self.n_mels = n_mels
46
+ self.n_fft = n_fft
47
+ self.sr = sr
48
+ self.htk = htk
49
+ self.fmin = fmin
50
+ if fmax is None:
51
+ fmax = sr // 2 - fmax_aug_range // 2
52
+ self.fmax = fmax
53
+ self.norm = norm
54
+ self.hopsize = hopsize
55
+ self.preamp = preamp
56
+ for win_l in self.win_length:
57
+ self.register_buffer(
58
+ f"window_{win_l}",
59
+ torch.hann_window(win_l, periodic=periodic_window),
60
+ persistent=False,
61
+ )
62
+ assert (
63
+ fmin_aug_range >= 1
64
+ ), f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
65
+ assert (
66
+ fmin_aug_range >= 1
67
+ ), f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
68
+ self.fmin_aug_range = fmin_aug_range
69
+ self.fmax_aug_range = fmax_aug_range
70
+
71
+ self.register_buffer(
72
+ "preemphasis_coefficient", torch.as_tensor([[[-0.97, 1]]]), persistent=False
73
+ )
74
+ if freqm == 0:
75
+ self.freqm = torch.nn.Identity()
76
+ else:
77
+ self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=False)
78
+ if timem == 0:
79
+ self.timem = torch.nn.Identity()
80
+ else:
81
+ self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=False)
82
+ self.fast_norm = fast_norm
83
+ self.padding = padding
84
+ if padding not in ["center", "same"]:
85
+ raise ValueError("Padding must be 'center' or 'same'.")
86
+ self.iden = nn.Identity()
87
+
88
+ def forward(self, x):
89
+ if self.preamp:
90
+ x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient)
91
+ x = x.squeeze(1)
92
+
93
+ fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
94
+ fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
95
+
96
+ # don't augment eval data
97
+ if not self.training:
98
+ fmin = self.fmin
99
+ fmax = self.fmax
100
+
101
+ mels = []
102
+ for n_fft, win_length in zip(self.n_fft, self.win_length):
103
+ x_temp = x
104
+ if self.padding == "same":
105
+ pad = win_length - self.hopsize
106
+ self.iden(x_temp) # printing
107
+ x_temp = torch.nn.functional.pad(x_temp, (pad // 2, pad // 2), mode="reflect")
108
+ self.iden(x_temp) # printing
109
+
110
+ x_temp = torch.stft(
111
+ x_temp,
112
+ n_fft,
113
+ hop_length=self.hopsize,
114
+ win_length=win_length,
115
+ center=self.padding == "center",
116
+ normalized=False,
117
+ window=getattr(self, f"window_{win_length}"),
118
+ return_complex=True
119
+ )
120
+ x_temp = torch.view_as_real(x_temp)
121
+ x_temp = (x_temp ** 2).sum(dim=-1) # power mag
122
+
123
+ mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, n_fft, self.sr,
124
+ fmin, fmax, vtln_low=100.0, vtln_high=-500.,
125
+ vtln_warp_factor=1.0)
126
+ mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
127
+ device=x.device)
128
+
129
+ with torch.cuda.amp.autocast(enabled=False):
130
+ x_temp = torch.matmul(mel_basis, x_temp)
131
+
132
+ x_temp = torch.log(torch.clip(x_temp, min=1e-7))
133
+
134
+ mels.append(x_temp)
135
+
136
+ mels = torch.stack(mels, dim=1)
137
+
138
+ if self.training:
139
+ mels = self.freqm(mels)
140
+ mels = self.timem(mels)
141
+ if self.fast_norm:
142
+ mels = (mels + 4.5) / 5.0 # fast normalization
143
+
144
+ return mels
145
+
146
+ def extra_repr(self):
147
+ return "winsize={}, hopsize={}".format(self.win_length, self.hopsize)
models/frame_passt/vit_helpers.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
3
+ Credit to @leo19941227 for remove timm dependencies here : https://github.com/s3prl/passt_hear21/blob/48a0dc1b824641ca59884ced53f5b86053fed141/hear21passt/models/helpers/vit_helpers.py
4
+
5
+ """
6
+ import math
7
+ import logging
8
+ import warnings
9
+ from copy import deepcopy
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from timm.models._hub import download_cached_file
15
+
16
+
17
+ # Global variables for rarely used pretrained checkpoint download progress and hash check.
18
+ # Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
19
+ _DOWNLOAD_PROGRESS = True
20
+ _CHECK_HASH = False
21
+
22
+
23
+ _logger = logging.getLogger(__name__)
24
+
25
+
26
+ def adapt_input_conv(in_chans, conv_weight, input_conv_name="(name not given)"):
27
+ conv_type = conv_weight.dtype
28
+ conv_weight = (
29
+ conv_weight.float()
30
+ ) # Some weights are in torch.half, ensure it's float for sum on CPU
31
+ O, I, J, K = conv_weight.shape
32
+ if in_chans == 1:
33
+ print(f"adapt_input_conv: Converted from {I} to 1 channel")
34
+ if I > 3:
35
+ assert conv_weight.shape[1] % 3 == 0
36
+ # For models with space2depth stems
37
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
38
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
39
+ else:
40
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
41
+ elif in_chans != 3:
42
+ if I != 3:
43
+ # loading a model pretrained on AudioSet for the downstream-task
44
+ if I == in_chans:
45
+ print(f"adapt_input_conv: Loading pretrained weights for {input_conv_name}, "
46
+ f"Assuming same input-conv and proj-conv configuration (1:1).")
47
+ pass
48
+ else:
49
+ print(f"adapt_input_conv: Converted input conv {input_conv_name} weights from 3 to {in_chans} channel(s)")
50
+ # NOTE this strategy should be better than random init, but there could be other combinations of
51
+ # the original RGB input layer weights that'd work better for specific cases.
52
+ repeat = int(math.ceil(in_chans / 3))
53
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
54
+ conv_weight *= 3 / float(in_chans)
55
+ conv_weight = conv_weight.to(conv_type)
56
+ return conv_weight
57
+
58
+
59
+ def load_pretrained(
60
+ model,
61
+ default_cfg=None,
62
+ num_classes=1000,
63
+ in_chans=3,
64
+ filter_fn=None,
65
+ strict=True,
66
+ progress=False,
67
+ ):
68
+ """Load pretrained checkpoint
69
+
70
+ Args:
71
+ model (nn.Module) : PyTorch model module
72
+ default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset
73
+ num_classes (int): num_classes for model
74
+ in_chans (int): in_chans for model
75
+ filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
76
+ strict (bool): strict load of checkpoint
77
+ progress (bool): enable progress bar for weight download
78
+
79
+ """
80
+ default_cfg = default_cfg or getattr(model, "default_cfg", None) or {}
81
+ pretrained_url = default_cfg.get("url", None)
82
+
83
+ if not pretrained_url:
84
+ _logger.warning(
85
+ "No pretrained weights exist for this model. Using random initialization."
86
+ )
87
+ return
88
+
89
+ _logger.info(f"Loading pretrained weights from url ({pretrained_url})")
90
+ pretrained_loc = download_cached_file(
91
+ pretrained_url,
92
+ check_hash=_CHECK_HASH,
93
+ progress=_DOWNLOAD_PROGRESS,
94
+ )
95
+
96
+ state_dict = torch.load(pretrained_loc, map_location="cpu")
97
+
98
+ if filter_fn is not None:
99
+ # for backwards compat with filter fn that take one arg, try one first, the two
100
+ try:
101
+ state_dict = filter_fn(state_dict)
102
+ except TypeError:
103
+ state_dict = filter_fn(state_dict, model)
104
+
105
+ input_convs = default_cfg.get("first_conv", None)
106
+ if input_convs is not None and in_chans != 3:
107
+ if isinstance(input_convs, str):
108
+ input_convs = (input_convs,)
109
+ for input_conv_name in input_convs:
110
+ weight_name = input_conv_name + ".weight"
111
+ try:
112
+ state_dict[weight_name] = adapt_input_conv(
113
+ in_chans, state_dict[weight_name], input_conv_name
114
+ )
115
+ # _logger.info(
116
+ # f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)"
117
+ # )
118
+ except (NotImplementedError, KeyError) as e:
119
+ if weight_name in state_dict:
120
+ del state_dict[weight_name]
121
+ strict = False
122
+ _logger.warning(
123
+ f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer."
124
+ )
125
+
126
+ classifiers = default_cfg.get("classifier", None)
127
+ label_offset = default_cfg.get("label_offset", 0)
128
+ if classifiers is not None:
129
+ if isinstance(classifiers, str):
130
+ classifiers = (classifiers,)
131
+ if num_classes != default_cfg["num_classes"]:
132
+ for classifier_name in classifiers:
133
+ # completely discard fully connected if model num_classes doesn't match pretrained weights
134
+ del state_dict[classifier_name + ".weight"]
135
+ del state_dict[classifier_name + ".bias"]
136
+ strict = False
137
+ elif label_offset > 0:
138
+ for classifier_name in classifiers:
139
+ # special case for pretrained weights with an extra background class in pretrained weights
140
+ classifier_weight = state_dict[classifier_name + ".weight"]
141
+ state_dict[classifier_name + ".weight"] = classifier_weight[
142
+ label_offset:
143
+ ]
144
+ classifier_bias = state_dict[classifier_name + ".bias"]
145
+ state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:]
146
+
147
+ model.load_state_dict(state_dict, strict=strict)
148
+
149
+
150
+ def overlay_external_default_cfg(default_cfg, kwargs):
151
+ """Overlay 'external_default_cfg' in kwargs on top of default_cfg arg."""
152
+ external_default_cfg = kwargs.pop("external_default_cfg", None)
153
+ if external_default_cfg:
154
+ default_cfg.pop("url", None) # url should come from external cfg
155
+ default_cfg.pop("hf_hub", None) # hf hub id should come from external cfg
156
+ default_cfg.update(external_default_cfg)
157
+
158
+
159
+ def filter_kwargs(kwargs, names):
160
+ if not kwargs or not names:
161
+ return
162
+ for n in names:
163
+ kwargs.pop(n, None)
164
+
165
+
166
+ def set_default_kwargs(kwargs, names, default_cfg):
167
+ for n in names:
168
+ # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
169
+ # default_cfg has one input_size=(C, H ,W) entry
170
+ if n == "img_size":
171
+ input_size = default_cfg.get("input_size", None)
172
+ if input_size is not None:
173
+ assert len(input_size) == 3
174
+ kwargs.setdefault(n, input_size[-2:])
175
+ elif n == "in_chans":
176
+ input_size = default_cfg.get("input_size", None)
177
+ if input_size is not None:
178
+ assert len(input_size) == 3
179
+ kwargs.setdefault(n, input_size[0])
180
+ else:
181
+ default_val = default_cfg.get(n, None)
182
+ if default_val is not None:
183
+ kwargs.setdefault(n, default_cfg[n])
184
+
185
+
186
+ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
187
+ """Update the default_cfg and kwargs before passing to model
188
+
189
+ FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
190
+ could/should be replaced by an improved configuration mechanism
191
+
192
+ Args:
193
+ default_cfg: input default_cfg (updated in-place)
194
+ kwargs: keyword args passed to model build fn (updated in-place)
195
+ kwargs_filter: keyword arg keys that must be removed before model __init__
196
+ """
197
+ # Overlay default cfg values from `external_default_cfg` if it exists in kwargs
198
+ overlay_external_default_cfg(default_cfg, kwargs)
199
+ # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
200
+ default_kwarg_names = ("num_classes", "global_pool", "in_chans")
201
+ if default_cfg.get("fixed_input_size", False):
202
+ # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
203
+ default_kwarg_names += ("img_size",)
204
+ set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
205
+ # Filter keyword args for task specific model variants (some 'features only' models, etc.)
206
+ filter_kwargs(kwargs, names=kwargs_filter)
207
+
208
+
209
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
210
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
211
+
212
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
213
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
214
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
215
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
216
+ 'survival rate' as the argument.
217
+
218
+ """
219
+ if drop_prob == 0.0 or not training:
220
+ return x
221
+ keep_prob = 1 - drop_prob
222
+ shape = (x.shape[0],) + (1,) * (
223
+ x.ndim - 1
224
+ ) # work with diff dim tensors, not just 2D ConvNets
225
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
226
+ random_tensor.floor_() # binarize
227
+ output = x.div(keep_prob) * random_tensor
228
+ return output
229
+
230
+
231
+ class DropPath(nn.Module):
232
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
233
+
234
+ def __init__(self, drop_prob=None):
235
+ super(DropPath, self).__init__()
236
+ self.drop_prob = drop_prob
237
+
238
+ def forward(self, x):
239
+ return drop_path(x, self.drop_prob, self.training)
240
+
241
+
242
+ from torch.nn.init import _calculate_fan_in_and_fan_out
243
+
244
+
245
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
246
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
247
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
248
+ def norm_cdf(x):
249
+ # Computes standard normal cumulative distribution function
250
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
251
+
252
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
253
+ warnings.warn(
254
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
255
+ "The distribution of values may be incorrect.",
256
+ stacklevel=2,
257
+ )
258
+
259
+ with torch.no_grad():
260
+ # Values are generated by using a truncated uniform distribution and
261
+ # then using the inverse CDF for the normal distribution.
262
+ # Get upper and lower cdf values
263
+ l = norm_cdf((a - mean) / std)
264
+ u = norm_cdf((b - mean) / std)
265
+
266
+ # Uniformly fill tensor with values from [l, u], then translate to
267
+ # [2l-1, 2u-1].
268
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
269
+
270
+ # Use inverse cdf transform for normal distribution to get truncated
271
+ # standard normal
272
+ tensor.erfinv_()
273
+
274
+ # Transform to proper mean, std
275
+ tensor.mul_(std * math.sqrt(2.0))
276
+ tensor.add_(mean)
277
+
278
+ # Clamp to ensure it's in the proper range
279
+ tensor.clamp_(min=a, max=b)
280
+ return tensor
281
+
282
+
283
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
284
+ r"""Fills the input Tensor with values drawn from a truncated
285
+ normal distribution. The values are effectively drawn from the
286
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
287
+ with values outside :math:`[a, b]` redrawn until they are within
288
+ the bounds. The method used for generating the random values works
289
+ best when :math:`a \leq \text{mean} \leq b`.
290
+ Args:
291
+ tensor: an n-dimensional `torch.Tensor`
292
+ mean: the mean of the normal distribution
293
+ std: the standard deviation of the normal distribution
294
+ a: the minimum cutoff value
295
+ b: the maximum cutoff value
296
+ Examples:
297
+ >>> w = torch.empty(3, 5)
298
+ >>> nn.init.trunc_normal_(w)
299
+ """
300
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
301
+
302
+
303
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
304
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
305
+ if mode == "fan_in":
306
+ denom = fan_in
307
+ elif mode == "fan_out":
308
+ denom = fan_out
309
+ elif mode == "fan_avg":
310
+ denom = (fan_in + fan_out) / 2
311
+
312
+ variance = scale / denom
313
+
314
+ if distribution == "truncated_normal":
315
+ # constant is stddev of standard normal truncated to (-2, 2)
316
+ trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
317
+ elif distribution == "normal":
318
+ tensor.normal_(std=math.sqrt(variance))
319
+ elif distribution == "uniform":
320
+ bound = math.sqrt(3 * variance)
321
+ tensor.uniform_(-bound, bound)
322
+ else:
323
+ raise ValueError(f"invalid distribution {distribution}")
324
+
325
+
326
+ def lecun_normal_(tensor):
327
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
328
+
329
+
330
+ def build_model_with_cfg(
331
+ model_cls,
332
+ variant: str,
333
+ pretrained: bool,
334
+ default_cfg: dict,
335
+ model_cfg=None,
336
+ feature_cfg=None,
337
+ pretrained_strict: bool = True,
338
+ pretrained_filter_fn=None,
339
+ pretrained_custom_load=False,
340
+ kwargs_filter=None,
341
+ **kwargs,
342
+ ):
343
+ """Build model with specified default_cfg and optional model_cfg
344
+
345
+ This helper fn aids in the construction of a model including:
346
+ * handling default_cfg and associated pretained weight loading
347
+ * passing through optional model_cfg for models with config based arch spec
348
+ * features_only model adaptation
349
+ * pruning config / model adaptation
350
+
351
+ Args:
352
+ model_cls (nn.Module): model class
353
+ variant (str): model variant name
354
+ pretrained (bool): load pretrained weights
355
+ default_cfg (dict): model's default pretrained/task config
356
+ model_cfg (Optional[Dict]): model's architecture config
357
+ feature_cfg (Optional[Dict]: feature extraction adapter config
358
+ pretrained_strict (bool): load pretrained weights strictly
359
+ pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
360
+ pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
361
+ kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
362
+ **kwargs: model args passed through to model __init__
363
+ """
364
+ pruned = kwargs.pop("pruned", False)
365
+ features = False
366
+ feature_cfg = feature_cfg or {}
367
+ default_cfg = deepcopy(default_cfg) if default_cfg else {}
368
+ update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
369
+ default_cfg.setdefault("architecture", variant)
370
+
371
+ # Setup for feature extraction wrapper done at end of this fn
372
+ if kwargs.pop("features_only", False):
373
+ features = True
374
+ feature_cfg.setdefault("out_indices", (0, 1, 2, 3, 4))
375
+ if "out_indices" in kwargs:
376
+ feature_cfg["out_indices"] = kwargs.pop("out_indices")
377
+
378
+ # Build the model
379
+ model = (
380
+ model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
381
+ )
382
+ model.default_cfg = default_cfg
383
+
384
+ # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
385
+ num_classes_pretrained = (
386
+ 0
387
+ if features
388
+ else getattr(model, "num_classes", kwargs.get("num_classes", 1000))
389
+ )
390
+ if pretrained:
391
+ assert not pretrained_custom_load, "URL should not contain npz for PASST models"
392
+ load_pretrained(
393
+ model,
394
+ num_classes=num_classes_pretrained,
395
+ in_chans=kwargs.get("in_chans", 3),
396
+ filter_fn=pretrained_filter_fn,
397
+ strict=pretrained_strict,
398
+ )
399
+ return model
models/m2d/M2D_wrapper.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.m2d.portable_m2d import PortableM2D as M2D
2
+ from models.transformer_wrapper import BaseModelWrapper
3
+
4
+
5
+ class M2DWrapper(BaseModelWrapper):
6
+ def __init__(self) -> None:
7
+ super().__init__()
8
+ self.m2d = M2D()
9
+
10
+ def mel_forward(self, x):
11
+ return self.m2d.to_normalized_feature(x)
12
+
13
+ def forward(self, spec):
14
+ return self.m2d.forward_mel(spec)
15
+
16
+ def separate_params(self):
17
+ pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
18
+ for k, p in self.named_parameters():
19
+ if any(['cls_token' in k,
20
+ 'pos_embed' in k,
21
+ 'norm_stats' in k,
22
+ 'patch_embed' in k]):
23
+ pt_params[0].append(p)
24
+ elif 'blocks.0.' in k:
25
+ pt_params[0].append(p)
26
+ elif 'blocks.1.' in k:
27
+ pt_params[1].append(p)
28
+ elif 'blocks.2.' in k:
29
+ pt_params[2].append(p)
30
+ elif 'blocks.3.' in k:
31
+ pt_params[3].append(p)
32
+ elif 'blocks.4.' in k:
33
+ pt_params[4].append(p)
34
+ elif 'blocks.5.' in k:
35
+ pt_params[5].append(p)
36
+ elif 'blocks.6.' in k:
37
+ pt_params[6].append(p)
38
+ elif 'blocks.7.' in k:
39
+ pt_params[7].append(p)
40
+ elif 'blocks.8.' in k:
41
+ pt_params[8].append(p)
42
+ elif 'blocks.9.' in k:
43
+ pt_params[9].append(p)
44
+ elif 'blocks.10.' in k:
45
+ pt_params[10].append(p)
46
+ elif 'blocks.11.' in k:
47
+ pt_params[11].append(p)
48
+ elif 'backbone.norm.weight' in k or 'backbone.norm.bias' in k:
49
+ pt_params[11].append(p)
50
+ else:
51
+ raise ValueError(f"Check separate params for M2D! Unknown key: {k}")
52
+ return list(reversed(pt_params))
models/m2d/portable_m2d.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Masked Modeling Duo (M2D) Portable Runtime.
2
+
3
+ All you need is:
4
+ pip install timm, einops, nnAudio
5
+ """
6
+
7
+ import logging
8
+ from functools import partial
9
+ from pathlib import Path
10
+
11
+ import nnAudio.features
12
+ import numpy as np
13
+ import timm
14
+ import torch
15
+ from einops import rearrange
16
+ from timm.models.layers import trunc_normal_
17
+
18
+
19
+ class Config:
20
+ weight_file = ''
21
+ feature_d = 768 * 5
22
+ norm_type = all
23
+ pooling_type = 'mean'
24
+ model = ''
25
+ input_size = [80, 208]
26
+ patch_size = [16, 16]
27
+ sr = '16k'
28
+ flat_features = False
29
+
30
+
31
+ def expand_size(sz):
32
+ if isinstance(sz, int):
33
+ return [sz, sz]
34
+ return sz
35
+
36
+
37
+ class PatchEmbed(torch.nn.Module):
38
+ """ 2D Image to Patch Embedding -- borrowed from https://pypi.org/project/timm/0.4.12/"""
39
+
40
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
41
+ super().__init__()
42
+ img_size = expand_size(img_size)
43
+ patch_size = expand_size(patch_size)
44
+ self.img_size = img_size
45
+ self.patch_size = patch_size
46
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
47
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
48
+ self.flatten = flatten
49
+
50
+ self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
51
+ self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()
52
+
53
+ def forward(self, x):
54
+ x = self.proj(x)
55
+ if self.flatten:
56
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
57
+ x = self.norm(x)
58
+ return x
59
+
60
+
61
+ class LocalViT(timm.models.vision_transformer.VisionTransformer):
62
+ """ Vision Transformer for M2D Audio"""
63
+
64
+ def __init__(self, **kwargs):
65
+ super().__init__(**kwargs)
66
+ # Workaround for PatchEmbed to avoid unintended assertion failure. ex) AssertionError: Input image width (102) doesn't match model (608).
67
+ self.patch_embed = PatchEmbed(self.patch_embed.img_size, self.patch_embed.patch_size,
68
+ self.patch_embed.proj.in_channels, self.patch_embed.proj.out_channels)
69
+ self.norm_stats = torch.nn.Parameter(torch.tensor([-7.1, 4.2]), requires_grad=False)
70
+ # We do not use the default head
71
+ del self.head
72
+
73
+ def patch_size(self):
74
+ return np.array(self.patch_embed.patch_size)
75
+
76
+ def grid_size(self):
77
+ # Workaround for compatibility issue (timm 0.4.5 fails with: return self.patch_embed.grid_size)
78
+ img_size = np.array(self.patch_embed.img_size)
79
+ patch_size = self.patch_size()
80
+ grid_size = img_size // patch_size
81
+ return grid_size
82
+
83
+ def forward_encoder(self, x):
84
+ x = self.patch_embed(x)
85
+
86
+ # add pos embed w/o cls token
87
+ pos_embed = self.pos_embed[:, 1:, :]
88
+ if x.shape[1] < pos_embed.shape[1]: # shorten pos_embed for a short input
89
+ dims = pos_embed.shape[-1]
90
+ fbins = self.grid_size()[0]
91
+ frames = x.shape[1] // fbins
92
+ pos_embed = pos_embed.reshape(1, fbins, -1, dims)[:, :, :frames, :].reshape(1, fbins * frames, dims)
93
+ x = x + pos_embed
94
+
95
+ # append cls token
96
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
97
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
98
+ x = torch.cat((cls_tokens, x), dim=1)
99
+
100
+ # apply Transformer blocks
101
+ for blk in self.blocks:
102
+ x = blk(x)
103
+ x = self.norm(x)
104
+
105
+ return x
106
+
107
+
108
+ def parse_sizes_by_name(name):
109
+ # Parse parameters. "m2d_vit_base-80x1001p16x16p16k" -> input size: 80x1001, patch size: 16x16, sr: 16k
110
+ model_cls = name.split('-')[0]
111
+ params = name.split('-')[1]
112
+ params = params.split('p')[:3]
113
+ input_str, patch_str, sr = params[0], params[1], params[2] if len(params) > 2 else '16k'
114
+ input_size = [int(a) for a in input_str.split('x')]
115
+ patch_size = [int(a) for a in patch_str.split('x')]
116
+ return input_size, patch_size, sr, model_cls
117
+
118
+
119
+ def drop_non_model_weights(model, checkpoint, filename):
120
+ model_keys = [n for n, p in model.named_parameters()]
121
+ new_ckpt, dropped = {}, []
122
+ for k in checkpoint:
123
+ if k not in model_keys:
124
+ dropped.append(k)
125
+ continue
126
+ new_ckpt[k] = checkpoint[k]
127
+ n_org = len(checkpoint.keys())
128
+ n_cur = len(new_ckpt.keys())
129
+ print(
130
+ f' using {n_cur} parameters, while dropped {n_org - n_cur} out of {n_org} parameters from {Path(filename).parent / Path(filename).name}'
131
+ if n_org > n_cur else f' using {n_cur} parameters from {Path(filename).parent / Path(filename).name}')
132
+ print(' (dropped:', dropped[:5], ')' if len(dropped) < 5 else '...)')
133
+ return new_ckpt
134
+
135
+
136
+ def load_evar_head_parameters(checkpoint, head_norm, head):
137
+ # Load the weights of the task head trained in the EVAR fine-tuning.
138
+ if 'module.head.norm.running_mean' in checkpoint:
139
+ head_norm.load_state_dict({to_k: checkpoint[k] for to_k, k in {
140
+ 'running_mean': 'module.head.norm.running_mean', 'running_var': 'module.head.norm.running_var'}.items()})
141
+ head.load_state_dict({to_k: checkpoint[k] for to_k, k in {
142
+ 'weight': 'module.head.mlp.mlp.0.weight', 'bias': 'module.head.mlp.mlp.0.bias'}.items()})
143
+ else:
144
+ print(' Not an EVAR checkpoint for loading head weights.')
145
+
146
+
147
+ def reformat_ckpt_keys(checkpoint):
148
+ # In case: checkpoint['model']
149
+ checkpoint = checkpoint['model'] if 'model' in checkpoint else checkpoint
150
+ # The checkpoints saved in a EVAR fine-tuning has a prefix of "module.ar.runtime.backbone", the following removes it.
151
+ new_ckpt = {}
152
+ for k in checkpoint:
153
+ new_k = k.replace('module.ar.runtime.backbone.', '') # replace
154
+ new_ckpt[new_k] = checkpoint[k]
155
+ return new_ckpt
156
+
157
+
158
+ def make_it_CLAP(model, checkpoint):
159
+ # Add projectors if needed
160
+ if 'audio_proj.0.weight' in checkpoint.keys():
161
+ proj_hidden_dim = embed_dim = checkpoint['audio_proj.0.weight'].shape[1]
162
+ model.audio_proj = torch.nn.Sequential(
163
+ torch.nn.Linear(embed_dim, proj_hidden_dim),
164
+ torch.nn.ReLU(),
165
+ torch.nn.Linear(proj_hidden_dim, embed_dim),
166
+ )
167
+ if 'text_proj.weight' in checkpoint.keys():
168
+ dim = checkpoint['text_proj.weight'].shape
169
+ model.text_proj = torch.nn.Linear(dim[1], dim[0])
170
+ else:
171
+ model.text_proj = torch.nn.Identity()
172
+
173
+
174
+ def get_backbone(args, weight_file):
175
+ name = Path(weight_file).parent.name if weight_file is not None \
176
+ else "m2d_clap_vit_base-80x1001p16x16-240128_AS-FT_enconly"
177
+ args.input_size, args.patch_size, args.sr, args.beats = parse_sizes_by_name(name)
178
+
179
+ # Create a ViT.
180
+ model = LocalViT(
181
+ in_chans=1, img_size=args.input_size, patch_size=args.patch_size, embed_dim=768, depth=12, num_heads=12,
182
+ mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6))
183
+
184
+ if weight_file is None:
185
+ args.mean, args.std = -7.1, 4.2
186
+ model.eval()
187
+ return model, None
188
+
189
+ # Load checkpoint.
190
+ checkpoint = torch.load(weight_file, map_location='cpu')
191
+ checkpoint = reformat_ckpt_keys(checkpoint)
192
+ # Set normalization statistics for backward compatibility. The [-7.1, 4.2] is for 2022 models.
193
+ if 'norm_stats' not in checkpoint:
194
+ checkpoint['norm_stats'] = torch.tensor([-7.1, 4.2])
195
+ print(' using default norm_stats:', checkpoint['norm_stats'])
196
+
197
+ # Modify the model if it should be a M2D-CLAP.
198
+ make_it_CLAP(model, checkpoint)
199
+
200
+ # Load weights.
201
+ dropped = drop_non_model_weights(model, checkpoint, weight_file)
202
+ msg = model.load_state_dict(dropped)
203
+ print(msg);
204
+ logging.info(msg)
205
+
206
+ # Make normalization statistics for the model easy to use in the downstream task.
207
+ args.mean, args.std = model.state_dict()['norm_stats'].to('cpu').numpy()
208
+
209
+ model.eval()
210
+ return model, checkpoint
211
+
212
+
213
+ def get_to_melspec(cfg):
214
+ if cfg.sr == '16k':
215
+ cfg.sample_rate, cfg.n_fft, cfg.window_size, cfg.hop_size = 16000, 400, 400, 160
216
+ cfg.n_mels, cfg.f_min, cfg.f_max = 80, 50, 8000
217
+ elif cfg.sr == '32k':
218
+ cfg.sample_rate, cfg.n_fft, cfg.window_size, cfg.hop_size = 32000, 800, 800, 320
219
+ cfg.n_mels, cfg.f_min, cfg.f_max = 80, 50, 16000
220
+ else:
221
+ assert False, f'Unknown input size: {cfg.input_size}'
222
+
223
+ to_spec = nnAudio.features.MelSpectrogram(
224
+ sr=cfg.sample_rate,
225
+ n_fft=cfg.n_fft,
226
+ win_length=cfg.window_size,
227
+ hop_length=cfg.hop_size,
228
+ n_mels=cfg.n_mels,
229
+ fmin=cfg.f_min,
230
+ fmax=cfg.f_max,
231
+ center=True,
232
+ power=2,
233
+ verbose=False,
234
+ )
235
+ logging.info(f'Runtime MelSpectrogram({cfg.sample_rate}, {cfg.n_fft}, {cfg.window_size}, {cfg.hop_size}, '
236
+ + f'{cfg.n_mels}, {cfg.f_min}, {cfg.f_max}):')
237
+ logging.info(to_spec)
238
+ return to_spec
239
+
240
+
241
+ def get_timestamps(cfg, batch_audio, x): # Returns timestamps in milliseconds.
242
+ audio_len = len(batch_audio[0])
243
+ sec = audio_len / cfg.sample_rate
244
+ x_len = len(x[0])
245
+ step = sec / x_len * 1000 # sec -> ms
246
+ ts = torch.tensor([step * i for i in range(x_len)]).unsqueeze(0)
247
+ ts = ts.repeat(len(batch_audio), 1)
248
+ return ts
249
+
250
+
251
+ class PortableM2D(torch.nn.Module):
252
+ def __init__(self, weight_file=None, num_classes=None, freeze_embed=False, flat_features=None):
253
+ super().__init__()
254
+ self.cfg = Config()
255
+ self.cfg.weight_file = weight_file
256
+ self.cfg.freeze_embed = freeze_embed
257
+ self.cfg.flat_features = self.cfg.flat_features if flat_features is None else flat_features
258
+
259
+ # Create backbone model.
260
+ self.backbone, checkpoint = get_backbone(self.cfg, self.cfg.weight_file)
261
+ # Finalize feature dimension.
262
+ d = self.backbone.pos_embed.shape[-1]
263
+ if num_classes is not None and 'module.head.mlp.mlp.0.weight' in checkpoint and \
264
+ checkpoint['module.head.mlp.mlp.0.weight'].shape[-1] == d:
265
+ self.cfg.flat_features = True
266
+ n_stack_feature = 1 if self.cfg.flat_features else (self.cfg.input_size[0] // self.cfg.patch_size[0])
267
+ self.cfg.feature_d = d * n_stack_feature # 768 if flat_features else 768*5=3840
268
+ # Create head.
269
+ if num_classes is not None:
270
+ self.head_norm = torch.nn.BatchNorm1d(self.cfg.feature_d, affine=False)
271
+ self.head = torch.nn.Linear(self.cfg.feature_d, num_classes)
272
+ trunc_normal_(self.head.weight, std=2e-5)
273
+ load_evar_head_parameters(checkpoint, self.head_norm, self.head)
274
+ # Option: freeze patch embedding ([2211.09359] How to Fine-Tune Vision Models with SGD)
275
+ if self.cfg.freeze_embed:
276
+ models_mae.set_requires_grad(self.backbone.patch_embed, False)
277
+ logging.info(' ** Freeze patch_embed **')
278
+ logging.info(self.backbone.patch_embed)
279
+
280
+ logging.info(f'Model input size: {self.cfg.input_size}')
281
+ logging.info(f'Using weights: {self.cfg.weight_file}')
282
+ logging.info(f'Feature dimension: {self.cfg.feature_d}')
283
+ logging.info(f'Norm stats: {self.cfg.mean}, {self.cfg.std}')
284
+
285
+ self.to_spec = get_to_melspec(self.cfg)
286
+ self.eval()
287
+
288
+ def to_log_mel_spec(self, batch_audio):
289
+ x = self.to_spec(batch_audio)
290
+ x = (x + torch.finfo().eps).log()
291
+ x = x.unsqueeze(1)
292
+ return x
293
+
294
+ def normalize_batch(self, x):
295
+ x = (x - self.cfg.mean) / self.cfg.std
296
+ return x
297
+
298
+ def to_normalized_feature(self, batch_audio):
299
+ x = self.to_log_mel_spec(batch_audio)
300
+ x = self.normalize_batch(x)
301
+ return x
302
+
303
+ def encode_lms(self, x, average_per_time_frame=False):
304
+ patch_fbins = self.backbone.grid_size()[0]
305
+ unit_frames = self.cfg.input_size[1]
306
+ patch_frames = self.backbone.patch_size()[1]
307
+ embed_d = self.backbone.patch_embed.proj.out_channels
308
+ n_chunk = (x.shape[-1] + unit_frames - 1) // unit_frames
309
+ pad_frames = (patch_frames - (x.shape[-1] % unit_frames % patch_frames)) % patch_frames
310
+ if pad_frames > 0:
311
+ x = torch.nn.functional.pad(x, (0, pad_frames))
312
+
313
+ embeddings = []
314
+ if self.cfg.flat_features:
315
+ # flatten all patch embeddings
316
+ for i in range(n_chunk):
317
+ emb = self.backbone.forward_encoder(x[..., i * unit_frames:(i + 1) * unit_frames])
318
+ emb = emb[..., 1:, :]
319
+ if average_per_time_frame:
320
+ emb = rearrange(emb, 'b (f t) d -> b t d f', f=patch_fbins, d=embed_d).mean(-1)
321
+ embeddings.append(emb)
322
+ else:
323
+ # stack embeddings along time frame
324
+ for i in range(n_chunk):
325
+ emb = self.backbone.forward_encoder(x[..., i * unit_frames:(i + 1) * unit_frames])
326
+ emb = emb[..., 1:, :]
327
+ emb = rearrange(emb, 'b (f t) d -> b t (f d)', f=patch_fbins, d=embed_d)
328
+ embeddings.append(emb)
329
+ # concatenate embedding chunks in the time axis
330
+ x = torch.cat(embeddings, axis=-2)
331
+ return x
332
+
333
+ def encode(self, batch_audio, average_per_time_frame=False):
334
+ x = self.to_normalized_feature(batch_audio)
335
+ return self.encode_lms(x, average_per_time_frame=average_per_time_frame)
336
+
337
+ def forward(self, batch_audio, average_per_time_frame=False):
338
+ x = self.encode(batch_audio, average_per_time_frame=average_per_time_frame)
339
+ if hasattr(self, 'head'):
340
+ x = x.mean(1) # B, D
341
+ x = self.head_norm(x.unsqueeze(-1)).squeeze(-1)
342
+ x = self.head(x)
343
+ return x
344
+
345
+ def forward_mel(self, batch_mel, average_per_time_frame=False):
346
+ x = self.encode_lms(batch_mel, average_per_time_frame=average_per_time_frame)
347
+ if hasattr(self, 'head'):
348
+ x = x.mean(1) # B, D
349
+ x = self.head_norm(x.unsqueeze(-1)).squeeze(-1)
350
+ x = self.head(x)
351
+ return x
352
+
353
+ def get_scene_embeddings(self, batch_audio):
354
+ x = self.encode(batch_audio)
355
+ x = torch.mean(x, dim=1)
356
+ return x
357
+
358
+ def get_timestamp_embeddings(self, batch_audio):
359
+ x = self.encode(batch_audio, average_per_time_frame=True)
360
+ ts = get_timestamps(self.cfg, batch_audio, x)
361
+ return x, ts
362
+
363
+ def forward_frames(self, batch_audio):
364
+ x, ts = self.get_timestamp_embeddings(batch_audio)
365
+ if hasattr(self, 'head'):
366
+ x = self.head_norm(x.transpose(-1, -2)).transpose(-2, -1)
367
+ x = self.head(x)
368
+ return x, ts
369
+
370
+ def encode_clap_audio(self, batch_audio):
371
+ audio_embeddings = self.forward(batch_audio)
372
+ audio_embeddings = audio_embeddings.mean(dim=-2)
373
+ audio_embeddings = self.backbone.audio_proj(audio_embeddings)
374
+ return audio_embeddings
375
+
376
+ def encode_clap_text(self, batch_text, truncate=False):
377
+ if not hasattr(self, 'text_encoder'):
378
+ self.text_encoder = GTETextEncoder()
379
+ text_embeddings = self.text_encoder(batch_text, truncate=truncate)
380
+ text_embeddings = self.backbone.text_proj(text_embeddings)
381
+ text_embeddings = text_embeddings.detach().cpu().to(torch.float)
382
+ return text_embeddings
383
+
384
+
385
+ # For the CLAP models
386
+
387
+ class GTETextEncoder:
388
+ def __init__(self, clip_weight="thenlper/gte-base"):
389
+ from transformers import AutoTokenizer, AutoModel
390
+ import os
391
+ os.environ["TOKENIZERS_PARALLELISM"] = "true" # To suppress warnings.
392
+
393
+ self.tokenizer = AutoTokenizer.from_pretrained(clip_weight)
394
+ self.model = AutoModel.from_pretrained(clip_weight)
395
+
396
+ def __call__(self, texts, truncate=True, max_length=512):
397
+ def average_pool(last_hidden_states, attention_mask):
398
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
399
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
400
+
401
+ with torch.no_grad():
402
+ device = next(self.model.parameters()).device
403
+ batch_dict = self.tokenizer(texts, max_length=max_length, padding=True, truncation=truncate,
404
+ return_tensors='pt')
405
+ batch_dict['input_ids'] = batch_dict['input_ids'].to(device)
406
+ batch_dict['token_type_ids'] = batch_dict['token_type_ids'].to(device)
407
+ batch_dict['attention_mask'] = batch_dict['attention_mask'].to(device)
408
+ outputs = self.model.to(device)(**batch_dict)
409
+ embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
410
+ return embeddings
models/prediction_wrapper.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.hub import download_url_to_file
6
+
7
+ from config import RESOURCES_FOLDER, CHECKPOINT_URLS
8
+ from models.seq_models import BidirectionalLSTM, BidirectionalGRU
9
+
10
+
11
+ class PredictionsWrapper(nn.Module):
12
+ """
13
+ A wrapper module that adds an optional sequence model and classification heads on top of a transformer.
14
+ It implements equations (1), (2), and (3) in the paper.
15
+
16
+ Args:
17
+ base_model (BaseModelWrapper): The base model (transformer) providing sequence embeddings
18
+ checkpoint (str, optional): checkpoint name for loading pre-trained weights. Default is None.
19
+ n_classes_strong (int): Number of classes for strong predictions. Default is 447.
20
+ n_classes_weak (int, optional): Number of classes for weak predictions. Default is None,
21
+ which sets it equal to n_classes_strong.
22
+ embed_dim (int, optional): Embedding dimension of the base model output. Default is 768.
23
+ seq_len (int, optional): Desired sequence length. Default is 250 (40 ms resolution).
24
+ seq_model_type (str, optional): Type of sequence model to use.
25
+ Default is None, which means no additional sequence model is used.
26
+ head_type (str, optional): Type of classification head. Choices are ["linear", "attention", "None"].
27
+ Default is "linear". "None" means that sequence embeddings are returned.
28
+ rnn_layers (int, optional): Number of RNN layers if seq_model_type is "rnn". Default is 2.
29
+ rnn_type (str, optional): Type of RNN to use. Choices are ["BiGRU", "BiLSTM"]. Default is "BiGRU".
30
+ rnn_dim (int, optional): Dimension of RNN hidden state if seq_model_type is "rnn". Default is 256.
31
+ rnn_dropout (float, optional): Dropout rate for RNN layers. Default is 0.0.
32
+ """
33
+
34
+ def __init__(self,
35
+ base_model,
36
+ checkpoint=None,
37
+ n_classes_strong=447,
38
+ n_classes_weak=None,
39
+ embed_dim=768,
40
+ seq_len=250,
41
+ seq_model_type=None,
42
+ head_type="linear",
43
+ rnn_layers=2,
44
+ rnn_type="BiGRU",
45
+ rnn_dim=2048,
46
+ rnn_dropout=0.0
47
+ ):
48
+ super(PredictionsWrapper, self).__init__()
49
+ self.model = base_model
50
+ self.seq_len = seq_len
51
+ self.embed_dim = embed_dim
52
+ self.n_classes_strong = n_classes_strong
53
+ self.n_classes_weak = n_classes_weak if n_classes_weak is not None else n_classes_strong
54
+ self.seq_model_type = seq_model_type
55
+ self.head_type = head_type
56
+
57
+ if self.seq_model_type == "rnn":
58
+ if rnn_type == "BiGRU":
59
+ self.seq_model = BidirectionalGRU(
60
+ n_in=self.embed_dim,
61
+ n_hidden=rnn_dim,
62
+ dropout=rnn_dropout,
63
+ num_layers=rnn_layers
64
+ )
65
+ elif rnn_type == "BiLSTM":
66
+ self.seq_model = BidirectionalLSTM(
67
+ nIn=self.embed_dim,
68
+ nHidden=rnn_dim,
69
+ nOut=rnn_dim * 2,
70
+ dropout=rnn_dropout,
71
+ num_layers=rnn_layers
72
+ )
73
+ num_features = rnn_dim * 2
74
+ elif self.seq_model_type is None:
75
+ self.seq_model = nn.Identity()
76
+ # no additional sequence model
77
+ num_features = self.embed_dim
78
+ else:
79
+ raise ValueError(f"Unknown seq_model_type: {self.seq_model_type}")
80
+
81
+ if self.head_type == "attention":
82
+ assert self.n_classes_strong == self.n_classes_weak, "head_type=='attention' requires number of strong and " \
83
+ "weak classes to be the same!"
84
+
85
+ if self.head_type is not None:
86
+ self.strong_head = nn.Linear(num_features, self.n_classes_strong)
87
+ self.weak_head = nn.Linear(num_features, self.n_classes_weak)
88
+ if checkpoint is not None:
89
+ print("Loading pretrained checkpoint: ", checkpoint)
90
+ self.load_checkpoint(checkpoint)
91
+
92
+ def load_checkpoint(self, checkpoint):
93
+ ckpt_file = os.path.join(RESOURCES_FOLDER, checkpoint + ".pt")
94
+ if not os.path.exists(ckpt_file):
95
+ download_url_to_file(CHECKPOINT_URLS[checkpoint], ckpt_file)
96
+ state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True)
97
+
98
+ # compatibility with uniform wrapper structure we introduced for the public repo
99
+ if 'fpasst' in checkpoint:
100
+ state_dict = {("model.fpasst." + k[len("model."):] if k.startswith("model.")
101
+ else k): v for k, v in state_dict.items()}
102
+ elif 'M2D' in checkpoint:
103
+ state_dict = {("model.m2d." + k[len("model."):] if not k.startswith("model.m2d.") and k.startswith("model.")
104
+ else k): v for k, v in state_dict.items()}
105
+ elif 'BEATs' in checkpoint:
106
+ state_dict = {("model.beats." + k[len("model.model."):] if k.startswith("model.model")
107
+ else k): v for k, v in state_dict.items()}
108
+ elif 'ASIT' in checkpoint:
109
+ state_dict = {("model.asit." + k[len("model."):] if k.startswith("model.")
110
+ else k): v for k, v in state_dict.items()}
111
+
112
+ n_classes_weak_in_sd = state_dict['weak_head.bias'].shape[0] if 'weak_head.bias' in state_dict else -1
113
+ n_classes_strong_in_sd = state_dict['strong_head.bias'].shape[0] if 'strong_head.bias' in state_dict else -1
114
+ seq_model_in_sd = any(['seq_model.' in key for key in state_dict.keys()])
115
+ keys_to_remove = []
116
+ strict = True
117
+ expected_missing = 0
118
+ if self.head_type is None:
119
+ # remove all keys related to head
120
+ keys_to_remove.append('weak_head.bias')
121
+ keys_to_remove.append('weak_head.weight')
122
+ keys_to_remove.append('strong_head.bias')
123
+ keys_to_remove.append('strong_head.weight')
124
+ elif self.seq_model_type is not None and not seq_model_in_sd:
125
+ # we want to train a sequence model (e.g., rnn) on top of a
126
+ # pre-trained transformer (e.g., AS weak pretrained)
127
+ keys_to_remove.append('weak_head.bias')
128
+ keys_to_remove.append('weak_head.weight')
129
+ keys_to_remove.append('strong_head.bias')
130
+ keys_to_remove.append('strong_head.weight')
131
+ num_seq_model_keys = len([key for key in self.seq_model.state_dict()])
132
+ expected_missing = len(keys_to_remove) + num_seq_model_keys
133
+ strict = False
134
+ else:
135
+ # head type is not None
136
+ if n_classes_weak_in_sd != self.n_classes_weak:
137
+ # remove weak head from sd
138
+ keys_to_remove.append('weak_head.bias')
139
+ keys_to_remove.append('weak_head.weight')
140
+ strict = False
141
+ if n_classes_strong_in_sd != self.n_classes_strong:
142
+ # remove strong head from sd
143
+ keys_to_remove.append('strong_head.bias')
144
+ keys_to_remove.append('strong_head.weight')
145
+ strict = False
146
+ expected_missing = len(keys_to_remove)
147
+
148
+ # allow missing mel parameters for compatibility
149
+ num_mel_keys = len([key for key in self.state_dict() if 'mel_transform' in key])
150
+ if num_mel_keys > 0:
151
+ expected_missing += num_mel_keys
152
+ strict = False
153
+
154
+ state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove}
155
+ missing, unexpected = self.load_state_dict(state_dict, strict=strict)
156
+ assert len(missing) == expected_missing
157
+ assert len(unexpected) == 0
158
+
159
+ def separate_params(self):
160
+ if hasattr(self, "separate_params"):
161
+ return self.model.separate_params()
162
+ else:
163
+ raise NotImplementedError("The base model has no 'separate_params' method!'")
164
+
165
+ def has_separate_params(self):
166
+ return hasattr(self.model, "separate_params")
167
+
168
+ def mel_forward(self, x):
169
+ return self.model.mel_forward(x)
170
+
171
+ def forward(self, x):
172
+ # base model is expected to output a sequence (see Eq. (1) in paper)
173
+ # (batch size x sequence length x embedding dimension)
174
+ x = self.model(x)
175
+
176
+ # ATST: x.shape: batch size x 250 x 768
177
+ # PaSST: x.shape: batch size x 250 x 768
178
+ # ASiT: x.shape: batch size x 497 x 768
179
+ # M2D: x.shape: batch size x 62 x 3840
180
+ # BEATs: x.shape: batch size x 496 x 768
181
+
182
+ assert len(x.shape) == 3
183
+
184
+ if x.size(-2) > self.seq_len:
185
+ x = torch.nn.functional.adaptive_avg_pool1d(x.transpose(1, 2), self.seq_len).transpose(1, 2)
186
+ elif x.size(-2) < self.seq_len:
187
+ x = torch.nn.functional.interpolate(x.transpose(1, 2), size=self.seq_len,
188
+ mode='linear').transpose(1, 2)
189
+
190
+ # Eq. (3) in the paper
191
+ # for teachers this is an RNN, for students it is nn.Identity
192
+ x = self.seq_model(x)
193
+
194
+ if self.head_type == "attention":
195
+ # attention head to obtain weak from strong predictions
196
+ # this is typically used for the DESED task, which requires both
197
+ # weak and strong predictions
198
+ strong = torch.sigmoid(self.strong_head(x))
199
+ sof = torch.softmax(self.weak_head(x), dim=-1)
200
+ sof = torch.clamp(sof, min=1e-7, max=1)
201
+ weak = (strong * sof).sum(1) / sof.sum(1)
202
+ return strong.transpose(1, 2), weak
203
+ elif self.head_type == "linear":
204
+ # simple linear layers as head (see Eq. (3) in the paper)
205
+ # on AudioSet strong, only strong predictions are used
206
+ # on AudioSet weak, only weak predictions are used
207
+ # why both? because we tried to simultaneously train on AudioSet weak and strong (less successful)
208
+ strong = self.strong_head(x)
209
+ weak = self.weak_head(x.mean(dim=1))
210
+ return strong.transpose(1, 2), weak
211
+ else:
212
+ # no head means the sequence is returned instead of strong and weak predictions
213
+ return x
models/seq_models.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class BidirectionalGRU(nn.Module):
5
+ def __init__(self, n_in, n_hidden, dropout=0, num_layers=1):
6
+ super(BidirectionalGRU, self).__init__()
7
+ self.rnn = nn.GRU(
8
+ n_in,
9
+ n_hidden,
10
+ bidirectional=True,
11
+ dropout=dropout,
12
+ batch_first=True,
13
+ num_layers=num_layers,
14
+ )
15
+
16
+ def forward(self, input_feat):
17
+ recurrent, _ = self.rnn(input_feat)
18
+ return recurrent
19
+
20
+
21
+ class BidirectionalLSTM(nn.Module):
22
+ def __init__(self, nIn, nHidden, nOut, dropout=0, num_layers=1):
23
+ super(BidirectionalLSTM, self).__init__()
24
+ self.rnn = nn.LSTM(
25
+ nIn,
26
+ nHidden,
27
+ bidirectional=True,
28
+ batch_first=True,
29
+ dropout=dropout,
30
+ num_layers=num_layers,
31
+ )
32
+ self.embedding = nn.Linear(nHidden * 2, nOut)
33
+
34
+ def forward(self, input_feat):
35
+ recurrent, _ = self.rnn(input_feat)
36
+ b, T, h = recurrent.size()
37
+ t_rec = recurrent.contiguous().view(b * T, h)
38
+ output = self.embedding(t_rec)
39
+ output = output.view(b, T, -1)
40
+ return output
models/transformer_wrapper.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import torch.nn as nn
3
+
4
+
5
+ class BaseModelWrapper(ABC, nn.Module):
6
+ @abstractmethod
7
+ def mel_forward(self, x):
8
+ """Process input waveform to mel spectrogram."""
9
+ pass
10
+
11
+ @abstractmethod
12
+ def forward(self, x):
13
+ """Extract embedding sequence from mel spectrogram."""
14
+ pass
15
+
16
+ @abstractmethod
17
+ def separate_params(self):
18
+ """Separate model parameters into predefined groups for layer-wise learning rate decay."""
19
+ pass
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy<2
2
+ librosa
3
+ pandas
4
+ timm
5
+ nnAudio
6
+ av>=10.0.0
7
+ h5py>=3.8.0
8
+ jsonpickle>=3.0.1
9
+ hf_transfer>=0.1.4
10
+ hf-fastup>=0.0.5
11
+ datasets>=2.15.0
12
+ pytorch-lightning>=2.0.0
13
+ wandb
14
+ transformers
15
+ intervaltree
16
+ more-itertools
17
+ torch
18
+ torchvision
19
+ torchaudio
20
+ einops
resources/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ In this folder, we place all files that are automatically downloaded (such as model checkpoints).
resources/best_model_BEATs.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eca6caa468d341767eec3ba0985ec6c0776bf8a15eaeeaf25952cdd2cb5d4613
3
+ size 361966733
resources/eval_durations.csv ADDED
The diff for this file is too large to render. See raw diff
 
resources/labelvocabulary.csv ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ idx,label
2
+ 0,Accelerating_and_revving_and_vroom
3
+ 1,Accordion
4
+ 2,Acoustic_guitar
5
+ 3,Aircraft
6
+ 4,Alarm
7
+ 5,Animal
8
+ 6,Applause
9
+ 7,Bark
10
+ 8,Bass_drum
11
+ 9,Bass_guitar
12
+ 10,Bathtub_(filling_or_washing)
13
+ 11,Bell
14
+ 12,Bicycle
15
+ 13,Bicycle_bell
16
+ 14,Bird
17
+ 15,Bird_vocalization_and_bird_call_and_bird_song
18
+ 16,Boat_and_Water_vehicle
19
+ 17,Boiling
20
+ 18,Boom
21
+ 19,Bowed_string_instrument
22
+ 20,Brass_instrument
23
+ 21,Breathing
24
+ 22,Burping_and_eructation
25
+ 23,Bus
26
+ 24,Buzz
27
+ 25,Camera
28
+ 26,Car
29
+ 27,Car_passing_by
30
+ 28,Cat
31
+ 29,Chatter
32
+ 30,Cheering
33
+ 31,Chewing_and_mastication
34
+ 32,Chicken_and_rooster
35
+ 33,Child_speech_and_kid_speaking
36
+ 34,Chime
37
+ 35,Chink_and_clink
38
+ 36,Chirp_and_tweet
39
+ 37,Chuckle_and_chortle
40
+ 38,Church_bell
41
+ 39,Clapping
42
+ 40,Clock
43
+ 41,Coin_(dropping)
44
+ 42,Computer_keyboard
45
+ 43,Conversation
46
+ 44,Cough
47
+ 45,Cowbell
48
+ 46,Crack
49
+ 47,Crackle
50
+ 48,Crash_cymbal
51
+ 49,Cricket
52
+ 50,Crow
53
+ 51,Crowd
54
+ 52,Crumpling_and_crinkling
55
+ 53,Crushing
56
+ 54,Crying_and_sobbing
57
+ 55,Cupboard_open_or_close
58
+ 56,Cutlery_and_silverware
59
+ 57,Cymbal
60
+ 58,Dishes_and_pots_and_pans
61
+ 59,Dog
62
+ 60,Domestic_animals_and_pets
63
+ 61,Domestic_sounds_and_home_sounds
64
+ 62,Door
65
+ 63,Doorbell
66
+ 64,Drawer_open_or_close
67
+ 65,Drill
68
+ 66,Drip
69
+ 67,Drum
70
+ 68,Drum_kit
71
+ 69,Electric_guitar
72
+ 70,Engine
73
+ 71,Engine_starting
74
+ 72,Explosion
75
+ 73,Fart
76
+ 74,Female_singing
77
+ 75,Female_speech_and_woman_speaking
78
+ 76,Fill_(with_liquid)
79
+ 77,Finger_snapping
80
+ 78,Fire
81
+ 79,Fireworks
82
+ 80,Fixed-wing_aircraft_and_airplane
83
+ 81,Fowl
84
+ 82,Frog
85
+ 83,Frying_(food)
86
+ 84,Gasp
87
+ 85,Giggle
88
+ 86,Glass
89
+ 87,Glockenspiel
90
+ 88,Gong
91
+ 89,Growling
92
+ 90,Guitar
93
+ 91,Gull_and_seagull
94
+ 92,Gunshot_and_gunfire
95
+ 93,Gurgling
96
+ 94,Hammer
97
+ 95,Hands
98
+ 96,Harmonica
99
+ 97,Harp
100
+ 98,Hi-hat
101
+ 99,Hiss
102
+ 100,Human_group_actions
103
+ 101,Human_voice
104
+ 102,Idling
105
+ 103,Insect
106
+ 104,Keyboard_(musical)
107
+ 105,Keys_jangling
108
+ 106,Knock
109
+ 107,Laughter
110
+ 108,Liquid
111
+ 109,Livestock_and_farm_animals_and_working_animals
112
+ 110,Male_singing
113
+ 111,Male_speech_and_man_speaking
114
+ 112,Mallet_percussion
115
+ 113,Marimba_and_xylophone
116
+ 114,Mechanical_fan
117
+ 115,Mechanisms
118
+ 116,Meow
119
+ 117,Microwave_oven
120
+ 118,Motor_vehicle_(road)
121
+ 119,Motorcycle
122
+ 120,Music
123
+ 121,Musical_instrument
124
+ 122,Ocean
125
+ 123,Organ
126
+ 124,Packing_tape_and_duct_tape
127
+ 125,Percussion
128
+ 126,Piano
129
+ 127,Plucked_string_instrument
130
+ 128,Pour
131
+ 129,Power_tool
132
+ 130,Printer
133
+ 131,Purr
134
+ 132,Race_car_and_auto_racing
135
+ 133,Rail_transport
136
+ 134,Rain
137
+ 135,Raindrop
138
+ 136,Ratchet_and_pawl
139
+ 137,Rattle
140
+ 138,Rattle_(instrument)
141
+ 139,Respiratory_sounds
142
+ 140,Ringtone
143
+ 141,Run
144
+ 142,Sawing
145
+ 143,Scissors
146
+ 144,Scratching_(performance_technique)
147
+ 145,Screaming
148
+ 146,Screech
149
+ 147,Shatter
150
+ 148,Shout
151
+ 149,Sigh
152
+ 150,Singing
153
+ 151,Sink_(filling_or_washing)
154
+ 152,Siren
155
+ 153,Skateboard
156
+ 154,Slam
157
+ 155,Sliding_door
158
+ 156,Snare_drum
159
+ 157,Sneeze
160
+ 158,Speech
161
+ 159,Speech_synthesizer
162
+ 160,Splash_and_splatter
163
+ 161,Squeak
164
+ 162,Stream
165
+ 163,Strum
166
+ 164,Subway_and_metro_and_underground
167
+ 165,Tabla
168
+ 166,Tambourine
169
+ 167,Tap
170
+ 168,Tearing
171
+ 169,Telephone
172
+ 170,Thump_and_thud
173
+ 171,Thunder
174
+ 172,Thunderstorm
175
+ 173,Tick
176
+ 174,Tick-tock
177
+ 175,Toilet_flush
178
+ 176,Tools
179
+ 177,Traffic_noise_and_roadway_noise
180
+ 178,Train
181
+ 179,Trickle_and_dribble
182
+ 180,Truck
183
+ 181,Trumpet
184
+ 182,Typewriter
185
+ 183,Typing
186
+ 184,Vehicle
187
+ 185,Vehicle_horn_and_car_horn_and_honking
188
+ 186,Walk_and_footsteps
189
+ 187,Water
190
+ 188,Water_tap_and_faucet
191
+ 189,Waves_and_surf
192
+ 190,Whispering
193
+ 191,Whoosh_and_swoosh_and_swish
194
+ 192,Wild_animals
195
+ 193,Wind
196
+ 194,Wind_chime
197
+ 195,Wind_instrument_and_woodwind_instrument
198
+ 196,Wood
199
+ 197,Writing
200
+ 198,Yell
201
+ 199,Zipper_(clothing)