Spaces:
Sleeping
Sleeping
gradio app
Browse files- .gitignore +5 -0
- LICENSE +21 -0
- README.md +320 -13
- app.py +71 -0
- config.py +28 -0
- data_util/audioset_classes.py +1393 -0
- data_util/audioset_strong.py +329 -0
- data_util/dcase2016task2.py +280 -0
- data_util/transforms.py +195 -0
- ex_audioset_strong.py +504 -0
- ex_dcase2016task2.py +517 -0
- helpers/augment.py +225 -0
- helpers/decode.py +72 -0
- helpers/encode.py +230 -0
- helpers/score.py +384 -0
- helpers/utils.py +12 -0
- images/downstream_task_results.png +0 -0
- inference.py +126 -0
- models/asit/ASIT_wrapper.py +60 -0
- models/asit/data_transformations.py +29 -0
- models/asit/utils.py +540 -0
- models/asit/vision_transformer.py +316 -0
- models/atstframe/ATSTF_wrapper.py +105 -0
- models/atstframe/audio_transformer.py +253 -0
- models/atstframe/transformer.py +112 -0
- models/beats/BEATs.py +183 -0
- models/beats/BEATs_wrapper.py +56 -0
- models/beats/Tokenizers.py +172 -0
- models/beats/backbone.py +783 -0
- models/beats/modules.py +218 -0
- models/beats/quantizer.py +215 -0
- models/frame_mn/Frame_MN_wrapper.py +75 -0
- models/frame_mn/block_types.py +189 -0
- models/frame_mn/model.py +356 -0
- models/frame_mn/utils.py +93 -0
- models/frame_passt/fpasst.py +963 -0
- models/frame_passt/fpasst_wrapper.py +86 -0
- models/frame_passt/preprocess.py +147 -0
- models/frame_passt/vit_helpers.py +399 -0
- models/m2d/M2D_wrapper.py +52 -0
- models/m2d/portable_m2d.py +410 -0
- models/prediction_wrapper.py +213 -0
- models/seq_models.py +40 -0
- models/transformer_wrapper.py +19 -0
- requirements.txt +17 -0
- resources/README.md +1 -0
- resources/best_model_BEATs.pth +3 -0
- resources/eval_durations.csv +0 -0
- resources/labelvocabulary.csv +89 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
__init__.pyc
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Florian Schmid
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,13 +1,320 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Effective Pre-Training of Audio Transformers for Sound Event Detection
|
| 2 |
+
|
| 3 |
+
In this repository, we publish pre-trained models and code for the ICASSP'25 paper: [**Effective Pre-Training of Audio Transformers for Sound Event Detection**](https://arxiv.org/abs/2409.09546).
|
| 4 |
+
|
| 5 |
+
In this paper, we propose a pre-training pipeline for audio spectrogram transformers for frame-level sound event detection tasks. On top of common pre-training steps, we add a meticulously designed training routine on AudioSet frame-level annotations. For five transformers, we show that this additional pre-training step leads to substantial performance improvements on frame-level downstream tasks. We release all model checkpoints and hope that they will help researchers improve tasks that require high-quality frame-level representations.
|
| 6 |
+
|
| 7 |
+
This repository includes:
|
| 8 |
+
* All pre-trained checkpoints and model files (see [here](https://github.com/fschmid56/PretrainedSED/releases/tag/v0.0.1))
|
| 9 |
+
* A script that demonstrates how the pre-trained checkpoints can be loaded and used for inference (see [here](https://github.com/fschmid56/PretrainedSED/blob/main/inference.py))
|
| 10 |
+
* Add a table outlining the external checkpoints used in this work (see [here](https://github.com/fschmid56/PretrainedSED?tab=readme-ov-file#model-checkpoints))
|
| 11 |
+
* Evaluation routine on the AudioSet frame-level annotations (see [here](https://github.com/fschmid56/PretrainedSED?tab=readme-ov-file#run-audioset-strong-evaluation))
|
| 12 |
+
* The AudioSet Strong training routine (see [here](https://github.com/fschmid56/PretrainedSED?tab=readme-ov-file#audioset-strong-pre-training))
|
| 13 |
+
* The ensemble logits for the AudioSet Strong dataset (see [here](https://github.com/fschmid56/PretrainedSED?tab=readme-ov-file#download-ensemble-pseudo-labels))
|
| 14 |
+
* A file demonstrating how the pre-trained transformers can be fine-tuned on a downstream task (see [here](ex_dcase2016task2.py))
|
| 15 |
+
* **New:** added two low-complexity SED models ('frame_mn10' with 3.83M parameters and 'frame_mn06' with 1.62M parameters)
|
| 16 |
+
|
| 17 |
+
## Setting up Environment
|
| 18 |
+
|
| 19 |
+
1. If needed, create a new environment with python 3.9 and activate it:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
conda create -n ptsed python=3.9 cython
|
| 23 |
+
conda activate ptsed
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
2. Install pytorch build that suits your system. For example:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
| 30 |
+
# or for cuda >= 12.1
|
| 31 |
+
pip3 install torch torchvision torchaudio
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
3. Install the requirements:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
pip3 install -r requirements.txt
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
4. Install package for mp3 decoding:
|
| 41 |
+
|
| 42 |
+
``` bash
|
| 43 |
+
CFLAGS='-O3 -march=native' pip install https://github.com/f0k/minimp3py/archive/master.zip
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Inference
|
| 47 |
+
|
| 48 |
+
The script [inference.py](inference.py) demonstrates how to load a pre-trained model and run sound event detection on an audio file
|
| 49 |
+
of arbitrary length.
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
python inference.py --cuda --model_name="BEATs" --audio_file="test_files/752547__iscence__milan_metro_coming_in_station.wav"
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
The argument ```model_name``` specifies the transformer used for inference, and the corresponding pre-trained model checkpoint
|
| 56 |
+
is automatically downloaded and placed in the folder [resources](resources).
|
| 57 |
+
|
| 58 |
+
The argument ```audio_file``` specifies the path to a single audio file. There is one [example file](test_files/752547__iscence__milan_metro_coming_in_station.wav) included.
|
| 59 |
+
More example files can be downloaded from the [GitHub release](https://github.com/fschmid56/PretrainedSED/releases/tag/v0.0.1).
|
| 60 |
+
|
| 61 |
+
**Low-complexity** inference with customized MobileNet:
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
python inference.py --cuda --model_name="frame_mn06" --audio_file="test_files/752547__iscence__milan_metro_coming_in_station.wav"
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Model Checkpoints
|
| 68 |
+
|
| 69 |
+
The following is a list of checkpoints that we have created and worked with in our paper. For external checkpoints, we provide the download link. "Checkpoint Name" refers to the respective names in our [GitHub release](https://github.com/fschmid56/PretrainedSED/releases/tag/v0.0.1). **All model checkpoints** are automatically downloaded by running the code, or can be manually downloaded from the [GitHub release](https://github.com/fschmid56/PretrainedSED/releases/tag/v0.0.1).
|
| 70 |
+
|
| 71 |
+
| Model | Pre-Training | Checkpoint Name | External Download Link | Reference |
|
| 72 |
+
|----------------------|--------------|--------------------|---------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------|
|
| 73 |
+
| BEATs | SSL | BEATs_ssl.pt | [here](https://1drv.ms/u/s!AqeByhGUtINrgcpxJUNDxg4eU0r-vA?e=qezPJ5) | [[1]](https://arxiv.org/pdf/2212.09058) |
|
| 74 |
+
| BEATs | Weak | BEATs_weak.pt | [here](https://1drv.ms/u/s!AqeByhGUtINrgcpke6_lRSZEKD5j2Q?e=A3FpOf) | [[1]](https://arxiv.org/pdf/2212.09058) |
|
| 75 |
+
| BEATs | Strong | BEATs_strong_1.pt | ours | [[1]](https://arxiv.org/pdf/2212.09058) |
|
| 76 |
+
| ATST-Frame | SSL | ATST-F_ssl.pt | [here](https://drive.google.com/file/d/1bGJSZWlAIIJ6GL5Id5dW0PTB72DL-QDQ/view?usp=sharing) | [[2]](https://arxiv.org/pdf/2306.04186) |
|
| 77 |
+
| ATST-Frame | Weak | ATST-F_weak.pt | [here](https://drive.google.com/file/d/1_xb0_n3UNbUG_pH1vLHTviLfsaSfCzxz/view?usp=drive_link) | [[2]](https://arxiv.org/pdf/2306.04186) |
|
| 78 |
+
| ATST-Frame | Strong | ATST-F_strong_1.pt | ours | [[2]](https://arxiv.org/pdf/2306.04186) |
|
| 79 |
+
| fPaSST | SSL | fpasst_im.pt | [here](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth) | [[3]](https://arxiv.org/pdf/2110.05069), [[4]](https://arxiv.org/pdf/2407.12997) |
|
| 80 |
+
| fPaSST | Weak | fpasst_weak.pt | ours | [[3]](https://arxiv.org/pdf/2110.05069), [[4]](https://arxiv.org/pdf/2407.12997) |
|
| 81 |
+
| fPaSST | Strong | fpasst_strong_1.pt | ours | [[3]](https://arxiv.org/pdf/2110.05069), [[4]](https://arxiv.org/pdf/2407.12997) |
|
| 82 |
+
| ASiT | SSL | ASIT_ssl.pt | [here](https://drive.google.com/file/d/11eaOU40jonpYZ3u_XI-XUSSWclv8qeR7/view?usp=drive_link) | [[5]](https://arxiv.org/pdf/2211.13189) |
|
| 83 |
+
| ASiT | Weak | ASIT_weak.pt | ours | [[5]](https://arxiv.org/pdf/2211.13189) |
|
| 84 |
+
| ASiT | Strong | ASIT_strong_1.pt | ours | [[5]](https://arxiv.org/pdf/2211.13189) |
|
| 85 |
+
| M2D | SSL | M2D_ssl.pt | [here](https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_clap_vit_base-80x1001p16x16-240128_AS-FT_enconly.zip) | [[6]](https://arxiv.org/pdf/2406.02032) |
|
| 86 |
+
| M2D | Weak | M2D_weak.pt | [here](https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_clap_vit_base-80x1001p16x16-240128_AS-FT_enconly.zip) | [[6]](https://arxiv.org/pdf/2406.02032) |
|
| 87 |
+
| M2D | Strong | M2D_strong_1.pt | ours | [[6]](https://arxiv.org/pdf/2406.02032) |
|
| 88 |
+
| Customized MobileNet | Strong | frame_mn06.pt | ours | **NEW** |
|
| 89 |
+
| Customized MobileNet | Strong | frame_mn10.pt | ours | **NEW** |
|
| 90 |
+
|
| 91 |
+
## AudioSet Strong pre-training
|
| 92 |
+
|
| 93 |
+
### Prepare Dataset
|
| 94 |
+
|
| 95 |
+
1. Follow the steps described [here](https://github.com/kkoutini/PaSST/tree/main/audioset#experiments-on-audioset) to obtain AudioSet, encoded as mp3 files and packed into HDF5 format.
|
| 96 |
+
|
| 97 |
+
You will end up with a directory containing three HDF5 files:
|
| 98 |
+
* balanced_train_segments_mp3.hdf
|
| 99 |
+
* unbalanced_train_segments_mp3.hdf
|
| 100 |
+
* eval_segments_mp3.hdf
|
| 101 |
+
|
| 102 |
+
2. We use the [Huggingface datasets](https://huggingface.co/docs/datasets/index) API for fast and memory-efficient loading of the dataset. The [hf_dataset_gen/audioset_strong.py](hf_dataset_gen/audioset_strong.py) file takes the dataset from Step 1 and converts it into a Huggingface dataset.
|
| 103 |
+
|
| 104 |
+
Adapt the paths in [hf_dataset_gen/audioset_strong.py](hf_dataset_gen/audioset_strong.py) marked as TODOs (2x: hdf5 path and target path for HF dataset).
|
| 105 |
+
|
| 106 |
+
3. Create the Hunggingface dataset:
|
| 107 |
+
```
|
| 108 |
+
cd hf_dataset_gen
|
| 109 |
+
python audioset_strong.py
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
4. The path to the dataset is specified via an environment variable. When you access the dataset for training or evaluation,
|
| 113 |
+
set the environment variable. For example, in our case, the Huggingface dataset path is set to:
|
| 114 |
+
|
| 115 |
+
```/share/hel/datasets/HF_datasets/local/audioset_strong```
|
| 116 |
+
|
| 117 |
+
And therefore we set the following environment variable:
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
export HF_DATASETS_CACHE=/share/hel/datasets/HF_datasets/cache/
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### Download ensemble pseudo labels
|
| 124 |
+
|
| 125 |
+
If you want to train on AudioSet Strong using Knowledge Distillation as described in the paper, you will have to download the
|
| 126 |
+
ensemble logits from [Zenodo](https://zenodo.org/records/14626113). The HDF5 file contains filenames (Youtube IDs) matched with the corresponding ensembled logits. The corresponding keys are "filenames" and "strong_logits". Ensemble Logits for one file are of shape 447 x 250 (number of classes x timeframes at 40 ms resolution). Ensemble Logits are stored in float16 format to save space.
|
| 127 |
+
|
| 128 |
+
Check out [this code piece](https://github.com/fschmid56/PretrainedSED/blob/f62e9fb1566254766396cce0343a2de4156d3015/data_util/transforms.py#L37) if you want to learn how pseudo labels are loaded.
|
| 129 |
+
|
| 130 |
+
For training, the pseudo-label file can simply be set via command line: ```--pseudo_labels_file=<location>```
|
| 131 |
+
|
| 132 |
+
### Run AudioSet Strong training
|
| 133 |
+
|
| 134 |
+
Example: Train ATST-F, pretrained on AudioSet weak, with an RNN on top, use the balanced sampler and set wavmix augmentation to probability of 1.0.
|
| 135 |
+
|
| 136 |
+
```
|
| 137 |
+
python ex_audioset_strong.py --model_name=ATST-F --seq_model_type=rnn --use_balanced_sampler --pretrained=weak --wavmix_p=1.0
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
Check out the results: https://api.wandb.ai/links/cp_tobi/tphswm5k
|
| 141 |
+
|
| 142 |
+
Example: Train ATST-F using Knowledge Distillation.
|
| 143 |
+
|
| 144 |
+
```
|
| 145 |
+
python ex_audioset_strong.py --model_name=ATST-F --pretrained=weak --n_epochs=120 --wavmix_p=0.5 --freq_warp_p=0 --filter_augment_p=0 --mixstyle_p=0 --max_lr=1e-4 --distillation_loss_weight=0.9 --pseudo_labels_file=<path_to_pseudo_label_file_from_Zenodo>
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Check out the results: https://api.wandb.ai/links/cp_tobi/2eh4cz80
|
| 149 |
+
|
| 150 |
+
### Run AudioSet Strong evaluation
|
| 151 |
+
|
| 152 |
+
Evaluate the AudioSet Strong pre-trained checkpoint of ATST-F:
|
| 153 |
+
|
| 154 |
+
```
|
| 155 |
+
python ex_audioset_strong.py --model_name=ATST-F --pretrained=strong --evaluate
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
If everything is set up correctly, this should give a `val/psds1_macro_averaged` of around 46.
|
| 159 |
+
|
| 160 |
+
## Fine-Tuning on Downstream Task
|
| 161 |
+
|
| 162 |
+
We demonstrate how pre-trained transformers can be fine-tuned for the downstream Sound Event Detection task by using our transformers on [DCASE 2016 Task 2](https://dcase.community/challenge2016/task-sound-event-detection-in-synthetic-audio-results). This task focuses on detecting office sounds and is part of the [HEAR benchmark](https://hearbenchmark.com/hear-tasks.html).
|
| 163 |
+
|
| 164 |
+
### Obtain DCASE 2016 Task 2 Dataset in HEAR format
|
| 165 |
+
|
| 166 |
+
Follow the instructions on the [HEAR website](https://hearbenchmark.com/hear-tasks.html) to download the dataset in 16 kHz sampling rate. After completing the setup, your file tree should look similar to this:
|
| 167 |
+
```
|
| 168 |
+
hear_datasets/tasks/dcase2016_task2-hear2021-full/
|
| 169 |
+
├── 16000
|
| 170 |
+
├── 48000
|
| 171 |
+
├── labelvocabulary.csv
|
| 172 |
+
├── task_metadata.json
|
| 173 |
+
├── test.json
|
| 174 |
+
├── train.json
|
| 175 |
+
└── valid.json
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
The ```16000``` folder contains audio files sampled at 16 kHz.
|
| 179 |
+
|
| 180 |
+
### Run Fine-Tuning
|
| 181 |
+
|
| 182 |
+
The main script for fine-tuning is [ex_dcase2016task2.py](ex_dcase2016task2.py).
|
| 183 |
+
|
| 184 |
+
To fine-tune the full ATST-F model, pre-trained on AudioSet Strong, with a layer-wise learning rate decay of 0.95, use the following command:
|
| 185 |
+
|
| 186 |
+
```
|
| 187 |
+
python ex_dcase2016task2.py --task_path=hear_datasets/tasks/dcase2016_task2-hear2021-full --model_name=ATST-F --pretrained=strong --lr_decay=0.95
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
To train only the linear prediction head on top of the frozen BEATs transformer, also pre-trained on AudioSet Strong, use this command:
|
| 191 |
+
|
| 192 |
+
```
|
| 193 |
+
python ex_dcase2016task2.py --task_path=hear_datasets/tasks/dcase2016_task2-hear2021-full --model_name=BEATs --pretrained=strong --transformer_frozen --max_lr=2e-1 --mixup_p=0 --wavmix_p=0 --no_adamw --weight_decay=0 --n_epochs=500
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
## Results & Ablation Studies
|
| 197 |
+
|
| 198 |
+
This section presents the main results reported [in the paper](https://arxiv.org/pdf/2409.09546), along with additional ablation studies, including teacher model performances, comparisons of different sequence models, and evaluations using the DESED baseline system setup. The additional ablation studies have been requested by ICASSP`25 reviewers.
|
| 199 |
+
|
| 200 |
+
* All results represent averages over three independent runs.
|
| 201 |
+
* For AudioSet Strong, we employ the threshold-independent PSDS1 [7] metric to ensure fine-grained temporal evaluation.
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
### Student Model Performances on AudioSet Strong (*from paper*)
|
| 205 |
+
|
| 206 |
+
* For the *Li et al. [2]* row, we reproduced their AudioSet Strong [training pipeline](https://github.com/Audio-WestlakeU/audiossl).
|
| 207 |
+
* Alongside the **Proposed Pipeline**, we include ablation studies for three settings: no KD, no RNN in teacher models, and no pre-training on AudioSet Weak (no Step 2).
|
| 208 |
+
|
| 209 |
+
| | **ATST-F** | **BEATs** | **fPaSST** | **M2D** | **ASiT** |
|
| 210 |
+
|-----------------------|------------|-----------|------------|----------|----------|
|
| 211 |
+
| **Li et al. [2]** | 40.9 | 36.5 | 38.7 | 36.9 | 37.0 |
|
| 212 |
+
| **Proposed Pipeline** | **45.8** | **46.5** | **45.4** | **46.3** | **46.2** |
|
| 213 |
+
| **-- without KD** | 41.8 | 44.1 | 40.7 | 41.1 | 40.9 |
|
| 214 |
+
| **-- without RNN** | 45.7 | 45.8 | 45.3 | 46.0 | 46.1 |
|
| 215 |
+
| **-- without Step 2** | 45.7 | 46.3 | 45.2 | 44.9 | **46.2** |
|
| 216 |
+
|
| 217 |
+
**Conclusions:**
|
| 218 |
+
* The significant performance gap to [2] stems mainly from our three design choices (KD, RNNs, Step 2), but also improvements in training on AudioSet Strong, including balanced sampling and aggressive data augmentation.
|
| 219 |
+
* Knowledge Distillation (KD) has the most substantial impact, underlining the effectiveness of the ensemble-KD approach.
|
| 220 |
+
* RNNs in teacher models and pre-training on AudioSet Weak offer modest improvements but are justified due to their low additional cost. Notably, they do not increase student model complexity, and AudioSet Weak checkpoints are publicly available for most transformers.
|
| 221 |
+
|
| 222 |
+
### Teacher Model Performances on AudioSet Strong (*additional results*)
|
| 223 |
+
|
| 224 |
+
* The table below shows teacher model results for each transformer.
|
| 225 |
+
* Column **Avg. Ind.** represents the average performance across all single models in the row.
|
| 226 |
+
* Column **Ensemble** represents the performance of the ensemble consisting of all models in the respective row.
|
| 227 |
+
|
| 228 |
+
| | **ATST-F** | **BEATs** | **fPaSST** | **M2D** | **ASiT** | **Avg. Ind.** | **Ensemble** |
|
| 229 |
+
|-------------------------------|------------|-----------|------------|----------|----------|---------------|--------------|
|
| 230 |
+
| **Proposed Teacher Pipeline** | 43.3 | **45.8** | **43.3** | **44.1** | **43.3** | **44.9** | **47.1** |
|
| 231 |
+
| **-- without RNN** | 41.8 | 44.1 | 40.7 | 41.1 | 40.9 | 41.7 | 46.2 |
|
| 232 |
+
| **-- without Step 2** | **43.5** | 34.4 | 40.9 | 43.8 | 43.2 | 41.2 | 46.5 |
|
| 233 |
+
|
| 234 |
+
**Conclusions:**
|
| 235 |
+
* *Ensemble Performance*: The *Ensemble* column reflects the teacher ensemble performances utilized for Knowledge Distillation (KD) in table above.
|
| 236 |
+
* *Impact of RNNs and Step 2*: Incorporating RNNs and Step 2 (AudioSet Weak pre-training) notably enhances single-model teacher performance, with the exception of ATST-F without Step 2.
|
| 237 |
+
* *Benefits of Ensembling*: While individual model performances show considerable variability (Avg. Ind.), ensembling stabilizes and elevates overall performance, as evidenced by the smaller differences in the *Ensemble* column.
|
| 238 |
+
* *BEATs-Specific Insights*: BEATs excels in the *Proposed Teacher Pipeline* and *without RNN* settings but underperforms in the *without Step 2* configuration. This discrepancy may be attributed to its unique SSL pre-training routine and longer sequence length (resulting from more tokens being extracted from the input).
|
| 239 |
+
|
| 240 |
+
### Teacher Model Performances with different Sequence Models (*additional results*)
|
| 241 |
+
|
| 242 |
+
* The use of an additional sequence model on top of the AudioSet Weak pre-trained transformers stems from our hypothesis that adding capacity specifically for temporally-strong predictions can enhance performance.
|
| 243 |
+
* The table below shows teacher model performances for various sequence models added on top of the transformers before training on AudioSet Strong. The paper uses BiGRUs (RNN) as they deliver the best performance.
|
| 244 |
+
* We investigated 4 different sequence models:
|
| 245 |
+
* RNN: BiGRUs
|
| 246 |
+
* Attention: Multi-Head Self-Attention with rotary position embeddings
|
| 247 |
+
* Transformer (TF): Transformer Encoder blocks with rotary position embeddings
|
| 248 |
+
* [MAMBA](https://arxiv.org/abs/2312.00752): Implementation from [mambapy](https://github.com/alxndrTL/mamba.py)
|
| 249 |
+
* We varied the inner dimension (*dim*) and the number of layers (\<Model Type\>:\<#layers\>; e.g., TF:2 means two Transformer layers were added on top of the pre-trained transformer).
|
| 250 |
+
* **Setup Notes**:
|
| 251 |
+
* Ablations were performed using **ATST-F** due to its computational efficiency.
|
| 252 |
+
* Performance without a sequence model was **41.8 PSDS1**.
|
| 253 |
+
* Removing the top Transformer layers, which may overfit to AudioSet Weak labels, decreased performance.
|
| 254 |
+
* For MAMBA, only a single layer was feasible due to memory constraints.
|
| 255 |
+
|
| 256 |
+
| | RNN:1 | RNN:2 | RNN:3 | TF:1 | TF:2 | TF:3 | ATT:1 | ATT:2 | ATT:3 | MAMBA:1 |
|
| 257 |
+
|:-------------|:-----:|:---------:|:-----:|:---------:|:-----:|:---------:|:-----:|:---------:|:-----:|:---------:|
|
| 258 |
+
| **dim=256** | 8.72 | 3.76 | 3.10 | 34.25 | 34.62 | 34.05 | 40.08 | 39.70 | 39.55 | 40.27 |
|
| 259 |
+
| **dim=512** | 40.62 | 7.26 | 0.12 | 40.41 | 41.11 | 40.30 | 41.78 | 41.91 | 41.95 | 41.25 |
|
| 260 |
+
| **dim=1024** | 42.74 | 42.75 | 43.00 | 42.69 | 42.22 | 42.20 | 42.44 | **42.45** | 42.08 | **41.97** |
|
| 261 |
+
| **dim=2048** | 43.41 | **43.43** | 42.66 | **42.90** | 38.94 | **42.90** | 41.58 | 41.59 | 41.42 | 41.72 |
|
| 262 |
+
|
| 263 |
+
**Conclusions:**
|
| 264 |
+
* *Best model type*: The highest performance was achieved with 2 BiGRU layers, followed by Transformer, Self-Attention, and MAMBA. All sequence models improved performance compared to using no additional sequence model, though MAMBA's gains were marginal.
|
| 265 |
+
* *Inner Dimension*: Larger inner dimensions consistently led to better performance across all sequence models. Significant improvements required dimensions ≥1024, while smaller dimensions (e.g., 256) often degraded performance, with severe failures for BiGRU. We believe that large inner dimensions are essential due to the high number of classes (447) in AudioSet Strong.
|
| 266 |
+
* *Number of layers*: Performance was relatively insensitive to the number of layers for most sequence models, with optimal results often achieved with just 1–2 layers.
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
### Downstream Task Performances (*from paper*)
|
| 270 |
+
|
| 271 |
+
* Three frame-level downstream tasks:
|
| 272 |
+
* DCASE 2023 Task 4: Domestic Environment Sound Event Detection (*DESED*), metric: PSDS 1
|
| 273 |
+
* DCASE 2016 Task 2 (*DC16-T2*), metric: onset F-measure
|
| 274 |
+
* MAESTRO 5hr (*MAESTRO*), metric: onset F-measure
|
| 275 |
+
* For DESED, we followed a simplified setup in line with [2], excluding unsupervised data (no mean teacher approach) and an additional CRNN component from the [DCASE 2023 Task 4 baseline system](https://github.com/DCASE-REPO/DESED_task/tree/master/recipes/dcase2023_task4_baseline). While state-of-the-art approaches such as [4] and [8] leverage advanced techniques (e.g., multi-stage/multi-iteration training, sophisticated data augmentation, and interpolation consistency training), we deliberately avoided these complexities, as the focus is on a precise evaluation of pre-training quality.
|
| 276 |
+
|
| 277 |
+

|
| 278 |
+
|
| 279 |
+
**Conclusions:**
|
| 280 |
+
* *In-Domain Tasks*: The pipeline demonstrates strong, consistent improvements for all transformers on *DESED* and *DC16-T2*, showcasing its effectiveness for in-domain tasks.
|
| 281 |
+
* *Out-of-Domain Task*: Results on *MAESTRO* (piano pitch prediction) are inconclusive. This limitation suggests that the proposed pre-training strategy yields substantial gains only when audio and labels are similar to the AudioSet ontology.
|
| 282 |
+
* *Simplified DESED Setup*: Despite the simplified setup (no CRNN, no unsupervised data), performance remains comparable to the [DCASE 2023 Task 4 baseline system](https://github.com/DCASE-REPO/DESED_task/tree/master/recipes/dcase2023_task4_baseline).
|
| 283 |
+
|
| 284 |
+
#### DESED Baseline Setup (*additional results*)
|
| 285 |
+
|
| 286 |
+
To complement the simplified DESED setup presented earlier, we provide results for the [DCASE 2023 Task 4 baseline system](https://github.com/DCASE-REPO/DESED_task/tree/master/recipes/dcase2023_task4_baseline) setup for ATST-F and BEATs in the table below. Note that hyperparameters were not extensively tuned, and the data setup may differ slightly from the original baseline.
|
| 287 |
+
|
| 288 |
+
| **Model** | **Checkpoint** | **Notes** | **Performance** |
|
| 289 |
+
|-----------|------------------|---------------------|-----------------|
|
| 290 |
+
| ATST-F | Step 1 (SSL) | | 42.7 |
|
| 291 |
+
| ATST-F | Step 2 (AS weak) | | 47.1 |
|
| 292 |
+
| ATST-F | Full Pipeline | | 50.4 |
|
| 293 |
+
| ATST-F | Full Pipeline | dropped 2 TF layers | **51.1** |
|
| 294 |
+
| BEATs | Step 1 (SSL) | | 39.7 |
|
| 295 |
+
| BEATs | Step 2 (AS weak) | | 48.1 |
|
| 296 |
+
| BEATs | Full Pipeline | | 48.6 |
|
| 297 |
+
| BEATs | Full Pipeline | dropped 2 TF layers | **51.1** |
|
| 298 |
+
|
| 299 |
+
**Conclusions**:
|
| 300 |
+
* The *Full Pipeline* substantially improves performance over *Step 1 (SSL)* and *Step 2 (AS Weak)* for both ATST-F and BEATs.
|
| 301 |
+
* Dropping the last two Transformer layers notably enhances results, suggesting that the final layers may focus on AudioSet Strong label-specific features, while earlier layers provide more general, transferable embeddings that benefit the DESED task. We will conduct further experiments to find out whether dropping Transformer layers is generalizable to other tasks or specific to the DESED task.
|
| 302 |
+
|
| 303 |
+
# References
|
| 304 |
+
|
| 305 |
+
[1] S. Chen, Y. Wu, C. Wang, S. Liu, D. Tompkins, Z. Chen, W. Che, X. Yu, and F. Wei, “BEATs: Audio pre-training with acoustic tokenizers,” in Proceedings of the International Conference on Machine Learning (ICML), 2023.
|
| 306 |
+
|
| 307 |
+
[2] X. Li, N. Shao, and X. Li, “Self-supervised audio teacher-student transformer for both clip-level and frame-level tasks,” Transactions on Audio, Speech, and Language Processing, vol. 32, pp. 1336–1351, 2024.
|
| 308 |
+
|
| 309 |
+
[3] K. Koutini, J. Schl¨uter, H. Eghbal-zadeh, and G. Widmer, “Efficient training of audio transformers with patchout,” in Proceedings of the Interspeech Conference, 2022.
|
| 310 |
+
|
| 311 |
+
[4] F. Schmid, P. Primus, T. Morocutti, J. Greif, and G. Widmer, “Multi-iteration multi-stage fine-tuning of transformers for sound event detection with heterogeneous datasets,” in Workshop on Detection and Classification of Acoustic Scenes and Events (DCASE), 2024.
|
| 312 |
+
|
| 313 |
+
[5] S. Atito, M. Awais, W. Wang, M. D. Plumbley, and J. Kittler, “ASiT: Local-global audio spectrogram vision transformer for event classification,” IEEE ACM Trans. Audio Speech Lang. Process., vol. 32, pp. 3684–3693, 2024.
|
| 314 |
+
|
| 315 |
+
[6] D. Niizumi, D. Takeuchi, Y. Ohishi, N. Harada, M. Yasuda, S. Tsubaki, and K. Imoto, “M2D-CLAP: masked modeling duo meets CLAP for learning general-purpose audio-language representation,” in Proceedings of the Interspeech Conference, 2024.
|
| 316 |
+
|
| 317 |
+
[7] J. Ebbers, R. Haeb-Umbach, and R. Serizel, “Threshold independent evaluation of sound event detection scores,” in Proceedings of the International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2022.
|
| 318 |
+
|
| 319 |
+
[8] N. Shao, X. Li, and X. Li, “Fine-tune the pretrained ATST model for sound event detection,” in Proceedings of the International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2024
|
| 320 |
+
|
app.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from models.atstframe.ATSTF_wrapper import ATSTWrapper
|
| 4 |
+
from models.beats.BEATs_wrapper import BEATsWrapper
|
| 5 |
+
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
|
| 6 |
+
from models.m2d.M2D_wrapper import M2DWrapper
|
| 7 |
+
from models.asit.ASIT_wrapper import ASiTWrapper
|
| 8 |
+
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
|
| 9 |
+
from models.prediction_wrapper import PredictionsWrapper
|
| 10 |
+
from models.frame_mn.utils import NAME_TO_WIDTH
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
class TransformerClassifier(nn.Module):
|
| 16 |
+
def __init__(self, model, n_classes):
|
| 17 |
+
super(TransformerClassifier, self).__init__()
|
| 18 |
+
self.model = model
|
| 19 |
+
self.linear = nn.Linear(model.embed_dim, n_classes)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
mel = self.model.mel_forward(x)
|
| 23 |
+
features = self.model(mel).squeeze(1)
|
| 24 |
+
return self.linear(features)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_model(model_name):
|
| 28 |
+
if model_name == "BEATs":
|
| 29 |
+
beats = BEATsWrapper()
|
| 30 |
+
model = PredictionsWrapper(beats, checkpoint=None, head_type=None, seq_len=1)
|
| 31 |
+
elif model_name == "ATST-F":
|
| 32 |
+
atst = ATSTWrapper()
|
| 33 |
+
model = PredictionsWrapper(atst, checkpoint=None, head_type=None, seq_len=1)
|
| 34 |
+
elif model_name == "fpasst":
|
| 35 |
+
fpasst = FPaSSTWrapper()
|
| 36 |
+
model = PredictionsWrapper(fpasst, checkpoint=None, head_type=None, seq_len=1)
|
| 37 |
+
elif model_name == "M2D":
|
| 38 |
+
m2d = M2DWrapper()
|
| 39 |
+
model = PredictionsWrapper(m2d, checkpoint=None, head_type=None, seq_len=1,
|
| 40 |
+
embed_dim=m2d.m2d.cfg.feature_d)
|
| 41 |
+
elif model_name == "ASIT":
|
| 42 |
+
asit = ASiTWrapper()
|
| 43 |
+
model = PredictionsWrapper(asit, checkpoint=None, head_type=None, seq_len=1)
|
| 44 |
+
elif model_name.startswith("frame_mn"):
|
| 45 |
+
width = NAME_TO_WIDTH(model_name)
|
| 46 |
+
frame_mn = FrameMNWrapper(width)
|
| 47 |
+
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
|
| 48 |
+
model = PredictionsWrapper(frame_mn, checkpoint=None, head_type=None, seq_len=1, embed_dim=embed_dim)
|
| 49 |
+
else:
|
| 50 |
+
raise NotImplementedError(f"Model {model_name} not (yet) implemented")
|
| 51 |
+
main_model = TransformerClassifier(model, n_classes=88)
|
| 52 |
+
# main_model.compile()
|
| 53 |
+
main_model.load_state_dict(torch.load(f"resources/best_model_{model_name}.pth", map_location='cpu'))
|
| 54 |
+
print(main_model)
|
| 55 |
+
main_model.eval()
|
| 56 |
+
return main_model
|
| 57 |
+
|
| 58 |
+
model = get_model("BEATs")
|
| 59 |
+
label_mapping = pd.read_csv("resources/labelvocabulary.csv", header=None, index_col=0).to_dict()[1]
|
| 60 |
+
def apply_sepia(input_audio):
|
| 61 |
+
# Apply sepia effect to the audio
|
| 62 |
+
waveform = torch.from_numpy(input_audio[1]).float() # Convert to tensor
|
| 63 |
+
output = model(waveform.unsqueeze(0))
|
| 64 |
+
output = output.detach().cpu().numpy()
|
| 65 |
+
output = np.argmax(output, axis=1)
|
| 66 |
+
return int(label_mapping[str(output.item())])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
demo = gr.Interface(apply_sepia, gr.Audio(max_length=4,), "number",title="NSynth Pitch Classification",)
|
| 71 |
+
demo.launch()
|
config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
RESOURCES_FOLDER = "resources"
|
| 2 |
+
GITHUB_RELEASE_URL = "https://github.com/fschmid56/PretrainedSED/releases/download/v0.0.1/"
|
| 3 |
+
|
| 4 |
+
# checkpoints
|
| 5 |
+
CHECKPOINT_URLS = {}
|
| 6 |
+
|
| 7 |
+
# strong
|
| 8 |
+
CHECKPOINT_URLS['BEATs_strong_1'] = GITHUB_RELEASE_URL + "BEATs_strong_1.pt"
|
| 9 |
+
CHECKPOINT_URLS['ATST-F_strong_1'] = GITHUB_RELEASE_URL + "ATST-F_strong_1.pt"
|
| 10 |
+
CHECKPOINT_URLS['ASIT_strong_1'] = GITHUB_RELEASE_URL + "ASIT_strong_1.pt"
|
| 11 |
+
CHECKPOINT_URLS['fpasst_strong_1'] = GITHUB_RELEASE_URL + "fpasst_strong_1.pt"
|
| 12 |
+
CHECKPOINT_URLS['M2D_strong_1'] = GITHUB_RELEASE_URL + "M2D_strong_1.pt"
|
| 13 |
+
for width in ['06', '10']:
|
| 14 |
+
CHECKPOINT_URLS[f'frame_mn{width}_strong_1'] = GITHUB_RELEASE_URL + f'frame_mn{width}_strong_1.pt'
|
| 15 |
+
|
| 16 |
+
# weak
|
| 17 |
+
CHECKPOINT_URLS['BEATs_weak'] = GITHUB_RELEASE_URL + "BEATs_weak.pt"
|
| 18 |
+
CHECKPOINT_URLS['ATST-F_weak'] = GITHUB_RELEASE_URL + "ATST-F_weak.pt"
|
| 19 |
+
CHECKPOINT_URLS['ASIT_weak'] = GITHUB_RELEASE_URL + "ASIT_weak.pt"
|
| 20 |
+
CHECKPOINT_URLS['fpasst_weak'] = GITHUB_RELEASE_URL + "fpasst_weak.pt"
|
| 21 |
+
CHECKPOINT_URLS['M2D_weak'] = GITHUB_RELEASE_URL + "M2D_weak.pt"
|
| 22 |
+
|
| 23 |
+
# ssl
|
| 24 |
+
CHECKPOINT_URLS['BEATs_ssl'] = GITHUB_RELEASE_URL + "BEATs_ssl.pt"
|
| 25 |
+
CHECKPOINT_URLS['ATST-F_ssl'] = GITHUB_RELEASE_URL + "ATST-F_ssl.pt"
|
| 26 |
+
CHECKPOINT_URLS['ASIT_ssl'] = GITHUB_RELEASE_URL + "ASIT_ssl.pt"
|
| 27 |
+
CHECKPOINT_URLS['fpasst_ssl'] = GITHUB_RELEASE_URL + "fpasst_ssl.pt"
|
| 28 |
+
CHECKPOINT_URLS['M2D_ssl'] = GITHUB_RELEASE_URL + "M2D_ssl.pt"
|
data_util/audioset_classes.py
ADDED
|
@@ -0,0 +1,1393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
as_strong_train_classes = ['Accelerating, revving, vroom',
|
| 2 |
+
'Air brake',
|
| 3 |
+
'Air conditioning',
|
| 4 |
+
'Air horn, truck horn',
|
| 5 |
+
'Aircraft',
|
| 6 |
+
'Aircraft engine',
|
| 7 |
+
'Alarm',
|
| 8 |
+
'Alarm clock',
|
| 9 |
+
'Alert',
|
| 10 |
+
'Ambulance (siren)',
|
| 11 |
+
'Animal',
|
| 12 |
+
'Applause',
|
| 13 |
+
'Arrow',
|
| 14 |
+
'Artillery fire',
|
| 15 |
+
'Audio logo',
|
| 16 |
+
'Babbling',
|
| 17 |
+
'Baby cry, infant cry',
|
| 18 |
+
'Baby laughter',
|
| 19 |
+
'Background noise',
|
| 20 |
+
'Bang',
|
| 21 |
+
'Bark',
|
| 22 |
+
'Basketball bounce',
|
| 23 |
+
'Bathroom sounds',
|
| 24 |
+
'Bathtub (filling or washing)',
|
| 25 |
+
'Battle cry',
|
| 26 |
+
'Bee, wasp, etc.',
|
| 27 |
+
'Beep, bleep',
|
| 28 |
+
'Bell',
|
| 29 |
+
'Bellow',
|
| 30 |
+
'Belly laugh',
|
| 31 |
+
'Bicycle bell',
|
| 32 |
+
'Bicycle, tricycle',
|
| 33 |
+
'Bird',
|
| 34 |
+
'Bird flight, flapping wings',
|
| 35 |
+
'Bird vocalization, bird call, bird song',
|
| 36 |
+
'Biting',
|
| 37 |
+
'Bleat',
|
| 38 |
+
'Blender, food processor',
|
| 39 |
+
'Boat, Water vehicle',
|
| 40 |
+
'Boiling',
|
| 41 |
+
'Boing',
|
| 42 |
+
'Booing',
|
| 43 |
+
'Boom',
|
| 44 |
+
'Bouncing',
|
| 45 |
+
'Bow-wow',
|
| 46 |
+
'Breaking',
|
| 47 |
+
'Breathing',
|
| 48 |
+
'Brief tone',
|
| 49 |
+
'Burping, eructation',
|
| 50 |
+
'Burst, pop',
|
| 51 |
+
'Bus',
|
| 52 |
+
'Busy signal',
|
| 53 |
+
'Buzz',
|
| 54 |
+
'Buzzer',
|
| 55 |
+
'Cacophony',
|
| 56 |
+
'Camera',
|
| 57 |
+
'Canidae, wild dogs, wolves',
|
| 58 |
+
'Cap gun',
|
| 59 |
+
'Car',
|
| 60 |
+
'Car alarm',
|
| 61 |
+
'Car passing by',
|
| 62 |
+
'Carbon monoxide detector, CO detector',
|
| 63 |
+
'Cart',
|
| 64 |
+
'Cash register',
|
| 65 |
+
'Cat',
|
| 66 |
+
'Caterwaul',
|
| 67 |
+
'Cattle, bovinae',
|
| 68 |
+
'Caw',
|
| 69 |
+
'Cellphone buzz, vibrating alert',
|
| 70 |
+
'Chain',
|
| 71 |
+
'Chainsaw',
|
| 72 |
+
'Change ringing (campanology)',
|
| 73 |
+
'Channel, environment and background',
|
| 74 |
+
'Chant',
|
| 75 |
+
'Cheering',
|
| 76 |
+
'Chewing, mastication',
|
| 77 |
+
'Chicken, rooster',
|
| 78 |
+
'Child singing',
|
| 79 |
+
'Child speech, kid speaking',
|
| 80 |
+
'Children playing',
|
| 81 |
+
'Children shouting',
|
| 82 |
+
'Chime',
|
| 83 |
+
'Chipmunk',
|
| 84 |
+
'Chirp tone',
|
| 85 |
+
'Chirp, tweet',
|
| 86 |
+
'Choir',
|
| 87 |
+
'Chop',
|
| 88 |
+
'Chopping (food)',
|
| 89 |
+
'Chorus effect',
|
| 90 |
+
'Chuckle, chortle',
|
| 91 |
+
'Church bell',
|
| 92 |
+
'Civil defense siren',
|
| 93 |
+
'Clang',
|
| 94 |
+
'Clapping',
|
| 95 |
+
'Clatter',
|
| 96 |
+
'Clickety-clack',
|
| 97 |
+
'Clicking',
|
| 98 |
+
'Clip-clop',
|
| 99 |
+
'Clock',
|
| 100 |
+
'Cluck',
|
| 101 |
+
'Clunk',
|
| 102 |
+
'Coin (dropping)',
|
| 103 |
+
'Computer keyboard',
|
| 104 |
+
'Conversation',
|
| 105 |
+
'Coo',
|
| 106 |
+
'Cough',
|
| 107 |
+
'Cowbell',
|
| 108 |
+
'Crack',
|
| 109 |
+
'Crackle',
|
| 110 |
+
'Creak',
|
| 111 |
+
'Cricket',
|
| 112 |
+
'Croak',
|
| 113 |
+
'Crockery breaking and smashing',
|
| 114 |
+
'Crow',
|
| 115 |
+
'Crowd',
|
| 116 |
+
'Crowing, cock-a-doodle-doo',
|
| 117 |
+
'Crumpling, crinkling',
|
| 118 |
+
'Crunch',
|
| 119 |
+
'Crushing',
|
| 120 |
+
'Crying, sobbing',
|
| 121 |
+
'Cupboard open or close',
|
| 122 |
+
'Cutlery, silverware',
|
| 123 |
+
'Deformable shell',
|
| 124 |
+
"Dental drill, dentist's drill",
|
| 125 |
+
'Dial tone',
|
| 126 |
+
'Digestive',
|
| 127 |
+
'Ding',
|
| 128 |
+
'Ding-dong',
|
| 129 |
+
'Dishes, pots, and pans',
|
| 130 |
+
'Distortion',
|
| 131 |
+
'Dog',
|
| 132 |
+
'Domestic animals, pets',
|
| 133 |
+
'Dong, bong',
|
| 134 |
+
'Donkey, ass',
|
| 135 |
+
'Door',
|
| 136 |
+
'Doorbell',
|
| 137 |
+
'Drawer open or close',
|
| 138 |
+
'Drill',
|
| 139 |
+
'Drip',
|
| 140 |
+
'Duck call (hunting tool)',
|
| 141 |
+
'Ducks, geese, waterfowl',
|
| 142 |
+
'Echo',
|
| 143 |
+
'Effects unit',
|
| 144 |
+
'Electric rotor drone, quadcopter',
|
| 145 |
+
'Electric shaver, electric razor',
|
| 146 |
+
'Electric toothbrush',
|
| 147 |
+
'Electronic tuner',
|
| 148 |
+
'Emergency vehicle',
|
| 149 |
+
'Engine',
|
| 150 |
+
'Engine knocking',
|
| 151 |
+
'Engine starting',
|
| 152 |
+
'Environmental noise',
|
| 153 |
+
'Error signal',
|
| 154 |
+
'Eruption',
|
| 155 |
+
'Explosion',
|
| 156 |
+
'Fart',
|
| 157 |
+
'Female singing',
|
| 158 |
+
'Female speech, woman speaking',
|
| 159 |
+
'Filing (rasp)',
|
| 160 |
+
'Fill (with liquid)',
|
| 161 |
+
'Finger snapping',
|
| 162 |
+
'Fire',
|
| 163 |
+
'Fire alarm',
|
| 164 |
+
'Fire engine, fire truck (siren)',
|
| 165 |
+
'Firecracker',
|
| 166 |
+
'Fireworks',
|
| 167 |
+
'Fixed-wing aircraft, airplane',
|
| 168 |
+
'Fizz',
|
| 169 |
+
'Flap',
|
| 170 |
+
'Fly, housefly',
|
| 171 |
+
'Foghorn',
|
| 172 |
+
'Fowl',
|
| 173 |
+
'Frog',
|
| 174 |
+
'Frying (food)',
|
| 175 |
+
'Fusillade',
|
| 176 |
+
'Gargling',
|
| 177 |
+
'Gasp',
|
| 178 |
+
'Gears',
|
| 179 |
+
'Generic impact sounds',
|
| 180 |
+
'Giggle',
|
| 181 |
+
'Glass',
|
| 182 |
+
'Glass chink, clink',
|
| 183 |
+
'Glass shatter',
|
| 184 |
+
'Goat',
|
| 185 |
+
'Gobble',
|
| 186 |
+
'Grind',
|
| 187 |
+
'Groan',
|
| 188 |
+
'Growling',
|
| 189 |
+
'Grunt',
|
| 190 |
+
'Gull, seagull',
|
| 191 |
+
'Gunshot, gunfire',
|
| 192 |
+
'Gurgling, bubbling',
|
| 193 |
+
'Gush',
|
| 194 |
+
'Hair dryer',
|
| 195 |
+
'Hammer',
|
| 196 |
+
'Hands',
|
| 197 |
+
'Heart sounds, heartbeat',
|
| 198 |
+
'Heavy engine (low frequency)',
|
| 199 |
+
'Helicopter',
|
| 200 |
+
'Hiccup',
|
| 201 |
+
'Hiss',
|
| 202 |
+
'Honk',
|
| 203 |
+
'Hoot',
|
| 204 |
+
'Horse',
|
| 205 |
+
'Howl',
|
| 206 |
+
'Howl (wind)',
|
| 207 |
+
'Hubbub, speech noise, speech babble',
|
| 208 |
+
'Hum',
|
| 209 |
+
'Human group actions',
|
| 210 |
+
'Human locomotion',
|
| 211 |
+
'Human sounds',
|
| 212 |
+
'Human voice',
|
| 213 |
+
'Humming',
|
| 214 |
+
'Ice cream truck, ice cream van',
|
| 215 |
+
'Idling',
|
| 216 |
+
'Insect',
|
| 217 |
+
'Inside, large room or hall',
|
| 218 |
+
'Inside, public space',
|
| 219 |
+
'Inside, small room',
|
| 220 |
+
'Jackhammer',
|
| 221 |
+
'Jet engine',
|
| 222 |
+
'Jingle bell',
|
| 223 |
+
'Jingle, tinkle',
|
| 224 |
+
'Kettle whistle',
|
| 225 |
+
'Keypress tone',
|
| 226 |
+
'Keys jangling',
|
| 227 |
+
'Kitchen and dining room sounds',
|
| 228 |
+
'Knife',
|
| 229 |
+
'Knock',
|
| 230 |
+
'Laughter',
|
| 231 |
+
'Lawn mower',
|
| 232 |
+
'Light engine (high frequency)',
|
| 233 |
+
'Liquid',
|
| 234 |
+
'Livestock, farm animals, working animals',
|
| 235 |
+
'Lock',
|
| 236 |
+
'Machine gun',
|
| 237 |
+
'Mains hum',
|
| 238 |
+
'Male singing',
|
| 239 |
+
'Male speech, man speaking',
|
| 240 |
+
'Mantra',
|
| 241 |
+
'Mechanical bell',
|
| 242 |
+
'Mechanical fan',
|
| 243 |
+
'Mechanisms',
|
| 244 |
+
'Medium engine (mid frequency)',
|
| 245 |
+
'Meow',
|
| 246 |
+
'Microphone',
|
| 247 |
+
'Microwave oven',
|
| 248 |
+
'Moo',
|
| 249 |
+
'Mosquito',
|
| 250 |
+
'Motor vehicle (road)',
|
| 251 |
+
'Motorboat, speedboat',
|
| 252 |
+
'Motorcycle',
|
| 253 |
+
'Mouse',
|
| 254 |
+
'Music',
|
| 255 |
+
'Narration, monologue',
|
| 256 |
+
'Neigh, whinny',
|
| 257 |
+
'Noise',
|
| 258 |
+
'Non-motorized land vehicle',
|
| 259 |
+
'Ocean',
|
| 260 |
+
'Oink',
|
| 261 |
+
'Other sourceless',
|
| 262 |
+
'Outside, urban or manmade',
|
| 263 |
+
'Owl',
|
| 264 |
+
'Packing tape, duct tape',
|
| 265 |
+
'Pant',
|
| 266 |
+
'Pant (dog)',
|
| 267 |
+
'Paper rustling',
|
| 268 |
+
'Patter',
|
| 269 |
+
'Pig',
|
| 270 |
+
'Pigeon, dove',
|
| 271 |
+
'Ping',
|
| 272 |
+
'Plop',
|
| 273 |
+
'Police car (siren)',
|
| 274 |
+
'Pour',
|
| 275 |
+
'Power saw, circular saw, table saw',
|
| 276 |
+
'Power tool',
|
| 277 |
+
'Power windows, electric windows',
|
| 278 |
+
'Printer',
|
| 279 |
+
'Propeller, airscrew',
|
| 280 |
+
'Puff',
|
| 281 |
+
'Pulleys',
|
| 282 |
+
'Pulse',
|
| 283 |
+
'Pump (liquid)',
|
| 284 |
+
'Purr',
|
| 285 |
+
'Quack',
|
| 286 |
+
'Race car, auto racing',
|
| 287 |
+
'Radio',
|
| 288 |
+
'Rail transport',
|
| 289 |
+
'Railroad car, train wagon',
|
| 290 |
+
'Rain',
|
| 291 |
+
'Rain on surface',
|
| 292 |
+
'Raindrop',
|
| 293 |
+
'Rapping',
|
| 294 |
+
'Ratchet, pawl',
|
| 295 |
+
'Rattle',
|
| 296 |
+
'Refrigerator',
|
| 297 |
+
'Respiratory sounds',
|
| 298 |
+
'Reverberation',
|
| 299 |
+
'Reversing beeps',
|
| 300 |
+
'Ringing tone, ringback tone',
|
| 301 |
+
'Ringtone',
|
| 302 |
+
'Roar',
|
| 303 |
+
'Roaring cats (lions, tigers)',
|
| 304 |
+
'Rodents, rats, mice',
|
| 305 |
+
'Roll',
|
| 306 |
+
'Rowboat, canoe, kayak',
|
| 307 |
+
'Rub',
|
| 308 |
+
'Rumble',
|
| 309 |
+
'Run',
|
| 310 |
+
'Rustle',
|
| 311 |
+
'Sailboat, sailing ship',
|
| 312 |
+
'Sanding',
|
| 313 |
+
'Sawing',
|
| 314 |
+
'Scissors',
|
| 315 |
+
'Scrape',
|
| 316 |
+
'Scratch',
|
| 317 |
+
'Screaming',
|
| 318 |
+
'Screech',
|
| 319 |
+
'Sewing machine',
|
| 320 |
+
'Sheep',
|
| 321 |
+
'Ship',
|
| 322 |
+
'Shout',
|
| 323 |
+
'Shower',
|
| 324 |
+
'Shuffle',
|
| 325 |
+
'Shuffling cards',
|
| 326 |
+
'Sigh',
|
| 327 |
+
'Sine wave',
|
| 328 |
+
'Singing',
|
| 329 |
+
'Single-lens reflex camera',
|
| 330 |
+
'Sink (filling or washing)',
|
| 331 |
+
'Siren',
|
| 332 |
+
'Sizzle',
|
| 333 |
+
'Skateboard',
|
| 334 |
+
'Slam',
|
| 335 |
+
'Slap, smack',
|
| 336 |
+
'Sliding door',
|
| 337 |
+
'Slosh',
|
| 338 |
+
'Slurp, drinking straw',
|
| 339 |
+
'Smash, crash',
|
| 340 |
+
'Smoke detector, smoke alarm',
|
| 341 |
+
'Snake',
|
| 342 |
+
'Snap',
|
| 343 |
+
'Sneeze',
|
| 344 |
+
'Snicker',
|
| 345 |
+
'Sniff',
|
| 346 |
+
'Snoring',
|
| 347 |
+
'Snort',
|
| 348 |
+
'Snort (horse)',
|
| 349 |
+
'Sonar',
|
| 350 |
+
'Sonic boom',
|
| 351 |
+
'Sound effect',
|
| 352 |
+
'Sound equipment',
|
| 353 |
+
'Sound reproduction',
|
| 354 |
+
'Speech',
|
| 355 |
+
'Speech synthesizer',
|
| 356 |
+
'Splash, splatter',
|
| 357 |
+
'Splinter',
|
| 358 |
+
'Spray',
|
| 359 |
+
'Squawk',
|
| 360 |
+
'Squeak',
|
| 361 |
+
'Squeal',
|
| 362 |
+
'Squish',
|
| 363 |
+
'Stairs',
|
| 364 |
+
'Static',
|
| 365 |
+
'Steam',
|
| 366 |
+
'Steam whistle',
|
| 367 |
+
'Stir',
|
| 368 |
+
'Stomach rumble',
|
| 369 |
+
'Stomp, stamp',
|
| 370 |
+
'Stream, river',
|
| 371 |
+
'Subway, metro, underground',
|
| 372 |
+
'Surface contact',
|
| 373 |
+
'Sweeping',
|
| 374 |
+
'Synthetic singing',
|
| 375 |
+
'Tap',
|
| 376 |
+
'Tap dance',
|
| 377 |
+
'Tape hiss',
|
| 378 |
+
'Tearing',
|
| 379 |
+
'Telephone',
|
| 380 |
+
'Telephone bell ringing',
|
| 381 |
+
'Telephone dialing, DTMF',
|
| 382 |
+
'Television',
|
| 383 |
+
'Throat clearing',
|
| 384 |
+
'Thump, thud',
|
| 385 |
+
'Thunder',
|
| 386 |
+
'Thunderstorm',
|
| 387 |
+
'Thunk',
|
| 388 |
+
'Tick',
|
| 389 |
+
'Tick-tock',
|
| 390 |
+
'Tire squeal, skidding',
|
| 391 |
+
'Toilet flush',
|
| 392 |
+
'Tools',
|
| 393 |
+
'Toothbrush',
|
| 394 |
+
'Traffic noise, roadway noise',
|
| 395 |
+
'Train',
|
| 396 |
+
'Train horn',
|
| 397 |
+
'Train wheels squealing',
|
| 398 |
+
'Train whistle',
|
| 399 |
+
'Trickle, dribble',
|
| 400 |
+
'Truck',
|
| 401 |
+
'Tuning fork',
|
| 402 |
+
'Turkey',
|
| 403 |
+
'Typewriter',
|
| 404 |
+
'Typing',
|
| 405 |
+
'Unknown sound',
|
| 406 |
+
'Vacuum cleaner',
|
| 407 |
+
'Vehicle',
|
| 408 |
+
'Vehicle horn, car horn, honking, toot',
|
| 409 |
+
'Velcro, hook and loop fastener',
|
| 410 |
+
'Video game sound',
|
| 411 |
+
'Wail, moan',
|
| 412 |
+
'Walk, footsteps',
|
| 413 |
+
'Washing machine',
|
| 414 |
+
'Water',
|
| 415 |
+
'Water tap, faucet',
|
| 416 |
+
'Waterfall',
|
| 417 |
+
'Waves, surf',
|
| 418 |
+
'Whack, thwack',
|
| 419 |
+
'Whale vocalization',
|
| 420 |
+
'Wheeze',
|
| 421 |
+
'Whimper',
|
| 422 |
+
'Whimper (dog)',
|
| 423 |
+
'Whip',
|
| 424 |
+
'Whir',
|
| 425 |
+
'Whispering',
|
| 426 |
+
'Whistle',
|
| 427 |
+
'Whistling',
|
| 428 |
+
'White noise, pink noise',
|
| 429 |
+
'Whoop',
|
| 430 |
+
'Whoosh, swoosh, swish',
|
| 431 |
+
'Wild animals',
|
| 432 |
+
'Wildfire',
|
| 433 |
+
'Wind',
|
| 434 |
+
'Wind chime',
|
| 435 |
+
'Wind noise (microphone)',
|
| 436 |
+
'Windscreen wiper, windshield wiper',
|
| 437 |
+
'Wobble',
|
| 438 |
+
'Wolf-whistling',
|
| 439 |
+
'Wood',
|
| 440 |
+
'Writing',
|
| 441 |
+
'Yak',
|
| 442 |
+
'Yawn',
|
| 443 |
+
'Yell',
|
| 444 |
+
'Yip',
|
| 445 |
+
'Yodeling',
|
| 446 |
+
'Zing',
|
| 447 |
+
'Zipper (clothing)']
|
| 448 |
+
|
| 449 |
+
as_strong_eval_classes = ['Accelerating, revving, vroom',
|
| 450 |
+
'Air brake',
|
| 451 |
+
'Air conditioning',
|
| 452 |
+
'Air horn, truck horn',
|
| 453 |
+
'Aircraft',
|
| 454 |
+
'Aircraft engine',
|
| 455 |
+
'Alarm',
|
| 456 |
+
'Alarm clock',
|
| 457 |
+
'Ambulance (siren)',
|
| 458 |
+
'Animal',
|
| 459 |
+
'Applause',
|
| 460 |
+
'Arrow',
|
| 461 |
+
'Artillery fire',
|
| 462 |
+
'Audio logo',
|
| 463 |
+
'Babbling',
|
| 464 |
+
'Baby cry, infant cry',
|
| 465 |
+
'Baby laughter',
|
| 466 |
+
'Background noise',
|
| 467 |
+
'Bang',
|
| 468 |
+
'Bark',
|
| 469 |
+
'Basketball bounce',
|
| 470 |
+
'Bathtub (filling or washing)',
|
| 471 |
+
'Battle cry',
|
| 472 |
+
'Bee, wasp, etc.',
|
| 473 |
+
'Beep, bleep',
|
| 474 |
+
'Bell',
|
| 475 |
+
'Bellow',
|
| 476 |
+
'Belly laugh',
|
| 477 |
+
'Bicycle bell',
|
| 478 |
+
'Bicycle, tricycle',
|
| 479 |
+
'Bird',
|
| 480 |
+
'Bird flight, flapping wings',
|
| 481 |
+
'Bird vocalization, bird call, bird song',
|
| 482 |
+
'Biting',
|
| 483 |
+
'Bleat',
|
| 484 |
+
'Blender, food processor',
|
| 485 |
+
'Boat, Water vehicle',
|
| 486 |
+
'Boiling',
|
| 487 |
+
'Boing',
|
| 488 |
+
'Boom',
|
| 489 |
+
'Bouncing',
|
| 490 |
+
'Bow-wow',
|
| 491 |
+
'Breaking',
|
| 492 |
+
'Breathing',
|
| 493 |
+
'Brief tone',
|
| 494 |
+
'Burping, eructation',
|
| 495 |
+
'Burst, pop',
|
| 496 |
+
'Bus',
|
| 497 |
+
'Busy signal',
|
| 498 |
+
'Buzz',
|
| 499 |
+
'Buzzer',
|
| 500 |
+
'Cacophony',
|
| 501 |
+
'Camera',
|
| 502 |
+
'Canidae, wild dogs, wolves',
|
| 503 |
+
'Cap gun',
|
| 504 |
+
'Car',
|
| 505 |
+
'Car alarm',
|
| 506 |
+
'Car passing by',
|
| 507 |
+
'Cart',
|
| 508 |
+
'Cash register',
|
| 509 |
+
'Cat',
|
| 510 |
+
'Caterwaul',
|
| 511 |
+
'Cattle, bovinae',
|
| 512 |
+
'Caw',
|
| 513 |
+
'Cellphone buzz, vibrating alert',
|
| 514 |
+
'Chainsaw',
|
| 515 |
+
'Change ringing (campanology)',
|
| 516 |
+
'Chant',
|
| 517 |
+
'Cheering',
|
| 518 |
+
'Chewing, mastication',
|
| 519 |
+
'Chicken, rooster',
|
| 520 |
+
'Child singing',
|
| 521 |
+
'Child speech, kid speaking',
|
| 522 |
+
'Children playing',
|
| 523 |
+
'Children shouting',
|
| 524 |
+
'Chime',
|
| 525 |
+
'Chipmunk',
|
| 526 |
+
'Chirp tone',
|
| 527 |
+
'Chirp, tweet',
|
| 528 |
+
'Choir',
|
| 529 |
+
'Chop',
|
| 530 |
+
'Chopping (food)',
|
| 531 |
+
'Chorus effect',
|
| 532 |
+
'Chuckle, chortle',
|
| 533 |
+
'Church bell',
|
| 534 |
+
'Civil defense siren',
|
| 535 |
+
'Clang',
|
| 536 |
+
'Clapping',
|
| 537 |
+
'Clatter',
|
| 538 |
+
'Clickety-clack',
|
| 539 |
+
'Clicking',
|
| 540 |
+
'Clip-clop',
|
| 541 |
+
'Clock',
|
| 542 |
+
'Cluck',
|
| 543 |
+
'Coin (dropping)',
|
| 544 |
+
'Computer keyboard',
|
| 545 |
+
'Conversation',
|
| 546 |
+
'Coo',
|
| 547 |
+
'Cough',
|
| 548 |
+
'Cowbell',
|
| 549 |
+
'Crack',
|
| 550 |
+
'Crackle',
|
| 551 |
+
'Creak',
|
| 552 |
+
'Cricket',
|
| 553 |
+
'Croak',
|
| 554 |
+
'Crockery breaking and smashing',
|
| 555 |
+
'Crow',
|
| 556 |
+
'Crowd',
|
| 557 |
+
'Crowing, cock-a-doodle-doo',
|
| 558 |
+
'Crumpling, crinkling',
|
| 559 |
+
'Crunch',
|
| 560 |
+
'Crushing',
|
| 561 |
+
'Crying, sobbing',
|
| 562 |
+
'Cupboard open or close',
|
| 563 |
+
'Cutlery, silverware',
|
| 564 |
+
"Dental drill, dentist's drill",
|
| 565 |
+
'Dial tone',
|
| 566 |
+
'Ding',
|
| 567 |
+
'Ding-dong',
|
| 568 |
+
'Dishes, pots, and pans',
|
| 569 |
+
'Distortion',
|
| 570 |
+
'Dog',
|
| 571 |
+
'Domestic animals, pets',
|
| 572 |
+
'Door',
|
| 573 |
+
'Doorbell',
|
| 574 |
+
'Drawer open or close',
|
| 575 |
+
'Drill',
|
| 576 |
+
'Drip',
|
| 577 |
+
'Ducks, geese, waterfowl',
|
| 578 |
+
'Echo',
|
| 579 |
+
'Effects unit',
|
| 580 |
+
'Electric rotor drone, quadcopter',
|
| 581 |
+
'Electric shaver, electric razor',
|
| 582 |
+
'Electric toothbrush',
|
| 583 |
+
'Electronic tuner',
|
| 584 |
+
'Emergency vehicle',
|
| 585 |
+
'Engine',
|
| 586 |
+
'Engine knocking',
|
| 587 |
+
'Engine starting',
|
| 588 |
+
'Environmental noise',
|
| 589 |
+
'Eruption',
|
| 590 |
+
'Explosion',
|
| 591 |
+
'Fart',
|
| 592 |
+
'Female singing',
|
| 593 |
+
'Female speech, woman speaking',
|
| 594 |
+
'Filing (rasp)',
|
| 595 |
+
'Fill (with liquid)',
|
| 596 |
+
'Finger snapping',
|
| 597 |
+
'Fire',
|
| 598 |
+
'Fire alarm',
|
| 599 |
+
'Fire engine, fire truck (siren)',
|
| 600 |
+
'Firecracker',
|
| 601 |
+
'Fireworks',
|
| 602 |
+
'Fixed-wing aircraft, airplane',
|
| 603 |
+
'Flap',
|
| 604 |
+
'Fly, housefly',
|
| 605 |
+
'Foghorn',
|
| 606 |
+
'Fowl',
|
| 607 |
+
'Frog',
|
| 608 |
+
'Frying (food)',
|
| 609 |
+
'Fusillade',
|
| 610 |
+
'Gargling',
|
| 611 |
+
'Gasp',
|
| 612 |
+
'Gears',
|
| 613 |
+
'Generic impact sounds',
|
| 614 |
+
'Giggle',
|
| 615 |
+
'Glass',
|
| 616 |
+
'Glass chink, clink',
|
| 617 |
+
'Glass shatter',
|
| 618 |
+
'Goat',
|
| 619 |
+
'Gobble',
|
| 620 |
+
'Groan',
|
| 621 |
+
'Growling',
|
| 622 |
+
'Grunt',
|
| 623 |
+
'Gunshot, gunfire',
|
| 624 |
+
'Gurgling, bubbling',
|
| 625 |
+
'Gush',
|
| 626 |
+
'Hair dryer',
|
| 627 |
+
'Hammer',
|
| 628 |
+
'Hands',
|
| 629 |
+
'Heart murmur',
|
| 630 |
+
'Heart sounds, heartbeat',
|
| 631 |
+
'Heavy engine (low frequency)',
|
| 632 |
+
'Helicopter',
|
| 633 |
+
'Hiccup',
|
| 634 |
+
'Hiss',
|
| 635 |
+
'Honk',
|
| 636 |
+
'Hoot',
|
| 637 |
+
'Horse',
|
| 638 |
+
'Howl',
|
| 639 |
+
'Howl (wind)',
|
| 640 |
+
'Hubbub, speech noise, speech babble',
|
| 641 |
+
'Hum',
|
| 642 |
+
'Human sounds',
|
| 643 |
+
'Human voice',
|
| 644 |
+
'Humming',
|
| 645 |
+
'Ice cream truck, ice cream van',
|
| 646 |
+
'Idling',
|
| 647 |
+
'Insect',
|
| 648 |
+
'Inside, large room or hall',
|
| 649 |
+
'Inside, public space',
|
| 650 |
+
'Inside, small room',
|
| 651 |
+
'Jackhammer',
|
| 652 |
+
'Jet engine',
|
| 653 |
+
'Jingle bell',
|
| 654 |
+
'Jingle, tinkle',
|
| 655 |
+
'Keys jangling',
|
| 656 |
+
'Kitchen and dining room sounds',
|
| 657 |
+
'Knock',
|
| 658 |
+
'Laughter',
|
| 659 |
+
'Lawn mower',
|
| 660 |
+
'Light engine (high frequency)',
|
| 661 |
+
'Liquid',
|
| 662 |
+
'Livestock, farm animals, working animals',
|
| 663 |
+
'Machine gun',
|
| 664 |
+
'Mains hum',
|
| 665 |
+
'Male singing',
|
| 666 |
+
'Male speech, man speaking',
|
| 667 |
+
'Mantra',
|
| 668 |
+
'Mechanical fan',
|
| 669 |
+
'Mechanisms',
|
| 670 |
+
'Medium engine (mid frequency)',
|
| 671 |
+
'Meow',
|
| 672 |
+
'Microwave oven',
|
| 673 |
+
'Moo',
|
| 674 |
+
'Mosquito',
|
| 675 |
+
'Motor vehicle (road)',
|
| 676 |
+
'Motorboat, speedboat',
|
| 677 |
+
'Motorcycle',
|
| 678 |
+
'Mouse',
|
| 679 |
+
'Music',
|
| 680 |
+
'Narration, monologue',
|
| 681 |
+
'Neigh, whinny',
|
| 682 |
+
'Noise',
|
| 683 |
+
'Non-motorized land vehicle',
|
| 684 |
+
'Ocean',
|
| 685 |
+
'Oink',
|
| 686 |
+
'Outside, rural or natural',
|
| 687 |
+
'Outside, urban or manmade',
|
| 688 |
+
'Owl',
|
| 689 |
+
'Packing tape, duct tape',
|
| 690 |
+
'Pant',
|
| 691 |
+
'Pant (dog)',
|
| 692 |
+
'Paper rustling',
|
| 693 |
+
'Patter',
|
| 694 |
+
'Pig',
|
| 695 |
+
'Pigeon, dove',
|
| 696 |
+
'Ping',
|
| 697 |
+
'Plop',
|
| 698 |
+
'Police car (siren)',
|
| 699 |
+
'Pour',
|
| 700 |
+
'Power saw, circular saw, table saw',
|
| 701 |
+
'Power tool',
|
| 702 |
+
'Power windows, electric windows',
|
| 703 |
+
'Printer',
|
| 704 |
+
'Propeller, airscrew',
|
| 705 |
+
'Pulleys',
|
| 706 |
+
'Pulse',
|
| 707 |
+
'Pump (liquid)',
|
| 708 |
+
'Purr',
|
| 709 |
+
'Quack',
|
| 710 |
+
'Race car, auto racing',
|
| 711 |
+
'Radio',
|
| 712 |
+
'Rail transport',
|
| 713 |
+
'Railroad car, train wagon',
|
| 714 |
+
'Rain',
|
| 715 |
+
'Rain on surface',
|
| 716 |
+
'Raindrop',
|
| 717 |
+
'Rapping',
|
| 718 |
+
'Ratchet, pawl',
|
| 719 |
+
'Rattle',
|
| 720 |
+
'Respiratory sounds',
|
| 721 |
+
'Reverberation',
|
| 722 |
+
'Reversing beeps',
|
| 723 |
+
'Ringing tone, ringback tone',
|
| 724 |
+
'Ringtone',
|
| 725 |
+
'Roar',
|
| 726 |
+
'Roaring cats (lions, tigers)',
|
| 727 |
+
'Rodents, rats, mice',
|
| 728 |
+
'Roll',
|
| 729 |
+
'Rowboat, canoe, kayak',
|
| 730 |
+
'Rub',
|
| 731 |
+
'Rumble',
|
| 732 |
+
'Run',
|
| 733 |
+
'Rustle',
|
| 734 |
+
'Sailboat, sailing ship',
|
| 735 |
+
'Sanding',
|
| 736 |
+
'Sawing',
|
| 737 |
+
'Scissors',
|
| 738 |
+
'Scrape',
|
| 739 |
+
'Scratch',
|
| 740 |
+
'Screaming',
|
| 741 |
+
'Sewing machine',
|
| 742 |
+
'Sheep',
|
| 743 |
+
'Ship',
|
| 744 |
+
'Shout',
|
| 745 |
+
'Shower',
|
| 746 |
+
'Shuffle',
|
| 747 |
+
'Shuffling cards',
|
| 748 |
+
'Sigh',
|
| 749 |
+
'Silence',
|
| 750 |
+
'Sine wave',
|
| 751 |
+
'Singing',
|
| 752 |
+
'Single-lens reflex camera',
|
| 753 |
+
'Sink (filling or washing)',
|
| 754 |
+
'Siren',
|
| 755 |
+
'Sizzle',
|
| 756 |
+
'Skateboard',
|
| 757 |
+
'Slam',
|
| 758 |
+
'Slap, smack',
|
| 759 |
+
'Sliding door',
|
| 760 |
+
'Slosh',
|
| 761 |
+
'Smash, crash',
|
| 762 |
+
'Smoke detector, smoke alarm',
|
| 763 |
+
'Snake',
|
| 764 |
+
'Sneeze',
|
| 765 |
+
'Snicker',
|
| 766 |
+
'Sniff',
|
| 767 |
+
'Snoring',
|
| 768 |
+
'Snort',
|
| 769 |
+
'Snort (horse)',
|
| 770 |
+
'Sonar',
|
| 771 |
+
'Sound effect',
|
| 772 |
+
'Sound equipment',
|
| 773 |
+
'Source-ambiguous sounds',
|
| 774 |
+
'Specific impact sounds',
|
| 775 |
+
'Speech',
|
| 776 |
+
'Speech synthesizer',
|
| 777 |
+
'Splash, splatter',
|
| 778 |
+
'Splinter',
|
| 779 |
+
'Spray',
|
| 780 |
+
'Squawk',
|
| 781 |
+
'Squeak',
|
| 782 |
+
'Squeal',
|
| 783 |
+
'Squish',
|
| 784 |
+
'Stairs',
|
| 785 |
+
'Static',
|
| 786 |
+
'Steam',
|
| 787 |
+
'Steam whistle',
|
| 788 |
+
'Stir',
|
| 789 |
+
'Stomach rumble',
|
| 790 |
+
'Stomp, stamp',
|
| 791 |
+
'Stream, river',
|
| 792 |
+
'Studio recording',
|
| 793 |
+
'Subway, metro, underground',
|
| 794 |
+
'Surface contact',
|
| 795 |
+
'Synthetic singing',
|
| 796 |
+
'Tap',
|
| 797 |
+
'Tap dance',
|
| 798 |
+
'Tearing',
|
| 799 |
+
'Telephone',
|
| 800 |
+
'Telephone bell ringing',
|
| 801 |
+
'Telephone dialing, DTMF',
|
| 802 |
+
'Television',
|
| 803 |
+
'Throat clearing',
|
| 804 |
+
'Throbbing',
|
| 805 |
+
'Thump, thud',
|
| 806 |
+
'Thunder',
|
| 807 |
+
'Thunderstorm',
|
| 808 |
+
'Thunk',
|
| 809 |
+
'Tick',
|
| 810 |
+
'Tick-tock',
|
| 811 |
+
'Tire squeal, skidding',
|
| 812 |
+
'Toilet flush',
|
| 813 |
+
'Tools',
|
| 814 |
+
'Toothbrush',
|
| 815 |
+
'Traffic noise, roadway noise',
|
| 816 |
+
'Train',
|
| 817 |
+
'Train horn',
|
| 818 |
+
'Train wheels squealing',
|
| 819 |
+
'Train whistle',
|
| 820 |
+
'Trickle, dribble',
|
| 821 |
+
'Truck',
|
| 822 |
+
'Tuning fork',
|
| 823 |
+
'Turkey',
|
| 824 |
+
'Typewriter',
|
| 825 |
+
'Typing',
|
| 826 |
+
'Unknown sound',
|
| 827 |
+
'Unmodified field recording',
|
| 828 |
+
'Vacuum cleaner',
|
| 829 |
+
'Vehicle',
|
| 830 |
+
'Vehicle horn, car horn, honking, toot',
|
| 831 |
+
'Velcro, hook and loop fastener',
|
| 832 |
+
'Vibration',
|
| 833 |
+
'Video game sound',
|
| 834 |
+
'Wail, moan',
|
| 835 |
+
'Walk, footsteps',
|
| 836 |
+
'Washing machine',
|
| 837 |
+
'Water',
|
| 838 |
+
'Water tap, faucet',
|
| 839 |
+
'Waterfall',
|
| 840 |
+
'Waves, surf',
|
| 841 |
+
'Whack, thwack',
|
| 842 |
+
'Whale vocalization',
|
| 843 |
+
'Wheeze',
|
| 844 |
+
'Whimper',
|
| 845 |
+
'Whimper (dog)',
|
| 846 |
+
'Whip',
|
| 847 |
+
'Whir',
|
| 848 |
+
'Whispering',
|
| 849 |
+
'Whistle',
|
| 850 |
+
'Whistling',
|
| 851 |
+
'White noise, pink noise',
|
| 852 |
+
'Whoop',
|
| 853 |
+
'Whoosh, swoosh, swish',
|
| 854 |
+
'Wild animals',
|
| 855 |
+
'Wind',
|
| 856 |
+
'Wind chime',
|
| 857 |
+
'Wind noise (microphone)',
|
| 858 |
+
'Wood',
|
| 859 |
+
'Writing',
|
| 860 |
+
'Yawn',
|
| 861 |
+
'Yell',
|
| 862 |
+
'Yip',
|
| 863 |
+
'Yodeling',
|
| 864 |
+
'Zipper (clothing)']
|
| 865 |
+
|
| 866 |
+
as_weak_classes = ['A capella',
|
| 867 |
+
'Accelerating, revving, vroom',
|
| 868 |
+
'Accordion',
|
| 869 |
+
'Acoustic guitar',
|
| 870 |
+
'Afrobeat',
|
| 871 |
+
'Air brake',
|
| 872 |
+
'Air conditioning',
|
| 873 |
+
'Air horn, truck horn',
|
| 874 |
+
'Aircraft',
|
| 875 |
+
'Aircraft engine',
|
| 876 |
+
'Alarm',
|
| 877 |
+
'Alarm clock',
|
| 878 |
+
'Ambient music',
|
| 879 |
+
'Ambulance (siren)',
|
| 880 |
+
'Angry music',
|
| 881 |
+
'Animal',
|
| 882 |
+
'Applause',
|
| 883 |
+
'Arrow',
|
| 884 |
+
'Artillery fire',
|
| 885 |
+
'Babbling',
|
| 886 |
+
'Baby cry, infant cry',
|
| 887 |
+
'Baby laughter',
|
| 888 |
+
'Background music',
|
| 889 |
+
'Bagpipes',
|
| 890 |
+
'Bang',
|
| 891 |
+
'Banjo',
|
| 892 |
+
'Bark',
|
| 893 |
+
'Basketball bounce',
|
| 894 |
+
'Bass drum',
|
| 895 |
+
'Bass guitar',
|
| 896 |
+
'Bathtub (filling or washing)',
|
| 897 |
+
'Battle cry',
|
| 898 |
+
'Beatboxing',
|
| 899 |
+
'Bee, wasp, etc.',
|
| 900 |
+
'Beep, bleep',
|
| 901 |
+
'Bell',
|
| 902 |
+
'Bellow',
|
| 903 |
+
'Belly laugh',
|
| 904 |
+
'Bicycle',
|
| 905 |
+
'Bicycle bell',
|
| 906 |
+
'Bird',
|
| 907 |
+
'Bird flight, flapping wings',
|
| 908 |
+
'Bird vocalization, bird call, bird song',
|
| 909 |
+
'Biting',
|
| 910 |
+
'Bleat',
|
| 911 |
+
'Blender',
|
| 912 |
+
'Bluegrass',
|
| 913 |
+
'Blues',
|
| 914 |
+
'Boat, Water vehicle',
|
| 915 |
+
'Boiling',
|
| 916 |
+
'Boing',
|
| 917 |
+
'Boom',
|
| 918 |
+
'Bouncing',
|
| 919 |
+
'Bow-wow',
|
| 920 |
+
'Bowed string instrument',
|
| 921 |
+
'Brass instrument',
|
| 922 |
+
'Breaking',
|
| 923 |
+
'Breathing',
|
| 924 |
+
'Burping, eructation',
|
| 925 |
+
'Burst, pop',
|
| 926 |
+
'Bus',
|
| 927 |
+
'Busy signal',
|
| 928 |
+
'Buzz',
|
| 929 |
+
'Buzzer',
|
| 930 |
+
'Cacophony',
|
| 931 |
+
'Camera',
|
| 932 |
+
'Canidae, dogs, wolves',
|
| 933 |
+
'Cap gun',
|
| 934 |
+
'Car',
|
| 935 |
+
'Car alarm',
|
| 936 |
+
'Car passing by',
|
| 937 |
+
'Carnatic music',
|
| 938 |
+
'Cash register',
|
| 939 |
+
'Cat',
|
| 940 |
+
'Caterwaul',
|
| 941 |
+
'Cattle, bovinae',
|
| 942 |
+
'Caw',
|
| 943 |
+
'Cello',
|
| 944 |
+
'Chainsaw',
|
| 945 |
+
'Change ringing (campanology)',
|
| 946 |
+
'Chant',
|
| 947 |
+
'Chatter',
|
| 948 |
+
'Cheering',
|
| 949 |
+
'Chewing, mastication',
|
| 950 |
+
'Chicken, rooster',
|
| 951 |
+
'Child singing',
|
| 952 |
+
'Child speech, kid speaking',
|
| 953 |
+
'Children playing',
|
| 954 |
+
'Children shouting',
|
| 955 |
+
'Chime',
|
| 956 |
+
'Chink, clink',
|
| 957 |
+
'Chirp tone',
|
| 958 |
+
'Chirp, tweet',
|
| 959 |
+
'Choir',
|
| 960 |
+
'Chop',
|
| 961 |
+
'Chopping (food)',
|
| 962 |
+
'Chorus effect',
|
| 963 |
+
'Christian music',
|
| 964 |
+
'Christmas music',
|
| 965 |
+
'Chuckle, chortle',
|
| 966 |
+
'Church bell',
|
| 967 |
+
'Civil defense siren',
|
| 968 |
+
'Clang',
|
| 969 |
+
'Clapping',
|
| 970 |
+
'Clarinet',
|
| 971 |
+
'Classical music',
|
| 972 |
+
'Clatter',
|
| 973 |
+
'Clickety-clack',
|
| 974 |
+
'Clicking',
|
| 975 |
+
'Clip-clop',
|
| 976 |
+
'Clock',
|
| 977 |
+
'Cluck',
|
| 978 |
+
'Coin (dropping)',
|
| 979 |
+
'Computer keyboard',
|
| 980 |
+
'Conversation',
|
| 981 |
+
'Coo',
|
| 982 |
+
'Cough',
|
| 983 |
+
'Country',
|
| 984 |
+
'Cowbell',
|
| 985 |
+
'Crack',
|
| 986 |
+
'Crackle',
|
| 987 |
+
'Creak',
|
| 988 |
+
'Cricket',
|
| 989 |
+
'Croak',
|
| 990 |
+
'Crow',
|
| 991 |
+
'Crowd',
|
| 992 |
+
'Crowing, cock-a-doodle-doo',
|
| 993 |
+
'Crumpling, crinkling',
|
| 994 |
+
'Crunch',
|
| 995 |
+
'Crushing',
|
| 996 |
+
'Crying, sobbing',
|
| 997 |
+
'Cupboard open or close',
|
| 998 |
+
'Cutlery, silverware',
|
| 999 |
+
'Cymbal',
|
| 1000 |
+
'Dance music',
|
| 1001 |
+
"Dental drill, dentist's drill",
|
| 1002 |
+
'Dial tone',
|
| 1003 |
+
'Didgeridoo',
|
| 1004 |
+
'Ding',
|
| 1005 |
+
'Ding-dong',
|
| 1006 |
+
'Disco',
|
| 1007 |
+
'Dishes, pots, and pans',
|
| 1008 |
+
'Distortion',
|
| 1009 |
+
'Dog',
|
| 1010 |
+
'Domestic animals, pets',
|
| 1011 |
+
'Door',
|
| 1012 |
+
'Doorbell',
|
| 1013 |
+
'Double bass',
|
| 1014 |
+
'Drawer open or close',
|
| 1015 |
+
'Drill',
|
| 1016 |
+
'Drip',
|
| 1017 |
+
'Drum',
|
| 1018 |
+
'Drum and bass',
|
| 1019 |
+
'Drum kit',
|
| 1020 |
+
'Drum machine',
|
| 1021 |
+
'Drum roll',
|
| 1022 |
+
'Dubstep',
|
| 1023 |
+
'Duck',
|
| 1024 |
+
'Echo',
|
| 1025 |
+
'Effects unit',
|
| 1026 |
+
'Electric guitar',
|
| 1027 |
+
'Electric piano',
|
| 1028 |
+
'Electric shaver, electric razor',
|
| 1029 |
+
'Electric toothbrush',
|
| 1030 |
+
'Electronic dance music',
|
| 1031 |
+
'Electronic music',
|
| 1032 |
+
'Electronic organ',
|
| 1033 |
+
'Electronic tuner',
|
| 1034 |
+
'Electronica',
|
| 1035 |
+
'Emergency vehicle',
|
| 1036 |
+
'Engine',
|
| 1037 |
+
'Engine knocking',
|
| 1038 |
+
'Engine starting',
|
| 1039 |
+
'Environmental noise',
|
| 1040 |
+
'Eruption',
|
| 1041 |
+
'Exciting music',
|
| 1042 |
+
'Explosion',
|
| 1043 |
+
'Fart',
|
| 1044 |
+
'Female singing',
|
| 1045 |
+
'Female speech, woman speaking',
|
| 1046 |
+
'Field recording',
|
| 1047 |
+
'Filing (rasp)',
|
| 1048 |
+
'Fill (with liquid)',
|
| 1049 |
+
'Finger snapping',
|
| 1050 |
+
'Fire',
|
| 1051 |
+
'Fire alarm',
|
| 1052 |
+
'Fire engine, fire truck (siren)',
|
| 1053 |
+
'Firecracker',
|
| 1054 |
+
'Fireworks',
|
| 1055 |
+
'Fixed-wing aircraft, airplane',
|
| 1056 |
+
'Flamenco',
|
| 1057 |
+
'Flap',
|
| 1058 |
+
'Flute',
|
| 1059 |
+
'Fly, housefly',
|
| 1060 |
+
'Foghorn',
|
| 1061 |
+
'Folk music',
|
| 1062 |
+
'Fowl',
|
| 1063 |
+
'French horn',
|
| 1064 |
+
'Frog',
|
| 1065 |
+
'Frying (food)',
|
| 1066 |
+
'Funk',
|
| 1067 |
+
'Funny music',
|
| 1068 |
+
'Fusillade',
|
| 1069 |
+
'Gargling',
|
| 1070 |
+
'Gasp',
|
| 1071 |
+
'Gears',
|
| 1072 |
+
'Giggle',
|
| 1073 |
+
'Glass',
|
| 1074 |
+
'Glockenspiel',
|
| 1075 |
+
'Goat',
|
| 1076 |
+
'Gobble',
|
| 1077 |
+
'Gong',
|
| 1078 |
+
'Goose',
|
| 1079 |
+
'Gospel music',
|
| 1080 |
+
'Groan',
|
| 1081 |
+
'Growling',
|
| 1082 |
+
'Grunge',
|
| 1083 |
+
'Grunt',
|
| 1084 |
+
'Guitar',
|
| 1085 |
+
'Gunshot, gunfire',
|
| 1086 |
+
'Gurgling',
|
| 1087 |
+
'Gush',
|
| 1088 |
+
'Hair dryer',
|
| 1089 |
+
'Hammer',
|
| 1090 |
+
'Hammond organ',
|
| 1091 |
+
'Hands',
|
| 1092 |
+
'Happy music',
|
| 1093 |
+
'Harmonic',
|
| 1094 |
+
'Harmonica',
|
| 1095 |
+
'Harp',
|
| 1096 |
+
'Harpsichord',
|
| 1097 |
+
'Heart murmur',
|
| 1098 |
+
'Heart sounds, heartbeat',
|
| 1099 |
+
'Heavy engine (low frequency)',
|
| 1100 |
+
'Heavy metal',
|
| 1101 |
+
'Helicopter',
|
| 1102 |
+
'Hi-hat',
|
| 1103 |
+
'Hiccup',
|
| 1104 |
+
'Hip hop music',
|
| 1105 |
+
'Hiss',
|
| 1106 |
+
'Honk',
|
| 1107 |
+
'Hoot',
|
| 1108 |
+
'Horse',
|
| 1109 |
+
'House music',
|
| 1110 |
+
'Howl',
|
| 1111 |
+
'Hubbub, speech noise, speech babble',
|
| 1112 |
+
'Hum',
|
| 1113 |
+
'Humming',
|
| 1114 |
+
'Ice cream truck, ice cream van',
|
| 1115 |
+
'Idling',
|
| 1116 |
+
'Independent music',
|
| 1117 |
+
'Insect',
|
| 1118 |
+
'Inside, large room or hall',
|
| 1119 |
+
'Inside, public space',
|
| 1120 |
+
'Inside, small room',
|
| 1121 |
+
'Jackhammer',
|
| 1122 |
+
'Jazz',
|
| 1123 |
+
'Jet engine',
|
| 1124 |
+
'Jingle (music)',
|
| 1125 |
+
'Jingle bell',
|
| 1126 |
+
'Jingle, tinkle',
|
| 1127 |
+
'Keyboard (musical)',
|
| 1128 |
+
'Keys jangling',
|
| 1129 |
+
'Knock',
|
| 1130 |
+
'Laughter',
|
| 1131 |
+
'Lawn mower',
|
| 1132 |
+
'Light engine (high frequency)',
|
| 1133 |
+
'Liquid',
|
| 1134 |
+
'Livestock, farm animals, working animals',
|
| 1135 |
+
'Lullaby',
|
| 1136 |
+
'Machine gun',
|
| 1137 |
+
'Mains hum',
|
| 1138 |
+
'Male singing',
|
| 1139 |
+
'Male speech, man speaking',
|
| 1140 |
+
'Mallet percussion',
|
| 1141 |
+
'Mandolin',
|
| 1142 |
+
'Mantra',
|
| 1143 |
+
'Maraca',
|
| 1144 |
+
'Marimba, xylophone',
|
| 1145 |
+
'Mechanical fan',
|
| 1146 |
+
'Mechanisms',
|
| 1147 |
+
'Medium engine (mid frequency)',
|
| 1148 |
+
'Meow',
|
| 1149 |
+
'Microwave oven',
|
| 1150 |
+
'Middle Eastern music',
|
| 1151 |
+
'Moo',
|
| 1152 |
+
'Mosquito',
|
| 1153 |
+
'Motor vehicle (road)',
|
| 1154 |
+
'Motorboat, speedboat',
|
| 1155 |
+
'Motorcycle',
|
| 1156 |
+
'Mouse',
|
| 1157 |
+
'Music',
|
| 1158 |
+
'Music for children',
|
| 1159 |
+
'Music of Africa',
|
| 1160 |
+
'Music of Asia',
|
| 1161 |
+
'Music of Bollywood',
|
| 1162 |
+
'Music of Latin America',
|
| 1163 |
+
'Musical instrument',
|
| 1164 |
+
'Narration, monologue',
|
| 1165 |
+
'Neigh, whinny',
|
| 1166 |
+
'New-age music',
|
| 1167 |
+
'Noise',
|
| 1168 |
+
'Ocean',
|
| 1169 |
+
'Oink',
|
| 1170 |
+
'Opera',
|
| 1171 |
+
'Orchestra',
|
| 1172 |
+
'Organ',
|
| 1173 |
+
'Outside, rural or natural',
|
| 1174 |
+
'Outside, urban or manmade',
|
| 1175 |
+
'Owl',
|
| 1176 |
+
'Pant',
|
| 1177 |
+
'Patter',
|
| 1178 |
+
'Percussion',
|
| 1179 |
+
'Piano',
|
| 1180 |
+
'Pig',
|
| 1181 |
+
'Pigeon, dove',
|
| 1182 |
+
'Ping',
|
| 1183 |
+
'Pink noise',
|
| 1184 |
+
'Pizzicato',
|
| 1185 |
+
'Plop',
|
| 1186 |
+
'Plucked string instrument',
|
| 1187 |
+
'Police car (siren)',
|
| 1188 |
+
'Pop music',
|
| 1189 |
+
'Pour',
|
| 1190 |
+
'Power tool',
|
| 1191 |
+
'Power windows, electric windows',
|
| 1192 |
+
'Printer',
|
| 1193 |
+
'Progressive rock',
|
| 1194 |
+
'Propeller, airscrew',
|
| 1195 |
+
'Psychedelic rock',
|
| 1196 |
+
'Pulleys',
|
| 1197 |
+
'Pulse',
|
| 1198 |
+
'Pump (liquid)',
|
| 1199 |
+
'Punk rock',
|
| 1200 |
+
'Purr',
|
| 1201 |
+
'Quack',
|
| 1202 |
+
'Race car, auto racing',
|
| 1203 |
+
'Radio',
|
| 1204 |
+
'Rail transport',
|
| 1205 |
+
'Railroad car, train wagon',
|
| 1206 |
+
'Rain',
|
| 1207 |
+
'Rain on surface',
|
| 1208 |
+
'Raindrop',
|
| 1209 |
+
'Rapping',
|
| 1210 |
+
'Ratchet, pawl',
|
| 1211 |
+
'Rattle',
|
| 1212 |
+
'Rattle (instrument)',
|
| 1213 |
+
'Reggae',
|
| 1214 |
+
'Reverberation',
|
| 1215 |
+
'Reversing beeps',
|
| 1216 |
+
'Rhythm and blues',
|
| 1217 |
+
'Rimshot',
|
| 1218 |
+
'Ringtone',
|
| 1219 |
+
'Roar',
|
| 1220 |
+
'Roaring cats (lions, tigers)',
|
| 1221 |
+
'Rock and roll',
|
| 1222 |
+
'Rock music',
|
| 1223 |
+
'Rodents, rats, mice',
|
| 1224 |
+
'Roll',
|
| 1225 |
+
'Rowboat, canoe, kayak',
|
| 1226 |
+
'Rub',
|
| 1227 |
+
'Rumble',
|
| 1228 |
+
'Run',
|
| 1229 |
+
'Rustle',
|
| 1230 |
+
'Rustling leaves',
|
| 1231 |
+
'Sad music',
|
| 1232 |
+
'Sailboat, sailing ship',
|
| 1233 |
+
'Salsa music',
|
| 1234 |
+
'Sampler',
|
| 1235 |
+
'Sanding',
|
| 1236 |
+
'Sawing',
|
| 1237 |
+
'Saxophone',
|
| 1238 |
+
'Scary music',
|
| 1239 |
+
'Scissors',
|
| 1240 |
+
'Scrape',
|
| 1241 |
+
'Scratch',
|
| 1242 |
+
'Scratching (performance technique)',
|
| 1243 |
+
'Screaming',
|
| 1244 |
+
'Sewing machine',
|
| 1245 |
+
'Shatter',
|
| 1246 |
+
'Sheep',
|
| 1247 |
+
'Ship',
|
| 1248 |
+
'Shofar',
|
| 1249 |
+
'Shout',
|
| 1250 |
+
'Shuffle',
|
| 1251 |
+
'Shuffling cards',
|
| 1252 |
+
'Sidetone',
|
| 1253 |
+
'Sigh',
|
| 1254 |
+
'Silence',
|
| 1255 |
+
'Sine wave',
|
| 1256 |
+
'Singing',
|
| 1257 |
+
'Singing bowl',
|
| 1258 |
+
'Single-lens reflex camera',
|
| 1259 |
+
'Sink (filling or washing)',
|
| 1260 |
+
'Siren',
|
| 1261 |
+
'Sitar',
|
| 1262 |
+
'Sizzle',
|
| 1263 |
+
'Ska',
|
| 1264 |
+
'Skateboard',
|
| 1265 |
+
'Skidding',
|
| 1266 |
+
'Slam',
|
| 1267 |
+
'Slap, smack',
|
| 1268 |
+
'Sliding door',
|
| 1269 |
+
'Slosh',
|
| 1270 |
+
'Smash, crash',
|
| 1271 |
+
'Smoke detector, smoke alarm',
|
| 1272 |
+
'Snake',
|
| 1273 |
+
'Snare drum',
|
| 1274 |
+
'Sneeze',
|
| 1275 |
+
'Snicker',
|
| 1276 |
+
'Sniff',
|
| 1277 |
+
'Snoring',
|
| 1278 |
+
'Snort',
|
| 1279 |
+
'Sonar',
|
| 1280 |
+
'Song',
|
| 1281 |
+
'Soul music',
|
| 1282 |
+
'Sound effect',
|
| 1283 |
+
'Soundtrack music',
|
| 1284 |
+
'Speech',
|
| 1285 |
+
'Speech synthesizer',
|
| 1286 |
+
'Splash, splatter',
|
| 1287 |
+
'Splinter',
|
| 1288 |
+
'Spray',
|
| 1289 |
+
'Squawk',
|
| 1290 |
+
'Squeak',
|
| 1291 |
+
'Squeal',
|
| 1292 |
+
'Squish',
|
| 1293 |
+
'Static',
|
| 1294 |
+
'Steam',
|
| 1295 |
+
'Steam whistle',
|
| 1296 |
+
'Steel guitar, slide guitar',
|
| 1297 |
+
'Steelpan',
|
| 1298 |
+
'Stir',
|
| 1299 |
+
'Stomach rumble',
|
| 1300 |
+
'Stream',
|
| 1301 |
+
'String section',
|
| 1302 |
+
'Strum',
|
| 1303 |
+
'Subway, metro, underground',
|
| 1304 |
+
'Swing music',
|
| 1305 |
+
'Synthesizer',
|
| 1306 |
+
'Synthetic singing',
|
| 1307 |
+
'Tabla',
|
| 1308 |
+
'Tambourine',
|
| 1309 |
+
'Tap',
|
| 1310 |
+
'Tapping (guitar technique)',
|
| 1311 |
+
'Tearing',
|
| 1312 |
+
'Techno',
|
| 1313 |
+
'Telephone',
|
| 1314 |
+
'Telephone bell ringing',
|
| 1315 |
+
'Telephone dialing, DTMF',
|
| 1316 |
+
'Television',
|
| 1317 |
+
'Tender music',
|
| 1318 |
+
'Theme music',
|
| 1319 |
+
'Theremin',
|
| 1320 |
+
'Throat clearing',
|
| 1321 |
+
'Throbbing',
|
| 1322 |
+
'Thump, thud',
|
| 1323 |
+
'Thunder',
|
| 1324 |
+
'Thunderstorm',
|
| 1325 |
+
'Thunk',
|
| 1326 |
+
'Tick',
|
| 1327 |
+
'Tick-tock',
|
| 1328 |
+
'Timpani',
|
| 1329 |
+
'Tire squeal',
|
| 1330 |
+
'Toilet flush',
|
| 1331 |
+
'Tools',
|
| 1332 |
+
'Toot',
|
| 1333 |
+
'Toothbrush',
|
| 1334 |
+
'Traditional music',
|
| 1335 |
+
'Traffic noise, roadway noise',
|
| 1336 |
+
'Train',
|
| 1337 |
+
'Train horn',
|
| 1338 |
+
'Train wheels squealing',
|
| 1339 |
+
'Train whistle',
|
| 1340 |
+
'Trance music',
|
| 1341 |
+
'Trickle, dribble',
|
| 1342 |
+
'Trombone',
|
| 1343 |
+
'Truck',
|
| 1344 |
+
'Trumpet',
|
| 1345 |
+
'Tubular bells',
|
| 1346 |
+
'Tuning fork',
|
| 1347 |
+
'Turkey',
|
| 1348 |
+
'Typewriter',
|
| 1349 |
+
'Typing',
|
| 1350 |
+
'Ukulele',
|
| 1351 |
+
'Vacuum cleaner',
|
| 1352 |
+
'Vehicle',
|
| 1353 |
+
'Vehicle horn, car horn, honking',
|
| 1354 |
+
'Vibraphone',
|
| 1355 |
+
'Vibration',
|
| 1356 |
+
'Video game music',
|
| 1357 |
+
'Violin, fiddle',
|
| 1358 |
+
'Vocal music',
|
| 1359 |
+
'Wail, moan',
|
| 1360 |
+
'Walk, footsteps',
|
| 1361 |
+
'Water',
|
| 1362 |
+
'Water tap, faucet',
|
| 1363 |
+
'Waterfall',
|
| 1364 |
+
'Waves, surf',
|
| 1365 |
+
'Wedding music',
|
| 1366 |
+
'Whack, thwack',
|
| 1367 |
+
'Whale vocalization',
|
| 1368 |
+
'Wheeze',
|
| 1369 |
+
'Whimper',
|
| 1370 |
+
'Whimper (dog)',
|
| 1371 |
+
'Whip',
|
| 1372 |
+
'Whir',
|
| 1373 |
+
'Whispering',
|
| 1374 |
+
'Whistle',
|
| 1375 |
+
'Whistling',
|
| 1376 |
+
'White noise',
|
| 1377 |
+
'Whoop',
|
| 1378 |
+
'Whoosh, swoosh, swish',
|
| 1379 |
+
'Wild animals',
|
| 1380 |
+
'Wind',
|
| 1381 |
+
'Wind chime',
|
| 1382 |
+
'Wind instrument, woodwind instrument',
|
| 1383 |
+
'Wind noise (microphone)',
|
| 1384 |
+
'Wood',
|
| 1385 |
+
'Wood block',
|
| 1386 |
+
'Writing',
|
| 1387 |
+
'Yell',
|
| 1388 |
+
'Yip',
|
| 1389 |
+
'Yodeling',
|
| 1390 |
+
'Zing',
|
| 1391 |
+
'Zipper (clothing)',
|
| 1392 |
+
'Zither'
|
| 1393 |
+
]
|
data_util/audioset_strong.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from time import perf_counter
|
| 3 |
+
import datasets
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import (
|
| 8 |
+
Dataset as TorchDataset,
|
| 9 |
+
DistributedSampler,
|
| 10 |
+
WeightedRandomSampler,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from data_util.audioset_classes import as_strong_train_classes
|
| 14 |
+
from data_util.transforms import (
|
| 15 |
+
Mp3DecodeTransform,
|
| 16 |
+
SequentialTransform,
|
| 17 |
+
AddPseudoLabelsTransform,
|
| 18 |
+
strong_label_transform,
|
| 19 |
+
target_transform
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
logger = datasets.logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def init_hf_config(max_shard_size="2GB", verbose=True, in_mem_max=None):
|
| 26 |
+
datasets.config.MAX_SHARD_SIZE = max_shard_size
|
| 27 |
+
if verbose:
|
| 28 |
+
datasets.logging.set_verbosity_info()
|
| 29 |
+
if in_mem_max is not None:
|
| 30 |
+
datasets.config.IN_MEMORY_MAX_SIZE = in_mem_max
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_hf_local_path(path, local_datasets_path=None):
|
| 34 |
+
if local_datasets_path is None:
|
| 35 |
+
local_datasets_path = os.environ.get(
|
| 36 |
+
"HF_DATASETS_LOCAL",
|
| 37 |
+
os.path.join(os.environ.get("HF_DATASETS_CACHE"), "../local"),
|
| 38 |
+
)
|
| 39 |
+
path = os.path.join(local_datasets_path, path)
|
| 40 |
+
return path
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class catchtime:
|
| 44 |
+
# context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time
|
| 45 |
+
def __init__(self, debug_print="Time", logger=logger):
|
| 46 |
+
self.debug_print = debug_print
|
| 47 |
+
self.logger = logger
|
| 48 |
+
|
| 49 |
+
def __enter__(self):
|
| 50 |
+
self.start = perf_counter()
|
| 51 |
+
return self
|
| 52 |
+
|
| 53 |
+
def __exit__(self, type, value, traceback):
|
| 54 |
+
self.time = perf_counter() - self.start
|
| 55 |
+
readout = f"{self.debug_print}: {self.time:.3f} seconds"
|
| 56 |
+
self.logger.info(readout)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def merge_overlapping_events(sample):
|
| 60 |
+
events = pd.DataFrame(sample['events'][0])
|
| 61 |
+
events = events.sort_values(by='onset')
|
| 62 |
+
sample['events'] = [None]
|
| 63 |
+
|
| 64 |
+
for l in events['event_label'].unique():
|
| 65 |
+
rows = []
|
| 66 |
+
for i, r in events.loc[events['event_label'] == l].iterrows():
|
| 67 |
+
if len(rows) == 0 or rows[-1]['offset'] < r['onset']:
|
| 68 |
+
rows.append(r)
|
| 69 |
+
else:
|
| 70 |
+
onset = min(rows[-1]['onset'], r['onset'])
|
| 71 |
+
offset = max(rows[-1]['offset'], r['offset'])
|
| 72 |
+
rows[-1]['onset'] = onset
|
| 73 |
+
rows[-1]['offset'] = offset
|
| 74 |
+
if sample["events"][0] is None:
|
| 75 |
+
sample['events'][0] = pd.DataFrame(rows)
|
| 76 |
+
else:
|
| 77 |
+
sample["events"][0] = pd.concat([sample['events'][0], pd.DataFrame(rows)])
|
| 78 |
+
return sample
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_training_dataset(
|
| 82 |
+
label_encoder,
|
| 83 |
+
audio_length=10.0,
|
| 84 |
+
sample_rate=16000,
|
| 85 |
+
wavmix_p=0.0,
|
| 86 |
+
pseudo_labels_file=None,
|
| 87 |
+
):
|
| 88 |
+
init_hf_config()
|
| 89 |
+
|
| 90 |
+
decode_transform = Mp3DecodeTransform(
|
| 91 |
+
sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
ds_list = []
|
| 95 |
+
|
| 96 |
+
with catchtime("Loading audioset_strong"):
|
| 97 |
+
as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
|
| 98 |
+
|
| 99 |
+
# label encode transformation
|
| 100 |
+
if label_encoder is not None:
|
| 101 |
+
# set list of label names to be encoded
|
| 102 |
+
label_encoder.labels = as_strong_train_classes
|
| 103 |
+
encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
|
| 104 |
+
else:
|
| 105 |
+
encode_label_fun = lambda x: x
|
| 106 |
+
|
| 107 |
+
as_transforms = [
|
| 108 |
+
decode_transform,
|
| 109 |
+
merge_overlapping_events,
|
| 110 |
+
encode_label_fun,
|
| 111 |
+
target_transform,
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
if pseudo_labels_file:
|
| 115 |
+
as_transforms.append(AddPseudoLabelsTransform(pseudo_labels_file=pseudo_labels_file).add_pseudo_label_transform)
|
| 116 |
+
|
| 117 |
+
as_ds.set_transform(SequentialTransform(as_transforms))
|
| 118 |
+
|
| 119 |
+
ds_list.append(as_ds["balanced_train"])
|
| 120 |
+
ds_list.append(as_ds["unbalanced_train"])
|
| 121 |
+
dataset = torch.utils.data.ConcatDataset(ds_list)
|
| 122 |
+
|
| 123 |
+
if wavmix_p > 0:
|
| 124 |
+
print("Using Wavmix!")
|
| 125 |
+
dataset = MixupDataset(dataset, rate=wavmix_p)
|
| 126 |
+
return dataset
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_eval_dataset(
|
| 130 |
+
label_encoder,
|
| 131 |
+
audio_length=10.0,
|
| 132 |
+
sample_rate=16000
|
| 133 |
+
):
|
| 134 |
+
init_hf_config()
|
| 135 |
+
ds_list = []
|
| 136 |
+
|
| 137 |
+
decode_transform = Mp3DecodeTransform(
|
| 138 |
+
sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
with catchtime(f"Loading audioset:"):
|
| 142 |
+
as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
|
| 143 |
+
|
| 144 |
+
# label encode transformation
|
| 145 |
+
if label_encoder is not None:
|
| 146 |
+
label_encoder.labels = as_strong_train_classes
|
| 147 |
+
encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
|
| 148 |
+
else:
|
| 149 |
+
encode_label_fun = lambda x: x
|
| 150 |
+
|
| 151 |
+
as_transforms = [
|
| 152 |
+
decode_transform,
|
| 153 |
+
merge_overlapping_events,
|
| 154 |
+
encode_label_fun,
|
| 155 |
+
target_transform
|
| 156 |
+
]
|
| 157 |
+
as_ds.set_transform(SequentialTransform(as_transforms))
|
| 158 |
+
as_ds_eval = (
|
| 159 |
+
as_ds["eval"]
|
| 160 |
+
)
|
| 161 |
+
ds_list.append(as_ds_eval)
|
| 162 |
+
dataset = torch.utils.data.ConcatDataset(ds_list)
|
| 163 |
+
return dataset
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_full_dataset(label_encoder, audio_length=10.0, sample_rate=16000):
|
| 167 |
+
init_hf_config()
|
| 168 |
+
|
| 169 |
+
decode_transform = Mp3DecodeTransform(
|
| 170 |
+
sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
with catchtime(f"Loading audioset:"):
|
| 174 |
+
as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
|
| 175 |
+
|
| 176 |
+
# label encode transformation
|
| 177 |
+
if label_encoder is not None:
|
| 178 |
+
label_encoder.labels = as_strong_train_classes
|
| 179 |
+
encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
|
| 180 |
+
else:
|
| 181 |
+
encode_label_fun = lambda x: x
|
| 182 |
+
|
| 183 |
+
as_transforms = [
|
| 184 |
+
decode_transform,
|
| 185 |
+
merge_overlapping_events,
|
| 186 |
+
encode_label_fun,
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
as_ds.set_transform(SequentialTransform(as_transforms))
|
| 190 |
+
ds_list = []
|
| 191 |
+
ds_list.append(as_ds["balanced_train"])
|
| 192 |
+
ds_list.append(as_ds["unbalanced_train"])
|
| 193 |
+
ds_list.append(as_ds["eval"])
|
| 194 |
+
|
| 195 |
+
dataset = torch.utils.data.ConcatDataset(ds_list)
|
| 196 |
+
return dataset
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def get_uniform_sample_weights(dataset):
|
| 200 |
+
"""
|
| 201 |
+
:return: float tensor of shape len(full_training_set) representing the weights of each sample.
|
| 202 |
+
"""
|
| 203 |
+
return torch.ones(len(dataset)).float()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_temporal_count_balanced_sample_weights(dataset, sample_weight_offset=30,
|
| 207 |
+
save_folder="/share/rk8/shared/as_strong"):
|
| 208 |
+
"""
|
| 209 |
+
:return: float tensor of shape len(full_training_set) representing the weights of each sample.
|
| 210 |
+
"""
|
| 211 |
+
# the order of balanced_train_hdf5, unbalanced_train_hdf5 is important.
|
| 212 |
+
# should match get_full_training_set
|
| 213 |
+
os.makedirs(save_folder, exist_ok=True)
|
| 214 |
+
save_file = os.path.join(save_folder, f"weights_temporal_count_offset_{sample_weight_offset}.pt")
|
| 215 |
+
if os.path.exists(save_file):
|
| 216 |
+
return torch.load(save_file)
|
| 217 |
+
|
| 218 |
+
from tqdm import tqdm
|
| 219 |
+
|
| 220 |
+
all_y = []
|
| 221 |
+
for sample in tqdm(dataset, desc="Calculating sample weights."):
|
| 222 |
+
all_y.append(sample["event_count"])
|
| 223 |
+
all_y = torch.from_numpy(np.stack(all_y, axis=0))
|
| 224 |
+
per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class
|
| 225 |
+
|
| 226 |
+
per_class = sample_weight_offset + per_class # offset low freq classes
|
| 227 |
+
if sample_weight_offset > 0:
|
| 228 |
+
print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}")
|
| 229 |
+
per_class_weights = 1000. / per_class
|
| 230 |
+
all_weight = all_y * per_class_weights
|
| 231 |
+
all_weight = all_weight.sum(dim=1)
|
| 232 |
+
|
| 233 |
+
torch.save(all_weight, save_file)
|
| 234 |
+
return all_weight
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class MixupDataset(TorchDataset):
|
| 238 |
+
""" Mixing Up wave forms
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
def __init__(self, dataset, beta=2, rate=0.5):
|
| 242 |
+
self.beta = beta
|
| 243 |
+
self.rate = rate
|
| 244 |
+
self.dataset = dataset
|
| 245 |
+
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
|
| 246 |
+
|
| 247 |
+
def __getitem__(self, index):
|
| 248 |
+
if torch.rand(1) < self.rate:
|
| 249 |
+
batch1 = self.dataset[index]
|
| 250 |
+
idx2 = torch.randint(len(self.dataset), (1,)).item()
|
| 251 |
+
batch2 = self.dataset[idx2]
|
| 252 |
+
x1, x2 = batch1['audio'], batch2['audio']
|
| 253 |
+
y1, y2 = batch1['strong'], batch2['strong']
|
| 254 |
+
if 'pseudo_strong' in batch1:
|
| 255 |
+
p1, p2 = batch1['pseudo_strong'], batch2['pseudo_strong']
|
| 256 |
+
l = np.random.beta(self.beta, self.beta)
|
| 257 |
+
l = max(l, 1. - l)
|
| 258 |
+
x1 = x1 - x1.mean()
|
| 259 |
+
x2 = x2 - x2.mean()
|
| 260 |
+
x = (x1 * l + x2 * (1. - l))
|
| 261 |
+
x = x - x.mean()
|
| 262 |
+
batch1['audio'] = x
|
| 263 |
+
batch1['strong'] = (y1 * l + y2 * (1. - l))
|
| 264 |
+
if 'pseudo_strong' in batch1:
|
| 265 |
+
batch1['pseudo_strong'] = (p1 * l + p2 * (1. - l))
|
| 266 |
+
return batch1
|
| 267 |
+
return self.dataset[index]
|
| 268 |
+
|
| 269 |
+
def __len__(self):
|
| 270 |
+
return len(self.dataset)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class DistributedSamplerWrapper(DistributedSampler):
|
| 274 |
+
def __init__(
|
| 275 |
+
self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True
|
| 276 |
+
):
|
| 277 |
+
super(DistributedSamplerWrapper, self).__init__(
|
| 278 |
+
dataset, num_replicas, rank, shuffle
|
| 279 |
+
)
|
| 280 |
+
# source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
|
| 281 |
+
self.sampler = sampler
|
| 282 |
+
|
| 283 |
+
def __iter__(self):
|
| 284 |
+
if self.sampler.generator is None:
|
| 285 |
+
self.sampler.generator = torch.Generator()
|
| 286 |
+
self.sampler.generator.manual_seed(self.seed + self.epoch)
|
| 287 |
+
indices = list(self.sampler)
|
| 288 |
+
if self.epoch < 2:
|
| 289 |
+
logger.info(
|
| 290 |
+
f"\n DistributedSamplerWrapper (rank {self.rank}) : {indices[:3]} \n\n"
|
| 291 |
+
)
|
| 292 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
| 293 |
+
return iter(indices)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def get_weighted_sampler(
|
| 297 |
+
samples_weights,
|
| 298 |
+
epoch_len=100_000,
|
| 299 |
+
sampler_replace=False,
|
| 300 |
+
):
|
| 301 |
+
num_nodes = int(os.environ.get("WORLD_SIZE", 1))
|
| 302 |
+
ddp = int(os.environ.get("DDP", 1))
|
| 303 |
+
num_nodes = max(ddp, num_nodes)
|
| 304 |
+
rank = int(os.environ.get("NODE_RANK", 0))
|
| 305 |
+
return DistributedSamplerWrapper(
|
| 306 |
+
sampler=WeightedRandomSampler(
|
| 307 |
+
samples_weights, num_samples=epoch_len, replacement=sampler_replace
|
| 308 |
+
),
|
| 309 |
+
dataset=range(epoch_len),
|
| 310 |
+
num_replicas=num_nodes,
|
| 311 |
+
rank=rank,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
from helpers.encode import ManyHotEncoder
|
| 317 |
+
|
| 318 |
+
encoder = ManyHotEncoder([], 10., 160, net_pooling=4, fs=16_000)
|
| 319 |
+
|
| 320 |
+
train_ds = get_training_dataset(
|
| 321 |
+
encoder, audio_length=10.0, sample_rate=16_000
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
valid_ds = get_eval_dataset(
|
| 325 |
+
encoder, audio_length=10.0, sample_rate=16_000
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
print("Len train dataset: ", len(train_ds))
|
| 329 |
+
print("Len valid dataset: ", len(valid_ds))
|
data_util/dcase2016task2.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, List, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
import torch
|
| 9 |
+
from intervaltree import IntervalTree
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FixCropDataset(Dataset):
|
| 14 |
+
"""
|
| 15 |
+
Read in a JSON file and return audio and audio filenames
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, data: Dict,
|
| 19 |
+
audio_dir: Path,
|
| 20 |
+
sample_rate: int,
|
| 21 |
+
label_fps: int,
|
| 22 |
+
label_to_idx: Dict,
|
| 23 |
+
nlabels: int):
|
| 24 |
+
self.clip_len = 120
|
| 25 |
+
self.target_len = 10
|
| 26 |
+
self.pieces_per_clip = self.clip_len // self.target_len
|
| 27 |
+
self.filenames = list(data.keys())
|
| 28 |
+
self.audio_dir = audio_dir
|
| 29 |
+
assert self.audio_dir.is_dir(), f"{audio_dir} is not a directory"
|
| 30 |
+
self.sample_rate = sample_rate
|
| 31 |
+
# all files are 120 seconds long, split them into 12 x 10 second pieces
|
| 32 |
+
self.pieces = []
|
| 33 |
+
self.labels = []
|
| 34 |
+
self.timestamps = []
|
| 35 |
+
for filename in self.filenames:
|
| 36 |
+
self.pieces += [(filename, i) for i in range(self.pieces_per_clip)]
|
| 37 |
+
labels = data[filename]
|
| 38 |
+
frame_len = 1000 / label_fps
|
| 39 |
+
timestamps = np.arange(label_fps * self.clip_len) * frame_len + 0.5 * frame_len
|
| 40 |
+
timestamp_labels = get_labels_for_timestamps(labels, timestamps)
|
| 41 |
+
ys = []
|
| 42 |
+
for timestamp_label in timestamp_labels:
|
| 43 |
+
timestamp_label_idxs = [label_to_idx[str(event)] for event in timestamp_label]
|
| 44 |
+
y_timestamp = label_to_binary_vector(timestamp_label_idxs, nlabels)
|
| 45 |
+
ys.append(y_timestamp)
|
| 46 |
+
ys = torch.stack(ys)
|
| 47 |
+
frames_per_clip = ys.size(0) // self.pieces_per_clip
|
| 48 |
+
self.labels += [ys[frames_per_clip * i: frames_per_clip * (i + 1)] for i in range(self.pieces_per_clip)]
|
| 49 |
+
self.timestamps += [timestamps[frames_per_clip * i: frames_per_clip * (i + 1)] for i in
|
| 50 |
+
range(self.pieces_per_clip)]
|
| 51 |
+
|
| 52 |
+
assert len(self.labels) == len(self.pieces) == len(self.filenames) * self.pieces_per_clip
|
| 53 |
+
|
| 54 |
+
def __len__(self):
|
| 55 |
+
return len(self.pieces)
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
filename = self.pieces[idx][0]
|
| 59 |
+
piece = self.pieces[idx][1]
|
| 60 |
+
audio_path = self.audio_dir.joinpath(filename)
|
| 61 |
+
audio, sr = sf.read(str(audio_path), dtype=np.float32)
|
| 62 |
+
assert sr == self.sample_rate
|
| 63 |
+
start = self.sample_rate * piece * self.target_len
|
| 64 |
+
end = start + self.sample_rate * self.target_len
|
| 65 |
+
audio = audio[start:end]
|
| 66 |
+
return audio, self.labels[idx].transpose(0, 1), filename, self.timestamps[idx]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class RandomCropDataset(Dataset):
|
| 70 |
+
"""
|
| 71 |
+
Read in a JSON file and return audio and audio filenames
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, data: Dict,
|
| 75 |
+
audio_dir: Path,
|
| 76 |
+
sample_rate: int,
|
| 77 |
+
label_fps: int,
|
| 78 |
+
label_to_idx: Dict,
|
| 79 |
+
nlabels: int):
|
| 80 |
+
self.clip_len = 120
|
| 81 |
+
self.target_len = 10
|
| 82 |
+
self.pieces_per_clip = self.clip_len // self.target_len
|
| 83 |
+
self.filenames = list(data.keys())
|
| 84 |
+
self.audio_dir = audio_dir
|
| 85 |
+
assert self.audio_dir.is_dir(), f"{audio_dir} is not a directory"
|
| 86 |
+
self.sample_rate = sample_rate
|
| 87 |
+
self.label_fps = label_fps
|
| 88 |
+
# all files are 120 seconds long, randomly crop 10 seconds snippets
|
| 89 |
+
self.labels = []
|
| 90 |
+
self.timestamps = []
|
| 91 |
+
for filename in self.filenames:
|
| 92 |
+
labels = data[filename]
|
| 93 |
+
frame_len = 1000 / label_fps
|
| 94 |
+
timestamps = np.arange(label_fps * self.clip_len) * frame_len + 0.5 * frame_len
|
| 95 |
+
timestamp_labels = get_labels_for_timestamps(labels, timestamps)
|
| 96 |
+
ys = []
|
| 97 |
+
for timestamp_label in timestamp_labels:
|
| 98 |
+
timestamp_label_idxs = [label_to_idx[str(event)] for event in timestamp_label]
|
| 99 |
+
y_timestamp = label_to_binary_vector(timestamp_label_idxs, nlabels)
|
| 100 |
+
ys.append(y_timestamp)
|
| 101 |
+
ys = torch.stack(ys)
|
| 102 |
+
self.labels.append(ys)
|
| 103 |
+
self.timestamps.append(timestamps)
|
| 104 |
+
|
| 105 |
+
assert len(self.labels) == len(self.filenames)
|
| 106 |
+
|
| 107 |
+
def __len__(self):
|
| 108 |
+
return len(self.filenames) * self.clip_len // self.target_len
|
| 109 |
+
|
| 110 |
+
def __getitem__(self, idx):
|
| 111 |
+
idx = idx % len(self.filenames)
|
| 112 |
+
filename = self.filenames[idx]
|
| 113 |
+
audio_path = self.audio_dir.joinpath(filename)
|
| 114 |
+
audio, sr = sf.read(str(audio_path), dtype=np.float32)
|
| 115 |
+
assert sr == self.sample_rate
|
| 116 |
+
|
| 117 |
+
# crop random 10 seconds piece
|
| 118 |
+
labels_to_pick = self.target_len * self.label_fps
|
| 119 |
+
max_offset = len(self.labels[idx]) - labels_to_pick + 1
|
| 120 |
+
offset = torch.randint(max_offset, (1,)).item()
|
| 121 |
+
labels = self.labels[idx][offset:offset + labels_to_pick]
|
| 122 |
+
scale = self.sample_rate // self.label_fps
|
| 123 |
+
audio = audio[offset * scale:offset * scale + labels_to_pick * scale]
|
| 124 |
+
timestamps = self.timestamps[idx][offset:offset + labels_to_pick]
|
| 125 |
+
return audio, labels.transpose(0, 1), filename, timestamps
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_training_dataset(
|
| 129 |
+
task_path,
|
| 130 |
+
sample_rate=16000,
|
| 131 |
+
label_fps=25,
|
| 132 |
+
wavmix_p=0.0,
|
| 133 |
+
random_crop=True
|
| 134 |
+
):
|
| 135 |
+
task_path = Path(task_path)
|
| 136 |
+
|
| 137 |
+
label_vocab, nlabels = label_vocab_nlabels(task_path)
|
| 138 |
+
label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
|
| 139 |
+
|
| 140 |
+
train_fold = task_path.joinpath("train.json")
|
| 141 |
+
audio_dir = task_path.joinpath(str(sample_rate), "train")
|
| 142 |
+
train_fold_data = json.load(train_fold.open())
|
| 143 |
+
if random_crop:
|
| 144 |
+
dataset = RandomCropDataset(train_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels)
|
| 145 |
+
else:
|
| 146 |
+
dataset = FixCropDataset(train_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels)
|
| 147 |
+
if wavmix_p > 0:
|
| 148 |
+
dataset = MixupDataset(dataset, rate=wavmix_p)
|
| 149 |
+
return dataset
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_validation_dataset(
|
| 153 |
+
task_path,
|
| 154 |
+
sample_rate=16000,
|
| 155 |
+
label_fps=25,
|
| 156 |
+
):
|
| 157 |
+
task_path = Path(task_path)
|
| 158 |
+
|
| 159 |
+
label_vocab, nlabels = label_vocab_nlabels(task_path)
|
| 160 |
+
label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
|
| 161 |
+
|
| 162 |
+
valid_fold = task_path.joinpath("valid.json")
|
| 163 |
+
audio_dir = task_path.joinpath(str(sample_rate), "valid")
|
| 164 |
+
valid_fold_data = json.load(valid_fold.open())
|
| 165 |
+
dataset = FixCropDataset(valid_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels)
|
| 166 |
+
return dataset
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_test_dataset(
|
| 170 |
+
task_path,
|
| 171 |
+
sample_rate=16000,
|
| 172 |
+
label_fps=25,
|
| 173 |
+
):
|
| 174 |
+
task_path = Path(task_path)
|
| 175 |
+
|
| 176 |
+
label_vocab, nlabels = label_vocab_nlabels(task_path)
|
| 177 |
+
label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
|
| 178 |
+
|
| 179 |
+
test_fold = task_path.joinpath("test.json")
|
| 180 |
+
audio_dir = task_path.joinpath(str(sample_rate), "test")
|
| 181 |
+
test_fold_data = json.load(test_fold.open())
|
| 182 |
+
dataset = FixCropDataset(test_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels)
|
| 183 |
+
return dataset
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_labels_for_timestamps(labels: List, timestamps: np.ndarray) -> List:
|
| 187 |
+
# A list of labels present at each timestamp
|
| 188 |
+
tree = IntervalTree()
|
| 189 |
+
# Add all events to the label tree
|
| 190 |
+
for event in labels:
|
| 191 |
+
# We add 0.0001 so that the end also includes the event
|
| 192 |
+
tree.addi(event["start"], event["end"] + 0.0001, event["label"])
|
| 193 |
+
|
| 194 |
+
timestamp_labels = []
|
| 195 |
+
# Update the binary vector of labels with intervals for each timestamp
|
| 196 |
+
for j, t in enumerate(timestamps):
|
| 197 |
+
interval_labels: List[str] = [interval.data for interval in tree[t]]
|
| 198 |
+
timestamp_labels.append(interval_labels)
|
| 199 |
+
# If we want to store the timestamp too
|
| 200 |
+
# labels_for_sound.append([float(t), interval_labels])
|
| 201 |
+
|
| 202 |
+
assert len(timestamp_labels) == len(timestamps)
|
| 203 |
+
return timestamp_labels
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def label_vocab_nlabels(task_path: Path) -> Tuple[pd.DataFrame, int]:
|
| 207 |
+
label_vocab = pd.read_csv(task_path.joinpath("labelvocabulary.csv"))
|
| 208 |
+
|
| 209 |
+
nlabels = len(label_vocab)
|
| 210 |
+
assert nlabels == label_vocab["idx"].max() + 1
|
| 211 |
+
return (label_vocab, nlabels)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def label_vocab_as_dict(df: pd.DataFrame, key: str, value: str) -> Dict:
|
| 215 |
+
"""
|
| 216 |
+
Returns a dictionary of the label vocabulary mapping the label column to
|
| 217 |
+
the idx column. key sets whether the label or idx is the key in the dict. The
|
| 218 |
+
other column will be the value.
|
| 219 |
+
"""
|
| 220 |
+
if key == "label":
|
| 221 |
+
# Make sure the key is a string
|
| 222 |
+
df["label"] = df["label"].astype(str)
|
| 223 |
+
value = "idx"
|
| 224 |
+
else:
|
| 225 |
+
assert key == "idx", "key argument must be either 'label' or 'idx'"
|
| 226 |
+
value = "label"
|
| 227 |
+
return df.set_index(key).to_dict()[value]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def label_to_binary_vector(label: List, num_labels: int) -> torch.Tensor:
|
| 231 |
+
"""
|
| 232 |
+
Converts a list of labels into a binary vector
|
| 233 |
+
Args:
|
| 234 |
+
label: list of integer labels
|
| 235 |
+
num_labels: total number of labels
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
A float Tensor that is multi-hot binary vector
|
| 239 |
+
"""
|
| 240 |
+
# Lame special case for multilabel with no labels
|
| 241 |
+
if len(label) == 0:
|
| 242 |
+
# BCEWithLogitsLoss wants float not long targets
|
| 243 |
+
binary_labels = torch.zeros((num_labels,), dtype=torch.float)
|
| 244 |
+
else:
|
| 245 |
+
binary_labels = torch.zeros((num_labels,)).scatter(0, torch.tensor(label), 1.0)
|
| 246 |
+
|
| 247 |
+
# Validate the binary vector we just created
|
| 248 |
+
assert set(torch.where(binary_labels == 1.0)[0].numpy()) == set(label)
|
| 249 |
+
return binary_labels
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class MixupDataset(Dataset):
|
| 253 |
+
""" Mixing Up wave forms
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, dataset, beta=0.2, rate=0.5):
|
| 257 |
+
self.beta = beta
|
| 258 |
+
self.rate = rate
|
| 259 |
+
self.dataset = dataset
|
| 260 |
+
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
|
| 261 |
+
|
| 262 |
+
def __getitem__(self, index):
|
| 263 |
+
if torch.rand(1) < self.rate:
|
| 264 |
+
batch1 = self.dataset[index]
|
| 265 |
+
idx2 = torch.randint(len(self.dataset), (1,)).item()
|
| 266 |
+
batch2 = self.dataset[idx2]
|
| 267 |
+
x1, x2 = batch1[0], batch2[0]
|
| 268 |
+
y1, y2 = batch1[1], batch2[1]
|
| 269 |
+
l = np.random.beta(self.beta, self.beta)
|
| 270 |
+
l = max(l, 1. - l)
|
| 271 |
+
x1 = x1 - x1.mean()
|
| 272 |
+
x2 = x2 - x2.mean()
|
| 273 |
+
x = (x1 * l + x2 * (1. - l))
|
| 274 |
+
x = x - x.mean()
|
| 275 |
+
y = (y1 * l + y2 * (1. - l))
|
| 276 |
+
return x, y, batch1[2], batch1[3]
|
| 277 |
+
return self.dataset[index]
|
| 278 |
+
|
| 279 |
+
def __len__(self):
|
| 280 |
+
return len(self.dataset)
|
data_util/transforms.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import datasets
|
| 4 |
+
import h5py
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
|
| 10 |
+
from data_util.audioset_classes import as_strong_train_classes
|
| 11 |
+
|
| 12 |
+
## Transforms with a similar style to https://github.com/descriptinc/audiotools/blob/master/audiotools/data/transforms.py
|
| 13 |
+
logger = datasets.logging.get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def target_transform(sample):
|
| 17 |
+
del sample["labels"]
|
| 18 |
+
del sample["label_ids"]
|
| 19 |
+
return sample
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def strong_label_transform(sample, strong_label_encoder=None):
|
| 23 |
+
assert strong_label_encoder is not None
|
| 24 |
+
events = pd.DataFrame(sample['events'][0])
|
| 25 |
+
events = events[events['event_label'].isin(set(as_strong_train_classes))]
|
| 26 |
+
strong = strong_label_encoder.encode_strong_df(events).T
|
| 27 |
+
sample["strong"] = [strong]
|
| 28 |
+
sample["event_count"] = [strong.sum(1)]
|
| 29 |
+
# encode ground truth events as string - we will use this for evaluation
|
| 30 |
+
sample["gt_string"] = ["++".join([";;".join([str(e[0]), str(e[1]), e[2]]) for e in
|
| 31 |
+
zip(sample['events'][0]['onset'], sample['events'][0]['offset'],
|
| 32 |
+
sample['events'][0]['event_label'])])]
|
| 33 |
+
del sample['events']
|
| 34 |
+
return sample
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AddPseudoLabelsTransform:
|
| 38 |
+
def __init__(self, pseudo_labels_file):
|
| 39 |
+
self.pseudo_labels_file = pseudo_labels_file
|
| 40 |
+
|
| 41 |
+
if self.pseudo_labels_file is not None:
|
| 42 |
+
# fetch dict of positions for each example
|
| 43 |
+
self.ex2pseudo_idx = {}
|
| 44 |
+
f = h5py.File(self.pseudo_labels_file, "r")
|
| 45 |
+
for i, fname in enumerate(f["filenames"]):
|
| 46 |
+
self.ex2pseudo_idx[fname.decode("UTF-8")] = i
|
| 47 |
+
self._opened_pseudo_hdf5 = None
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def pseudo_hdf5_file(self):
|
| 51 |
+
if self._opened_pseudo_hdf5 is None:
|
| 52 |
+
self._opened_pseudo_hdf5 = h5py.File(self.pseudo_labels_file, "r")
|
| 53 |
+
return self._opened_pseudo_hdf5
|
| 54 |
+
|
| 55 |
+
def add_pseudo_label_transform(self, sample):
|
| 56 |
+
indices = [self.ex2pseudo_idx[fn.rstrip(".mp3")] for fn in sample['filename']]
|
| 57 |
+
pseudo_strong = [torch.from_numpy(np.stack(self.pseudo_hdf5_file["strong_logits"][index])).float()
|
| 58 |
+
for index in indices]
|
| 59 |
+
pseudo_strong = [torch.sigmoid(pseudo_strong[i]) for i in range(len(pseudo_strong))]
|
| 60 |
+
sample['pseudo_strong'] = pseudo_strong
|
| 61 |
+
return sample
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class SequentialTransform:
|
| 65 |
+
"""Apply a sequence of transforms to a batch."""
|
| 66 |
+
|
| 67 |
+
def __init__(self, transforms):
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
transforms: list of transforms to apply
|
| 71 |
+
"""
|
| 72 |
+
self.transforms = transforms
|
| 73 |
+
|
| 74 |
+
def append(self, transform):
|
| 75 |
+
self.transforms.append(transform)
|
| 76 |
+
|
| 77 |
+
def __call__(self, batch):
|
| 78 |
+
for t in self.transforms:
|
| 79 |
+
batch = t(batch)
|
| 80 |
+
return batch
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Mp3DecodeTransform:
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
mp3_bytes_key="mp3_bytes",
|
| 87 |
+
audio_key="audio",
|
| 88 |
+
sample_rate=32000,
|
| 89 |
+
max_length=10.0,
|
| 90 |
+
min_length=None,
|
| 91 |
+
random_sample_crop=True,
|
| 92 |
+
allow_resample=True,
|
| 93 |
+
resampling_method="sinc_interp_kaiser",
|
| 94 |
+
keep_mp3_bytes=False,
|
| 95 |
+
debug_info_key=None,
|
| 96 |
+
):
|
| 97 |
+
"""Decode mp3 bytes to audio waveform
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
mp3_bytes_key (str, optional): The key to mp3 bytes in the input batch. Defaults to "mp3_bytes".
|
| 101 |
+
audio_key (str, optional): The key to save the decoded audio in the output batch. Defaults to "audio".
|
| 102 |
+
sample_rate (int, optional): The expected output audio_key. Defaults to 32000.
|
| 103 |
+
max_length (int, float, optional): the maximum output audio length in seconds if float, otherwise in samples. Defaults to 10.
|
| 104 |
+
min_length (int, optional): the minimum output audio length in seconds. Defaults to max_length.
|
| 105 |
+
random_sample_crop (bool, optional): Randomly crop the audio to max_length if its longer otherwise return the first crop. Defaults to True.
|
| 106 |
+
allow_resample (bool, optional): Resample the singal if the sampling rate don't match. Defaults to True.
|
| 107 |
+
resampling_method (str, optional): reampling method from torchaudio.transforms.Resample . Defaults to "sinc_interp_kaiser".
|
| 108 |
+
keep_mp3_bytes (bool, optional): keep the original bytes in the output dict. Defaults to False.
|
| 109 |
+
|
| 110 |
+
Raises:
|
| 111 |
+
Exception: if minimp3py is not installed
|
| 112 |
+
"""
|
| 113 |
+
self.mp3_bytes_key = mp3_bytes_key
|
| 114 |
+
self.audio_key = audio_key
|
| 115 |
+
self.sample_rate = sample_rate
|
| 116 |
+
self.max_length = max_length
|
| 117 |
+
if min_length is None:
|
| 118 |
+
min_length = max_length
|
| 119 |
+
self.min_length = min_length
|
| 120 |
+
self.random_sample_crop = random_sample_crop
|
| 121 |
+
self.allow_resample = allow_resample
|
| 122 |
+
self.resampling_method = resampling_method
|
| 123 |
+
self.keep_mp3_bytes = keep_mp3_bytes
|
| 124 |
+
self.debug_info_key = debug_info_key
|
| 125 |
+
self.resamplers_cache = {}
|
| 126 |
+
try:
|
| 127 |
+
import minimp3py # noqa: F401
|
| 128 |
+
except:
|
| 129 |
+
raise Exception(
|
| 130 |
+
"minimp3py is not installed, please install it using: `CFLAGS='-O3 -march=native' pip install https://github.com/f0k/minimp3py/archive/master.zip`"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def __call__(self, batch):
|
| 134 |
+
import minimp3py
|
| 135 |
+
|
| 136 |
+
data_list = batch[self.mp3_bytes_key]
|
| 137 |
+
if self.debug_info_key is not None:
|
| 138 |
+
file_name_list = batch[self.debug_info_key]
|
| 139 |
+
else:
|
| 140 |
+
file_name_list = range(len(data_list))
|
| 141 |
+
audio_list = []
|
| 142 |
+
for data, file_name in zip(data_list, file_name_list):
|
| 143 |
+
try:
|
| 144 |
+
duration, ch, sr = minimp3py.probe(data)
|
| 145 |
+
if isinstance(self.max_length, float):
|
| 146 |
+
max_length = int(self.max_length * sr)
|
| 147 |
+
else:
|
| 148 |
+
max_length = int(self.max_length * sr // self.sample_rate)
|
| 149 |
+
offset = 0
|
| 150 |
+
if self.random_sample_crop and duration > max_length:
|
| 151 |
+
max_offset = max(int(duration - max_length), 0) + 1
|
| 152 |
+
offset = torch.randint(max_offset, (1,)).item()
|
| 153 |
+
waveform, _ = minimp3py.read(data, start=offset, length=max_length)
|
| 154 |
+
waveform = waveform[:, 0] # 0 for the first channel only
|
| 155 |
+
if waveform.dtype != "float32":
|
| 156 |
+
raise RuntimeError("Unexpected wave type")
|
| 157 |
+
|
| 158 |
+
waveform = torch.from_numpy(waveform)
|
| 159 |
+
if len(waveform) == 0:
|
| 160 |
+
logger.warning(
|
| 161 |
+
f"Empty waveform for {file_name}, duration {duration}, offset {offset}, max_length {max_length}, sr {sr}, ch {ch}"
|
| 162 |
+
)
|
| 163 |
+
elif sr != self.sample_rate:
|
| 164 |
+
assert self.allow_resample, f"Unexpected sample rate {sr} instead of {self.sample_rate} at {file_name}"
|
| 165 |
+
if self.resamplers_cache.get(sr) is None:
|
| 166 |
+
self.resamplers_cache[sr] = torchaudio.transforms.Resample(
|
| 167 |
+
sr,
|
| 168 |
+
self.sample_rate,
|
| 169 |
+
resampling_method=self.resampling_method,
|
| 170 |
+
)
|
| 171 |
+
waveform = self.resamplers_cache[sr](waveform)
|
| 172 |
+
min_length = self.min_length
|
| 173 |
+
if isinstance(self.min_length, float):
|
| 174 |
+
min_length = int(self.min_length * self.sample_rate)
|
| 175 |
+
if min_length is not None and len(waveform) < min_length:
|
| 176 |
+
waveform = torch.concatenate(
|
| 177 |
+
(
|
| 178 |
+
waveform,
|
| 179 |
+
torch.zeros(
|
| 180 |
+
min_length - len(waveform),
|
| 181 |
+
dtype=waveform.dtype,
|
| 182 |
+
device=waveform.device,
|
| 183 |
+
),
|
| 184 |
+
),
|
| 185 |
+
dim=0,
|
| 186 |
+
)
|
| 187 |
+
audio_list.append(waveform)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"Error decoding {file_name}: {e}")
|
| 190 |
+
raise e
|
| 191 |
+
batch[self.audio_key] = audio_list
|
| 192 |
+
batch["sampling_rate"] = [self.sample_rate] * len(audio_list)
|
| 193 |
+
if not self.keep_mp3_bytes:
|
| 194 |
+
del batch[self.mp3_bytes_key]
|
| 195 |
+
return batch
|
ex_audioset_strong.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import argparse
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import wandb
|
| 8 |
+
import transformers
|
| 9 |
+
import random
|
| 10 |
+
import pytorch_lightning as pl
|
| 11 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 12 |
+
import sed_scores_eval
|
| 13 |
+
|
| 14 |
+
from helpers.decode import batched_decode_preds
|
| 15 |
+
from helpers.encode import ManyHotEncoder
|
| 16 |
+
from models.atstframe.ATSTF_wrapper import ATSTWrapper
|
| 17 |
+
from models.beats.BEATs_wrapper import BEATsWrapper
|
| 18 |
+
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
|
| 19 |
+
from models.m2d.M2D_wrapper import M2DWrapper
|
| 20 |
+
from models.asit.ASIT_wrapper import ASiTWrapper
|
| 21 |
+
from models.prediction_wrapper import PredictionsWrapper
|
| 22 |
+
from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop
|
| 23 |
+
from helpers.utils import worker_init_fn
|
| 24 |
+
from data_util.audioset_strong import get_training_dataset, get_eval_dataset
|
| 25 |
+
from data_util.audioset_strong import get_temporal_count_balanced_sample_weights, get_uniform_sample_weights, \
|
| 26 |
+
get_weighted_sampler
|
| 27 |
+
from data_util.audioset_classes import as_strong_train_classes, as_strong_eval_classes
|
| 28 |
+
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
|
| 29 |
+
from models.frame_mn.utils import NAME_TO_WIDTH
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class PLModule(pl.LightningModule):
|
| 33 |
+
def __init__(self, config, encoder):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.config = config
|
| 36 |
+
self.encoder = encoder
|
| 37 |
+
|
| 38 |
+
if config.pretrained == "scratch":
|
| 39 |
+
checkpoint = None
|
| 40 |
+
elif config.pretrained == "ssl":
|
| 41 |
+
checkpoint = "ssl"
|
| 42 |
+
elif config.pretrained == "weak":
|
| 43 |
+
checkpoint = "weak"
|
| 44 |
+
elif config.pretrained == "strong":
|
| 45 |
+
checkpoint = "strong_1"
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}")
|
| 48 |
+
|
| 49 |
+
# load transformer model
|
| 50 |
+
if config.model_name == "BEATs":
|
| 51 |
+
beats = BEATsWrapper()
|
| 52 |
+
model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None,
|
| 53 |
+
seq_model_type=config.seq_model_type)
|
| 54 |
+
elif config.model_name == "ATST-F":
|
| 55 |
+
atst = ATSTWrapper()
|
| 56 |
+
model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None,
|
| 57 |
+
seq_model_type=config.seq_model_type)
|
| 58 |
+
elif config.model_name == "fpasst":
|
| 59 |
+
fpasst = FPaSSTWrapper()
|
| 60 |
+
model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None,
|
| 61 |
+
seq_model_type=config.seq_model_type)
|
| 62 |
+
elif config.model_name == "M2D":
|
| 63 |
+
m2d = M2DWrapper()
|
| 64 |
+
model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None,
|
| 65 |
+
seq_model_type=config.seq_model_type,
|
| 66 |
+
embed_dim=m2d.m2d.cfg.feature_d)
|
| 67 |
+
elif config.model_name == "ASIT":
|
| 68 |
+
asit = ASiTWrapper()
|
| 69 |
+
model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None,
|
| 70 |
+
seq_model_type=config.seq_model_type)
|
| 71 |
+
elif config.model_name.startswith("frame_mn"):
|
| 72 |
+
width = NAME_TO_WIDTH(config.model_name)
|
| 73 |
+
frame_mn = FrameMNWrapper(width)
|
| 74 |
+
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
|
| 75 |
+
model = PredictionsWrapper(frame_mn, checkpoint=f"{config.model_name}_strong_1", embed_dim=embed_dim)
|
| 76 |
+
else:
|
| 77 |
+
raise NotImplementedError(f"Model {config.model_name} not (yet) implemented")
|
| 78 |
+
|
| 79 |
+
self.model = model
|
| 80 |
+
|
| 81 |
+
# prepare ingredients for knowledge distillation
|
| 82 |
+
assert 0 <= config.distillation_loss_weight <= 1, "Lambda for Knowledge Distillation must be between 0 and 1."
|
| 83 |
+
self.strong_loss = nn.BCEWithLogitsLoss()
|
| 84 |
+
|
| 85 |
+
self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0))
|
| 86 |
+
|
| 87 |
+
self.val_durations_df = pd.read_csv(f"resources/eval_durations.csv",
|
| 88 |
+
sep=",", header=None, names=["filename", "duration"])
|
| 89 |
+
self.val_predictions_strong = {}
|
| 90 |
+
self.val_ground_truth = {}
|
| 91 |
+
self.val_duration = {}
|
| 92 |
+
self.val_loss = []
|
| 93 |
+
|
| 94 |
+
def forward(self, batch):
|
| 95 |
+
x = batch["audio"]
|
| 96 |
+
mel = self.model.mel_forward(x)
|
| 97 |
+
y_strong, _ = self.model(mel)
|
| 98 |
+
return y_strong
|
| 99 |
+
|
| 100 |
+
def get_optimizer(
|
| 101 |
+
self, lr, adamw=False, weight_decay=0.01, betas=(0.9, 0.999)
|
| 102 |
+
):
|
| 103 |
+
# we split the parameters into two groups, one for the pretrained model and one for the downstream model
|
| 104 |
+
# we also split each of them into <=1 dimensional and >=2 dimensional parameters, so we can only
|
| 105 |
+
# apply weight decay to the >=2 dimensional parameters, thus excluding biases and batch norms, an idea from NanoGPT
|
| 106 |
+
params_leq1D = []
|
| 107 |
+
params_geq2D = []
|
| 108 |
+
|
| 109 |
+
for name, param in self.model.named_parameters():
|
| 110 |
+
if param.requires_grad:
|
| 111 |
+
if param.ndimension() >= 2:
|
| 112 |
+
params_geq2D.append(param)
|
| 113 |
+
else:
|
| 114 |
+
params_leq1D.append(param)
|
| 115 |
+
|
| 116 |
+
param_groups = [
|
| 117 |
+
{'params': params_leq1D, 'lr': lr},
|
| 118 |
+
{'params': params_geq2D, 'lr': lr, 'weight_decay': weight_decay},
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
if weight_decay > 0:
|
| 122 |
+
assert adamw
|
| 123 |
+
assert len(param_groups) > 0
|
| 124 |
+
if adamw:
|
| 125 |
+
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
|
| 126 |
+
return torch.optim.AdamW(param_groups, lr=lr, betas=betas)
|
| 127 |
+
return torch.optim.Adam(param_groups, lr=lr, betas=betas)
|
| 128 |
+
|
| 129 |
+
def get_lr_scheduler(
|
| 130 |
+
self,
|
| 131 |
+
optimizer,
|
| 132 |
+
num_training_steps,
|
| 133 |
+
schedule_mode="cos",
|
| 134 |
+
gamma: float = 0.999996,
|
| 135 |
+
num_warmup_steps=20000,
|
| 136 |
+
lr_end=2e-7,
|
| 137 |
+
):
|
| 138 |
+
if schedule_mode in {"exp"}:
|
| 139 |
+
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
|
| 140 |
+
if schedule_mode in {"cosine", "cos"}:
|
| 141 |
+
return transformers.get_cosine_schedule_with_warmup(
|
| 142 |
+
optimizer,
|
| 143 |
+
num_warmup_steps=num_warmup_steps,
|
| 144 |
+
num_training_steps=num_training_steps,
|
| 145 |
+
)
|
| 146 |
+
if schedule_mode in {"linear"}:
|
| 147 |
+
print("Linear schedule!")
|
| 148 |
+
return transformers.get_polynomial_decay_schedule_with_warmup(
|
| 149 |
+
optimizer,
|
| 150 |
+
num_warmup_steps=num_warmup_steps,
|
| 151 |
+
num_training_steps=num_training_steps,
|
| 152 |
+
power=1.0,
|
| 153 |
+
lr_end=lr_end,
|
| 154 |
+
)
|
| 155 |
+
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
|
| 156 |
+
|
| 157 |
+
def configure_optimizers(self):
|
| 158 |
+
"""
|
| 159 |
+
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
|
| 160 |
+
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
|
| 161 |
+
:return: dict containing optimizer and learning rate scheduler
|
| 162 |
+
"""
|
| 163 |
+
optimizer = self.get_optimizer(self.config.max_lr, adamw=self.config.adamw,
|
| 164 |
+
weight_decay=self.config.weight_decay)
|
| 165 |
+
|
| 166 |
+
num_training_steps = self.trainer.estimated_stepping_batches
|
| 167 |
+
|
| 168 |
+
scheduler = self.get_lr_scheduler(optimizer, num_training_steps,
|
| 169 |
+
schedule_mode=self.config.schedule_mode,
|
| 170 |
+
lr_end=self.config.lr_end)
|
| 171 |
+
lr_scheduler_config = {
|
| 172 |
+
"scheduler": scheduler,
|
| 173 |
+
"interval": "step",
|
| 174 |
+
"frequency": 1
|
| 175 |
+
}
|
| 176 |
+
return [optimizer], [lr_scheduler_config]
|
| 177 |
+
|
| 178 |
+
def training_step(self, train_batch, batch_idx):
|
| 179 |
+
"""
|
| 180 |
+
:param train_batch: contains one batch from train dataloader
|
| 181 |
+
:param batch_idx
|
| 182 |
+
:return: a dict containing at least loss that is used to update model parameters, can also contain
|
| 183 |
+
other items that can be processed in 'training_epoch_end' to log other metrics than loss
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
x = train_batch["audio"]
|
| 187 |
+
labels = train_batch['strong']
|
| 188 |
+
if 'pseudo_strong' in train_batch:
|
| 189 |
+
pseudo_labels = train_batch['pseudo_strong']
|
| 190 |
+
else:
|
| 191 |
+
# create dummy pseudo labels
|
| 192 |
+
pseudo_labels = torch.zeros_like(labels)
|
| 193 |
+
assert self.config.distillation_loss_weight == 0
|
| 194 |
+
|
| 195 |
+
mel = self.model.mel_forward(x)
|
| 196 |
+
|
| 197 |
+
# time rolling
|
| 198 |
+
if self.config.frame_shift_range > 0:
|
| 199 |
+
mel, labels, pseudo_labels = frame_shift(
|
| 200 |
+
mel,
|
| 201 |
+
labels,
|
| 202 |
+
pseudo_labels=pseudo_labels,
|
| 203 |
+
net_pooling=self.encoder.net_pooling,
|
| 204 |
+
shift_range=self.config.frame_shift_range
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# mixup
|
| 208 |
+
if self.config.mixup_p > random.random():
|
| 209 |
+
mel, labels, pseudo_labels = mixup(
|
| 210 |
+
mel,
|
| 211 |
+
targets=labels,
|
| 212 |
+
pseudo_strong=pseudo_labels
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# mixstyle
|
| 216 |
+
if self.config.mixstyle_p > random.random():
|
| 217 |
+
mel = mixstyle(
|
| 218 |
+
mel
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# time masking
|
| 222 |
+
if self.config.max_time_mask_size > 0:
|
| 223 |
+
mel, labels, pseudo_labels = time_mask(
|
| 224 |
+
mel,
|
| 225 |
+
labels,
|
| 226 |
+
pseudo_labels=pseudo_labels,
|
| 227 |
+
net_pooling=self.encoder.net_pooling,
|
| 228 |
+
max_mask_ratio=self.config.max_time_mask_size
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# frequency masking
|
| 232 |
+
if self.config.filter_augment_p > random.random():
|
| 233 |
+
mel, _ = filter_augmentation(
|
| 234 |
+
mel
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# frequency warping
|
| 238 |
+
if self.config.freq_warp_p > random.random():
|
| 239 |
+
mel = mel.squeeze(1)
|
| 240 |
+
mel = self.freq_warp(mel)
|
| 241 |
+
mel = mel.unsqueeze(1)
|
| 242 |
+
|
| 243 |
+
# forward through network; use strong head
|
| 244 |
+
y_hat_strong, _ = self.model(mel)
|
| 245 |
+
|
| 246 |
+
strong_supervised_loss = self.strong_loss(y_hat_strong, labels)
|
| 247 |
+
|
| 248 |
+
if self.config.distillation_loss_weight > 0:
|
| 249 |
+
strong_distillation_loss = self.strong_loss(y_hat_strong, pseudo_labels)
|
| 250 |
+
else:
|
| 251 |
+
strong_distillation_loss = torch.tensor(0., device=y_hat_strong.device, dtype=y_hat_strong.dtype)
|
| 252 |
+
|
| 253 |
+
loss = self.config.distillation_loss_weight * strong_distillation_loss \
|
| 254 |
+
+ (1 - self.config.distillation_loss_weight) * strong_supervised_loss
|
| 255 |
+
|
| 256 |
+
# logging
|
| 257 |
+
self.log('epoch', self.current_epoch)
|
| 258 |
+
for i, param_group in enumerate(self.trainer.optimizers[0].param_groups):
|
| 259 |
+
self.log(f'trainer/lr_optimizer_{i}', param_group['lr'])
|
| 260 |
+
self.log("train/loss", loss.detach().cpu(), prog_bar=True)
|
| 261 |
+
self.log("train/strong_supervised_loss", strong_supervised_loss.detach().cpu())
|
| 262 |
+
self.log("train/strong_distillation_loss", strong_distillation_loss.detach().cpu())
|
| 263 |
+
|
| 264 |
+
return loss
|
| 265 |
+
|
| 266 |
+
def validation_step(self, val_batch, batch_idx):
|
| 267 |
+
# bring ground truth into shape needed for evaluation
|
| 268 |
+
for f, gt_string in zip(val_batch["filename"], val_batch["gt_string"]):
|
| 269 |
+
f = f[:-len(".mp3")]
|
| 270 |
+
events = [e.split(";;") for e in gt_string.split("++")]
|
| 271 |
+
self.val_ground_truth[f] = [(float(e[0]), float(e[1]), e[2]) for e in events]
|
| 272 |
+
self.val_duration[f] = self.val_durations_df[self.val_durations_df["filename"] == f]["duration"].values[0]
|
| 273 |
+
|
| 274 |
+
y_hat_strong = self(val_batch)
|
| 275 |
+
y_strong = val_batch["strong"]
|
| 276 |
+
|
| 277 |
+
loss = self.strong_loss(y_hat_strong, y_strong)
|
| 278 |
+
self.val_loss.append(loss.cpu())
|
| 279 |
+
|
| 280 |
+
scores_raw, scores_postprocessed, prediction_dfs = batched_decode_preds(
|
| 281 |
+
y_hat_strong.float(),
|
| 282 |
+
val_batch['filename'],
|
| 283 |
+
self.encoder,
|
| 284 |
+
median_filter=self.config.median_window
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
self.val_predictions_strong.update(
|
| 288 |
+
scores_postprocessed
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def on_validation_epoch_end(self):
|
| 292 |
+
gt_unique_events = set([e[2] for f, events in self.val_ground_truth.items() for e in events])
|
| 293 |
+
train_unique_events = set(self.encoder.labels)
|
| 294 |
+
# evaluate on all classes that are in both train and test sets (407 classes)
|
| 295 |
+
class_intersection = gt_unique_events.intersection(train_unique_events)
|
| 296 |
+
|
| 297 |
+
assert len(class_intersection) == len(set(as_strong_train_classes).intersection(as_strong_eval_classes)) == 407, \
|
| 298 |
+
f"Intersection unique events. Expected: {len(set(as_strong_train_classes).intersection(as_strong_eval_classes))}," \
|
| 299 |
+
f" Actual: {len(class_intersection)}"
|
| 300 |
+
|
| 301 |
+
# filter ground truth according to class_intersection
|
| 302 |
+
val_ground_truth = {fid: [event for event in self.val_ground_truth[fid] if event[2] in class_intersection]
|
| 303 |
+
for fid in self.val_ground_truth}
|
| 304 |
+
# drop audios without events - aligned with DESED evaluation procedure
|
| 305 |
+
val_ground_truth = {fid: events for fid, events in val_ground_truth.items() if len(events) > 0}
|
| 306 |
+
# keep only corresponding audio durations
|
| 307 |
+
audio_durations = {
|
| 308 |
+
fid: self.val_duration[fid] for fid in val_ground_truth.keys()
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
# filter files in predictions
|
| 312 |
+
as_strong_preds = {
|
| 313 |
+
fid: self.val_predictions_strong[fid] for fid in val_ground_truth.keys()
|
| 314 |
+
}
|
| 315 |
+
# filter classes in predictions
|
| 316 |
+
unused_classes = list(set(self.encoder.labels).difference(class_intersection))
|
| 317 |
+
for f, df in as_strong_preds.items():
|
| 318 |
+
df.drop(columns=list(unused_classes), axis=1, inplace=True)
|
| 319 |
+
|
| 320 |
+
segment_based_pauroc = sed_scores_eval.segment_based.auroc(
|
| 321 |
+
as_strong_preds,
|
| 322 |
+
val_ground_truth,
|
| 323 |
+
audio_durations,
|
| 324 |
+
max_fpr=0.1,
|
| 325 |
+
segment_length=1.0,
|
| 326 |
+
num_jobs=1
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
psds1 = sed_scores_eval.intersection_based.psds(
|
| 330 |
+
as_strong_preds,
|
| 331 |
+
val_ground_truth,
|
| 332 |
+
audio_durations,
|
| 333 |
+
dtc_threshold=0.7,
|
| 334 |
+
gtc_threshold=0.7,
|
| 335 |
+
cttc_threshold=None,
|
| 336 |
+
alpha_ct=0,
|
| 337 |
+
alpha_st=1,
|
| 338 |
+
num_jobs=1
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# "val/psds1_macro_averaged" is psds1 without penalization for performance
|
| 342 |
+
# variations across classes
|
| 343 |
+
logs = {"val/loss": torch.as_tensor(self.val_loss).mean().cuda(),
|
| 344 |
+
"val/psds1": psds1[0],
|
| 345 |
+
"val/psds1_macro_averaged": np.array([v for k, v in psds1[1].items()]).mean(),
|
| 346 |
+
"val/pauroc": segment_based_pauroc[0]['mean'],
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
self.log_dict(logs, sync_dist=False)
|
| 350 |
+
self.val_predictions_strong = {}
|
| 351 |
+
self.val_ground_truth = {}
|
| 352 |
+
self.val_duration = {}
|
| 353 |
+
self.val_loss = []
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def train(config):
|
| 357 |
+
# Train Models on temporally-strong portion of AudioSet.
|
| 358 |
+
|
| 359 |
+
# logging is done using wandb
|
| 360 |
+
wandb_logger = WandbLogger(
|
| 361 |
+
project="PTSED",
|
| 362 |
+
notes="Pre-Training Transformers for Sound Event Detection on AudioSet Strong.",
|
| 363 |
+
tags=["AudioSet Strong", "Sound Event Detection", "Pseudo Labels", "Knowledge Disitillation"],
|
| 364 |
+
config=config,
|
| 365 |
+
name=config.experiment_name
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# encoder manages encoding and decoding of model predictions
|
| 369 |
+
encoder = ManyHotEncoder(as_strong_train_classes)
|
| 370 |
+
|
| 371 |
+
train_set = get_training_dataset(encoder, wavmix_p=config.wavmix_p,
|
| 372 |
+
pseudo_labels_file=config.pseudo_labels_file)
|
| 373 |
+
eval_set = get_eval_dataset(encoder)
|
| 374 |
+
|
| 375 |
+
if config.use_balanced_sampler:
|
| 376 |
+
sample_weights = get_temporal_count_balanced_sample_weights(train_set, save_folder="resources")
|
| 377 |
+
else:
|
| 378 |
+
sample_weights = get_uniform_sample_weights(train_set)
|
| 379 |
+
|
| 380 |
+
train_sampler = get_weighted_sampler(sample_weights, epoch_len=config.epoch_len)
|
| 381 |
+
|
| 382 |
+
# train dataloader
|
| 383 |
+
train_dl = DataLoader(dataset=train_set,
|
| 384 |
+
sampler=train_sampler,
|
| 385 |
+
worker_init_fn=worker_init_fn,
|
| 386 |
+
num_workers=config.num_workers,
|
| 387 |
+
batch_size=config.batch_size,
|
| 388 |
+
shuffle=False)
|
| 389 |
+
|
| 390 |
+
# eval dataloader
|
| 391 |
+
eval_dl = DataLoader(dataset=eval_set,
|
| 392 |
+
worker_init_fn=worker_init_fn,
|
| 393 |
+
num_workers=config.num_workers,
|
| 394 |
+
batch_size=config.batch_size)
|
| 395 |
+
|
| 396 |
+
# create pytorch lightening module
|
| 397 |
+
pl_module = PLModule(config, encoder)
|
| 398 |
+
|
| 399 |
+
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
|
| 400 |
+
# on which kind of device(s) to train and possible callbacks
|
| 401 |
+
trainer = pl.Trainer(max_epochs=config.n_epochs,
|
| 402 |
+
logger=wandb_logger,
|
| 403 |
+
accelerator='auto',
|
| 404 |
+
devices=config.num_devices,
|
| 405 |
+
precision=config.precision,
|
| 406 |
+
num_sanity_val_steps=0,
|
| 407 |
+
check_val_every_n_epoch=config.check_val_every_n_epoch
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# start training and validation for the specified number of epochs
|
| 411 |
+
trainer.fit(pl_module, train_dl, eval_dl)
|
| 412 |
+
|
| 413 |
+
wandb.finish()
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def evaluate(config):
|
| 417 |
+
# only evaluation of pre-trained models
|
| 418 |
+
# encoder manages encoding and decoding of model predictions
|
| 419 |
+
encoder = ManyHotEncoder(as_strong_train_classes)
|
| 420 |
+
eval_set = get_eval_dataset(encoder)
|
| 421 |
+
|
| 422 |
+
# eval dataloader
|
| 423 |
+
eval_dl = DataLoader(dataset=eval_set,
|
| 424 |
+
worker_init_fn=worker_init_fn,
|
| 425 |
+
num_workers=config.num_workers,
|
| 426 |
+
batch_size=config.batch_size)
|
| 427 |
+
|
| 428 |
+
# create pytorch lightening module
|
| 429 |
+
pl_module = PLModule(config, encoder)
|
| 430 |
+
|
| 431 |
+
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
|
| 432 |
+
# on which kind of device(s) to train and possible callbacks
|
| 433 |
+
trainer = pl.Trainer(max_epochs=config.n_epochs,
|
| 434 |
+
accelerator='auto',
|
| 435 |
+
devices=config.num_devices,
|
| 436 |
+
precision=config.precision,
|
| 437 |
+
num_sanity_val_steps=0,
|
| 438 |
+
check_val_every_n_epoch=config.check_val_every_n_epoch)
|
| 439 |
+
|
| 440 |
+
# start evaluation
|
| 441 |
+
trainer.validate(pl_module, eval_dl)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
if __name__ == '__main__':
|
| 445 |
+
parser = argparse.ArgumentParser(description='Configuration Parser. ')
|
| 446 |
+
|
| 447 |
+
# general
|
| 448 |
+
parser.add_argument('--experiment_name', type=str, default="AudioSet_Strong")
|
| 449 |
+
parser.add_argument('--batch_size', type=int, default=256)
|
| 450 |
+
parser.add_argument('--num_workers', type=int, default=16)
|
| 451 |
+
parser.add_argument('--num_devices', type=int, default=1)
|
| 452 |
+
parser.add_argument('--precision', type=int, default=16)
|
| 453 |
+
parser.add_argument('--evaluate', action='store_true', default=False)
|
| 454 |
+
parser.add_argument('--check_val_every_n_epoch', type=int, default=5)
|
| 455 |
+
|
| 456 |
+
# model
|
| 457 |
+
parser.add_argument('--model_name', type=str,
|
| 458 |
+
choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"] + \
|
| 459 |
+
[f"frame_mn{width}" for width in ["06", "10"]],
|
| 460 |
+
default="ATST-F") # used also for training
|
| 461 |
+
# "scratch" = no pretraining
|
| 462 |
+
# "ssl" = SSL pre-trained
|
| 463 |
+
# "weak" = AudioSet Weak pre-trained
|
| 464 |
+
# "strong" = AudioSet Strong pre-trained
|
| 465 |
+
parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"],
|
| 466 |
+
default="weak")
|
| 467 |
+
parser.add_argument('--seq_model_type', type=str, choices=["rnn"],
|
| 468 |
+
default=None)
|
| 469 |
+
|
| 470 |
+
# training
|
| 471 |
+
parser.add_argument('--n_epochs', type=int, default=30)
|
| 472 |
+
parser.add_argument('--use_balanced_sampler', action='store_true', default=False)
|
| 473 |
+
parser.add_argument('--distillation_loss_weight', type=float, default=0.0)
|
| 474 |
+
parser.add_argument('--epoch_len', type=int, default=100000)
|
| 475 |
+
parser.add_argument('--median_window', type=int, default=9)
|
| 476 |
+
|
| 477 |
+
# augmentation
|
| 478 |
+
parser.add_argument('--wavmix_p', type=float, default=0.8)
|
| 479 |
+
parser.add_argument('--freq_warp_p', type=float, default=0.8)
|
| 480 |
+
parser.add_argument('--filter_augment_p', type=float, default=0.8)
|
| 481 |
+
parser.add_argument('--frame_shift_range', type=float, default=0.125) # in seconds
|
| 482 |
+
parser.add_argument('--mixup_p', type=float, default=0.3)
|
| 483 |
+
parser.add_argument('--mixstyle_p', type=float, default=0.3)
|
| 484 |
+
parser.add_argument('--max_time_mask_size', type=float, default=0.0)
|
| 485 |
+
|
| 486 |
+
# optimizer
|
| 487 |
+
parser.add_argument('--adamw', action='store_true', default=False)
|
| 488 |
+
parser.add_argument('--weight_decay', type=float, default=0.0)
|
| 489 |
+
|
| 490 |
+
# lr schedule
|
| 491 |
+
parser.add_argument('--schedule_mode', type=str, default="cos")
|
| 492 |
+
parser.add_argument('--max_lr', type=float, default=7e-5)
|
| 493 |
+
parser.add_argument('--lr_end', type=float, default=2e-7)
|
| 494 |
+
parser.add_argument('--warmup_steps', type=int, default=5000)
|
| 495 |
+
|
| 496 |
+
# knowledge distillation
|
| 497 |
+
parser.add_argument('--pseudo_labels_file', type=str,
|
| 498 |
+
default=None)
|
| 499 |
+
|
| 500 |
+
args = parser.parse_args()
|
| 501 |
+
if args.evaluate:
|
| 502 |
+
evaluate(args)
|
| 503 |
+
else:
|
| 504 |
+
train(args)
|
ex_dcase2016task2.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import transformers
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
|
| 14 |
+
import wandb
|
| 15 |
+
from data_util.dcase2016task2 import (get_training_dataset, get_validation_dataset, get_test_dataset,
|
| 16 |
+
label_vocab_nlabels, label_vocab_as_dict)
|
| 17 |
+
from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop
|
| 18 |
+
from helpers.score import get_events_for_all_files, combine_target_events, EventBasedScore, SegmentBasedScore
|
| 19 |
+
from helpers.utils import worker_init_fn
|
| 20 |
+
from models.asit.ASIT_wrapper import ASiTWrapper
|
| 21 |
+
from models.atstframe.ATSTF_wrapper import ATSTWrapper
|
| 22 |
+
from models.beats.BEATs_wrapper import BEATsWrapper
|
| 23 |
+
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
|
| 24 |
+
from models.m2d.M2D_wrapper import M2DWrapper
|
| 25 |
+
from models.prediction_wrapper import PredictionsWrapper
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PLModule(pl.LightningModule):
|
| 29 |
+
def __init__(self, config):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.config = config
|
| 32 |
+
|
| 33 |
+
if config.pretrained == "scratch":
|
| 34 |
+
checkpoint = None
|
| 35 |
+
elif config.pretrained == "ssl":
|
| 36 |
+
checkpoint = "ssl"
|
| 37 |
+
elif config.pretrained == "weak":
|
| 38 |
+
checkpoint = "weak"
|
| 39 |
+
elif config.pretrained == "strong":
|
| 40 |
+
checkpoint = "strong_1"
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}")
|
| 43 |
+
|
| 44 |
+
# load transformer model
|
| 45 |
+
if config.model_name == "BEATs":
|
| 46 |
+
beats = BEATsWrapper()
|
| 47 |
+
model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None,
|
| 48 |
+
seq_model_type=config.seq_model_type,
|
| 49 |
+
n_classes_strong=self.config.n_classes)
|
| 50 |
+
elif config.model_name == "ATST-F":
|
| 51 |
+
atst = ATSTWrapper()
|
| 52 |
+
model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None,
|
| 53 |
+
seq_model_type=config.seq_model_type,
|
| 54 |
+
n_classes_strong=self.config.n_classes)
|
| 55 |
+
elif config.model_name == "fpasst":
|
| 56 |
+
fpasst = FPaSSTWrapper()
|
| 57 |
+
model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None,
|
| 58 |
+
seq_model_type=config.seq_model_type,
|
| 59 |
+
n_classes_strong=self.config.n_classes)
|
| 60 |
+
elif config.model_name == "M2D":
|
| 61 |
+
m2d = M2DWrapper()
|
| 62 |
+
model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None,
|
| 63 |
+
seq_model_type=config.seq_model_type,
|
| 64 |
+
n_classes_strong=self.config.n_classes,
|
| 65 |
+
embed_dim=m2d.m2d.cfg.feature_d)
|
| 66 |
+
elif config.model_name == "ASIT":
|
| 67 |
+
asit = ASiTWrapper()
|
| 68 |
+
model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None,
|
| 69 |
+
seq_model_type=config.seq_model_type,
|
| 70 |
+
n_classes_strong=self.config.n_classes)
|
| 71 |
+
else:
|
| 72 |
+
raise NotImplementedError(f"Model {config.model_name} not (yet) implemented")
|
| 73 |
+
|
| 74 |
+
self.model = model
|
| 75 |
+
self.strong_loss = nn.BCEWithLogitsLoss()
|
| 76 |
+
|
| 77 |
+
self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0))
|
| 78 |
+
|
| 79 |
+
task_path = Path(self.config.task_path)
|
| 80 |
+
label_vocab, nlabels = label_vocab_nlabels(task_path)
|
| 81 |
+
self.label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx")
|
| 82 |
+
|
| 83 |
+
self.idx_to_label: Dict[int, str] = {
|
| 84 |
+
idx: label for (label, idx) in self.label_to_idx.items()
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
self.event_onset_200ms_fms = EventBasedScore(
|
| 88 |
+
label_to_idx=self.label_to_idx,
|
| 89 |
+
name="event_onset_200ms_fms",
|
| 90 |
+
scores=("f_measure", "precision", "recall"),
|
| 91 |
+
params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.2}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.event_onset_50ms_fms = EventBasedScore(
|
| 95 |
+
label_to_idx=self.label_to_idx,
|
| 96 |
+
name="event_onset_50ms_fms",
|
| 97 |
+
scores=("f_measure", "precision", "recall"),
|
| 98 |
+
params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.05}
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.segment_1s_er = SegmentBasedScore(
|
| 102 |
+
label_to_idx=self.label_to_idx,
|
| 103 |
+
name="segment_1s_er",
|
| 104 |
+
scores=("error_rate",),
|
| 105 |
+
params={"time_resolution": 1.0},
|
| 106 |
+
maximize=False,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.postprocessing_grid = {
|
| 110 |
+
"median_filter_ms": [
|
| 111 |
+
250
|
| 112 |
+
],
|
| 113 |
+
"min_duration": [
|
| 114 |
+
125
|
| 115 |
+
]
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], []
|
| 119 |
+
|
| 120 |
+
def forward(self, audio):
|
| 121 |
+
mel = self.model.mel_forward(audio)
|
| 122 |
+
y_strong, _ = self.model(mel)
|
| 123 |
+
return y_strong
|
| 124 |
+
|
| 125 |
+
def separate_params(self):
|
| 126 |
+
pt_params = []
|
| 127 |
+
seq_params = []
|
| 128 |
+
head_params = []
|
| 129 |
+
|
| 130 |
+
for name, p in self.named_parameters():
|
| 131 |
+
name = name[len("model."):]
|
| 132 |
+
if name.startswith('model'):
|
| 133 |
+
# the transformer
|
| 134 |
+
pt_params.append(p)
|
| 135 |
+
elif name.startswith('seq_model'):
|
| 136 |
+
# the optional sequence model
|
| 137 |
+
seq_params.append(p)
|
| 138 |
+
elif name.startswith('strong_head') or name.startswith('weak_head'):
|
| 139 |
+
# the prediction head
|
| 140 |
+
head_params.append(p)
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"Unexpected key in model: {name}")
|
| 143 |
+
|
| 144 |
+
if self.model.has_separate_params():
|
| 145 |
+
# split parameters into groups according to their depth in the network
|
| 146 |
+
# based on this, we can apply layer-wise learning rate decay
|
| 147 |
+
pt_params = self.model.separate_params()
|
| 148 |
+
else:
|
| 149 |
+
if self.config.lr_decay != 1.0:
|
| 150 |
+
raise ValueError(f"Model has no separate_params function. Can't apply layer-wise lr decay, but "
|
| 151 |
+
f"learning rate decay is set to {self.config.lr_decay}.")
|
| 152 |
+
|
| 153 |
+
return pt_params, seq_params, head_params
|
| 154 |
+
|
| 155 |
+
def get_optimizer(
|
| 156 |
+
self,
|
| 157 |
+
lr,
|
| 158 |
+
lr_decay=1.0,
|
| 159 |
+
transformer_lr=None,
|
| 160 |
+
transformer_frozen=False,
|
| 161 |
+
adamw=False,
|
| 162 |
+
weight_decay=0.01,
|
| 163 |
+
betas=(0.9, 0.999)
|
| 164 |
+
):
|
| 165 |
+
pt_params, seq_params, head_params = self.separate_params()
|
| 166 |
+
|
| 167 |
+
param_groups = [
|
| 168 |
+
{'params': head_params, 'lr': lr}, # model head (besides base model and seq model)
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
if transformer_frozen:
|
| 172 |
+
for p in pt_params + seq_params:
|
| 173 |
+
if isinstance(p, list):
|
| 174 |
+
for p_i in p:
|
| 175 |
+
p_i.detach_()
|
| 176 |
+
else:
|
| 177 |
+
p.detach_()
|
| 178 |
+
else:
|
| 179 |
+
if transformer_lr is None:
|
| 180 |
+
transformer_lr = lr
|
| 181 |
+
if isinstance(pt_params, list) and isinstance(pt_params[0], list):
|
| 182 |
+
# apply lr decay
|
| 183 |
+
scale_lrs = [transformer_lr * (lr_decay ** i) for i in range(1, len(pt_params) + 1)]
|
| 184 |
+
param_groups = param_groups + [{"params": pt_params[i], "lr": scale_lrs[i]} for i in
|
| 185 |
+
range(len(pt_params))]
|
| 186 |
+
else:
|
| 187 |
+
param_groups.append(
|
| 188 |
+
{'params': pt_params, 'lr': transformer_lr}, # pretrained model
|
| 189 |
+
)
|
| 190 |
+
param_groups.append(
|
| 191 |
+
{'params': seq_params, 'lr': lr}, # pretrained model
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# do not apply weight decay to biases and batch norms
|
| 195 |
+
param_groups_split = []
|
| 196 |
+
for param_group in param_groups:
|
| 197 |
+
params_1D, params_2D = [], []
|
| 198 |
+
lr = param_group['lr']
|
| 199 |
+
for param in param_group['params']:
|
| 200 |
+
if param.ndimension() >= 2:
|
| 201 |
+
params_2D.append(param)
|
| 202 |
+
elif param.ndimension() <= 1:
|
| 203 |
+
params_1D.append(param)
|
| 204 |
+
param_groups_split += [{'params': params_2D, 'lr': lr, 'weight_decay': weight_decay},
|
| 205 |
+
{'params': params_1D, 'lr': lr}]
|
| 206 |
+
if weight_decay > 0:
|
| 207 |
+
assert adamw
|
| 208 |
+
if adamw:
|
| 209 |
+
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
|
| 210 |
+
return torch.optim.AdamW(param_groups_split, lr=lr, weight_decay=weight_decay, betas=betas)
|
| 211 |
+
return torch.optim.Adam(param_groups_split, lr=lr, betas=betas)
|
| 212 |
+
|
| 213 |
+
def get_lr_scheduler(
|
| 214 |
+
self,
|
| 215 |
+
optimizer,
|
| 216 |
+
num_training_steps,
|
| 217 |
+
schedule_mode="cos",
|
| 218 |
+
gamma: float = 0.999996,
|
| 219 |
+
num_warmup_steps=4000,
|
| 220 |
+
lr_end=1e-7,
|
| 221 |
+
):
|
| 222 |
+
if schedule_mode in {"exp"}:
|
| 223 |
+
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
|
| 224 |
+
if schedule_mode in {"cosine", "cos"}:
|
| 225 |
+
return transformers.get_cosine_schedule_with_warmup(
|
| 226 |
+
optimizer,
|
| 227 |
+
num_warmup_steps=num_warmup_steps,
|
| 228 |
+
num_training_steps=num_training_steps,
|
| 229 |
+
)
|
| 230 |
+
if schedule_mode in {"linear"}:
|
| 231 |
+
print("Linear schedule!")
|
| 232 |
+
return transformers.get_polynomial_decay_schedule_with_warmup(
|
| 233 |
+
optimizer,
|
| 234 |
+
num_warmup_steps=num_warmup_steps,
|
| 235 |
+
num_training_steps=num_training_steps,
|
| 236 |
+
power=1.0,
|
| 237 |
+
lr_end=lr_end,
|
| 238 |
+
)
|
| 239 |
+
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
|
| 240 |
+
|
| 241 |
+
def configure_optimizers(self):
|
| 242 |
+
"""
|
| 243 |
+
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
|
| 244 |
+
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
|
| 245 |
+
:return: dict containing optimizer and learning rate scheduler
|
| 246 |
+
"""
|
| 247 |
+
optimizer = self.get_optimizer(self.config.max_lr,
|
| 248 |
+
lr_decay=self.config.lr_decay,
|
| 249 |
+
transformer_lr=self.config.transformer_lr,
|
| 250 |
+
transformer_frozen=self.config.transformer_frozen,
|
| 251 |
+
adamw=False if self.config.no_adamw else True,
|
| 252 |
+
weight_decay=self.config.weight_decay)
|
| 253 |
+
|
| 254 |
+
num_training_steps = self.trainer.estimated_stepping_batches
|
| 255 |
+
|
| 256 |
+
scheduler = self.get_lr_scheduler(optimizer, num_training_steps,
|
| 257 |
+
schedule_mode=self.config.schedule_mode,
|
| 258 |
+
lr_end=self.config.lr_end)
|
| 259 |
+
lr_scheduler_config = {
|
| 260 |
+
"scheduler": scheduler,
|
| 261 |
+
"interval": "step",
|
| 262 |
+
"frequency": 1
|
| 263 |
+
}
|
| 264 |
+
return [optimizer], [lr_scheduler_config]
|
| 265 |
+
|
| 266 |
+
def training_step(self, train_batch, batch_idx):
|
| 267 |
+
"""
|
| 268 |
+
:param train_batch: contains one batch from train dataloader
|
| 269 |
+
:param batch_idx
|
| 270 |
+
:return: a dict containing at least loss that is used to update model parameters, can also contain
|
| 271 |
+
other items that can be processed in 'training_epoch_end' to log other metrics than loss
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
audios, labels, fnames, timestamps = train_batch
|
| 275 |
+
|
| 276 |
+
if self.config.transformer_frozen:
|
| 277 |
+
self.model.model.eval()
|
| 278 |
+
self.model.seq_model.eval()
|
| 279 |
+
mel = self.model.mel_forward(audios)
|
| 280 |
+
|
| 281 |
+
# time rolling
|
| 282 |
+
if self.config.frame_shift_range > 0:
|
| 283 |
+
mel, labels = frame_shift(
|
| 284 |
+
mel,
|
| 285 |
+
labels,
|
| 286 |
+
shift_range=self.config.frame_shift_range
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# mixup
|
| 290 |
+
if self.config.mixup_p > random.random():
|
| 291 |
+
mel, labels = mixup(
|
| 292 |
+
mel,
|
| 293 |
+
targets=labels
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# mixstyle
|
| 297 |
+
if self.config.mixstyle_p > random.random():
|
| 298 |
+
mel = mixstyle(
|
| 299 |
+
mel
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# time masking
|
| 303 |
+
if self.config.max_time_mask_size > 0:
|
| 304 |
+
mel, labels, pseudo_labels = time_mask(
|
| 305 |
+
mel,
|
| 306 |
+
labels,
|
| 307 |
+
max_mask_ratio=self.config.max_time_mask_size
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# frequency masking
|
| 311 |
+
if self.config.filter_augment_p > random.random():
|
| 312 |
+
mel, _ = filter_augmentation(
|
| 313 |
+
mel
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# frequency warping
|
| 317 |
+
if self.config.freq_warp_p > random.random():
|
| 318 |
+
mel = mel.squeeze(1)
|
| 319 |
+
mel = self.freq_warp(mel)
|
| 320 |
+
mel = mel.unsqueeze(1)
|
| 321 |
+
|
| 322 |
+
# forward through network; use strong head
|
| 323 |
+
y_hat_strong, _ = self.model(mel)
|
| 324 |
+
|
| 325 |
+
loss = self.strong_loss(y_hat_strong, labels)
|
| 326 |
+
|
| 327 |
+
# logging
|
| 328 |
+
self.log('epoch', self.current_epoch)
|
| 329 |
+
for i, param_group in enumerate(self.trainer.optimizers[0].param_groups):
|
| 330 |
+
self.log(f'trainer/lr_optimizer_{i}', param_group['lr'])
|
| 331 |
+
self.log("train/loss", loss.detach().cpu(), prog_bar=True)
|
| 332 |
+
|
| 333 |
+
return loss
|
| 334 |
+
|
| 335 |
+
def _score_step(self, batch):
|
| 336 |
+
audios, labels, fnames, timestamps = batch
|
| 337 |
+
|
| 338 |
+
strong_preds = self.forward(audios)
|
| 339 |
+
|
| 340 |
+
self.preds.append(strong_preds)
|
| 341 |
+
self.tgts.append(labels)
|
| 342 |
+
self.fnames.append(fnames)
|
| 343 |
+
self.timestamps.append(timestamps)
|
| 344 |
+
|
| 345 |
+
def _score_epoch_end(self, name="val"):
|
| 346 |
+
preds = torch.cat(self.preds)
|
| 347 |
+
tgts = torch.cat(self.tgts)
|
| 348 |
+
fnames = [item for sublist in self.fnames for item in sublist]
|
| 349 |
+
timestamps = torch.cat(self.timestamps)
|
| 350 |
+
val_loss = self.strong_loss(preds, tgts)
|
| 351 |
+
self.log(f"{name}/loss", val_loss, prog_bar=True)
|
| 352 |
+
|
| 353 |
+
# the following function expects one prediction per timestamp (sequence dimension must be flattened)
|
| 354 |
+
seq_len = preds.size(-1)
|
| 355 |
+
preds = rearrange(preds, 'bs c t -> (bs t) c').float()
|
| 356 |
+
timestamps = rearrange(timestamps, 'bs t -> (bs t)').float()
|
| 357 |
+
fnames = [fname for fname in fnames for _ in range(seq_len)]
|
| 358 |
+
|
| 359 |
+
predicted_events_by_postprocessing = get_events_for_all_files(
|
| 360 |
+
preds,
|
| 361 |
+
fnames,
|
| 362 |
+
timestamps,
|
| 363 |
+
self.idx_to_label,
|
| 364 |
+
self.postprocessing_grid
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# we only have one postprocessing configurations (aligned with HEAR challenge)
|
| 368 |
+
key = list(predicted_events_by_postprocessing.keys())[0]
|
| 369 |
+
predicted_events = predicted_events_by_postprocessing[key]
|
| 370 |
+
|
| 371 |
+
# load ground truth for test fold
|
| 372 |
+
task_path = Path(self.config.task_path)
|
| 373 |
+
test_target_events = combine_target_events(["valid" if name == "val" else "test"], task_path)
|
| 374 |
+
onset_fms = self.event_onset_200ms_fms(predicted_events, test_target_events)
|
| 375 |
+
onset_fms_50 = self.event_onset_50ms_fms(predicted_events, test_target_events)
|
| 376 |
+
segment_1s_er = self.segment_1s_er(predicted_events, test_target_events)
|
| 377 |
+
|
| 378 |
+
self.log(f"{name}/onset_fms", onset_fms[0][1])
|
| 379 |
+
self.log(f"{name}/onset_fms_50", onset_fms_50[0][1])
|
| 380 |
+
self.log(f"{name}/segment_1s_er", segment_1s_er[0][1])
|
| 381 |
+
|
| 382 |
+
# free buffers
|
| 383 |
+
self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], []
|
| 384 |
+
|
| 385 |
+
def validation_step(self, batch, batch_idx):
|
| 386 |
+
self._score_step(batch)
|
| 387 |
+
|
| 388 |
+
def on_validation_epoch_end(self):
|
| 389 |
+
self._score_epoch_end(name="val")
|
| 390 |
+
|
| 391 |
+
def test_step(self, batch, batch_idx):
|
| 392 |
+
self._score_step(batch)
|
| 393 |
+
|
| 394 |
+
def on_test_epoch_end(self):
|
| 395 |
+
self._score_epoch_end(name="test")
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def train(config):
|
| 399 |
+
# Example for fine-tuning pre-trained transformers on a downstream task.
|
| 400 |
+
|
| 401 |
+
# logging is done using wandb
|
| 402 |
+
wandb_logger = WandbLogger(
|
| 403 |
+
project="PTSED",
|
| 404 |
+
notes="Downstream Training on office sound event detection.",
|
| 405 |
+
tags=["DCASE 2016 Task 2", "Sound Event Detection"],
|
| 406 |
+
config=config,
|
| 407 |
+
name=config.experiment_name
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
train_set = get_training_dataset(config.task_path, wavmix_p=config.wavmix_p)
|
| 411 |
+
val_ds = get_validation_dataset(config.task_path)
|
| 412 |
+
test_ds = get_test_dataset(config.task_path)
|
| 413 |
+
|
| 414 |
+
# train dataloader
|
| 415 |
+
train_dl = DataLoader(dataset=train_set,
|
| 416 |
+
worker_init_fn=worker_init_fn,
|
| 417 |
+
num_workers=config.num_workers,
|
| 418 |
+
batch_size=config.batch_size,
|
| 419 |
+
shuffle=True)
|
| 420 |
+
|
| 421 |
+
# validation dataloader
|
| 422 |
+
valid_dl = DataLoader(dataset=val_ds,
|
| 423 |
+
worker_init_fn=worker_init_fn,
|
| 424 |
+
num_workers=config.num_workers,
|
| 425 |
+
batch_size=config.batch_size,
|
| 426 |
+
shuffle=False,
|
| 427 |
+
drop_last=False)
|
| 428 |
+
|
| 429 |
+
# test dataloader
|
| 430 |
+
test_dl = DataLoader(dataset=test_ds,
|
| 431 |
+
worker_init_fn=worker_init_fn,
|
| 432 |
+
num_workers=config.num_workers,
|
| 433 |
+
batch_size=config.batch_size,
|
| 434 |
+
shuffle=False,
|
| 435 |
+
drop_last=False)
|
| 436 |
+
|
| 437 |
+
# create pytorch lightening module
|
| 438 |
+
pl_module = PLModule(config)
|
| 439 |
+
|
| 440 |
+
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
|
| 441 |
+
# on which kind of device(s) to train and possible callbacks
|
| 442 |
+
trainer = pl.Trainer(max_epochs=config.n_epochs,
|
| 443 |
+
logger=wandb_logger,
|
| 444 |
+
accelerator='auto',
|
| 445 |
+
devices=config.num_devices,
|
| 446 |
+
precision=config.precision,
|
| 447 |
+
num_sanity_val_steps=0,
|
| 448 |
+
check_val_every_n_epoch=config.check_val_every_n_epoch
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# start training and validation for the specified number of epochs
|
| 452 |
+
trainer.fit(
|
| 453 |
+
pl_module,
|
| 454 |
+
train_dataloaders=train_dl,
|
| 455 |
+
val_dataloaders=valid_dl,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
test_results = trainer.test(pl_module, dataloaders=test_dl)
|
| 459 |
+
print(test_results)
|
| 460 |
+
wandb.finish()
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
if __name__ == '__main__':
|
| 464 |
+
parser = argparse.ArgumentParser(description='Configuration Parser. ')
|
| 465 |
+
|
| 466 |
+
# general
|
| 467 |
+
parser.add_argument('--task_path', type=str, required=True)
|
| 468 |
+
parser.add_argument('--experiment_name', type=str, default="DCASE2016Task2")
|
| 469 |
+
parser.add_argument('--batch_size', type=int, default=256)
|
| 470 |
+
parser.add_argument('--num_workers', type=int, default=16)
|
| 471 |
+
parser.add_argument('--num_devices', type=int, default=1)
|
| 472 |
+
parser.add_argument('--precision', type=int, default=16)
|
| 473 |
+
parser.add_argument('--check_val_every_n_epoch', type=int, default=10)
|
| 474 |
+
|
| 475 |
+
# model
|
| 476 |
+
parser.add_argument('--model_name', type=str,
|
| 477 |
+
choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"],
|
| 478 |
+
default="ATST-F") # used also for training
|
| 479 |
+
# "scratch" = no pretraining
|
| 480 |
+
# "ssl" = SSL pre-trained
|
| 481 |
+
# "weak" = AudioSet Weak pre-trained
|
| 482 |
+
# "strong" = AudioSet Strong pre-trained
|
| 483 |
+
parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"],
|
| 484 |
+
default="strong")
|
| 485 |
+
parser.add_argument('--seq_model_type', type=str, choices=["rnn"],
|
| 486 |
+
default=None)
|
| 487 |
+
parser.add_argument('--n_classes', type=int, default=11)
|
| 488 |
+
|
| 489 |
+
# training
|
| 490 |
+
parser.add_argument('--n_epochs', type=int, default=300)
|
| 491 |
+
|
| 492 |
+
# augmentation
|
| 493 |
+
parser.add_argument('--wavmix_p', type=float, default=0.5)
|
| 494 |
+
parser.add_argument('--freq_warp_p', type=float, default=0.0)
|
| 495 |
+
parser.add_argument('--filter_augment_p', type=float, default=0.0)
|
| 496 |
+
parser.add_argument('--frame_shift_range', type=float, default=0.0) # in seconds
|
| 497 |
+
parser.add_argument('--mixup_p', type=float, default=0.5)
|
| 498 |
+
parser.add_argument('--mixstyle_p', type=float, default=0.0)
|
| 499 |
+
parser.add_argument('--max_time_mask_size', type=float, default=0.0)
|
| 500 |
+
|
| 501 |
+
# optimizer
|
| 502 |
+
parser.add_argument('--no_adamw', action='store_true', default=False)
|
| 503 |
+
parser.add_argument('--weight_decay', type=float, default=0.001)
|
| 504 |
+
parser.add_argument('--transformer_frozen', action='store_true', dest='transformer_frozen',
|
| 505 |
+
default=False,
|
| 506 |
+
help='Disable training for the transformer.')
|
| 507 |
+
|
| 508 |
+
# lr schedule
|
| 509 |
+
parser.add_argument('--schedule_mode', type=str, default="cos")
|
| 510 |
+
parser.add_argument('--max_lr', type=float, default=1.06e-4)
|
| 511 |
+
parser.add_argument('--transformer_lr', type=float, default=None)
|
| 512 |
+
parser.add_argument('--lr_decay', type=float, default=1.0)
|
| 513 |
+
parser.add_argument('--lr_end', type=float, default=1e-7)
|
| 514 |
+
parser.add_argument('--warmup_steps', type=int, default=100)
|
| 515 |
+
|
| 516 |
+
args = parser.parse_args()
|
| 517 |
+
train(args)
|
helpers/augment.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.distributions.beta import Beta
|
| 7 |
+
|
| 8 |
+
def frame_shift(mels, labels, embeddings=None, pseudo_labels=None,
|
| 9 |
+
net_pooling=4, shift_range=0.125):
|
| 10 |
+
bsz, channels, n_bands, frames = mels.shape
|
| 11 |
+
abs_shift_mel = int(frames * shift_range)
|
| 12 |
+
|
| 13 |
+
if embeddings is not None:
|
| 14 |
+
embed_frames = embeddings.shape[-1]
|
| 15 |
+
embed_pool_fact = frames / embed_frames
|
| 16 |
+
|
| 17 |
+
for bindx in range(bsz):
|
| 18 |
+
shift = int(random.gauss(0, abs_shift_mel))
|
| 19 |
+
mels[bindx] = torch.roll(mels[bindx], shift, dims=-1)
|
| 20 |
+
label_shift = -abs(shift) / net_pooling if shift < 0 else shift / net_pooling
|
| 21 |
+
label_shift = round(label_shift)
|
| 22 |
+
labels[bindx] = torch.roll(labels[bindx], label_shift, dims=-1)
|
| 23 |
+
|
| 24 |
+
if pseudo_labels is not None:
|
| 25 |
+
pseudo_labels[bindx] = torch.roll(pseudo_labels[bindx], label_shift, dims=-1)
|
| 26 |
+
|
| 27 |
+
if embeddings is not None:
|
| 28 |
+
embed_shift = -abs(shift) / embed_pool_fact if shift < 0 else shift / embed_pool_fact
|
| 29 |
+
embed_shift = round(embed_shift)
|
| 30 |
+
embeddings[bindx] = torch.roll(embeddings[bindx], embed_shift, dims=-1)
|
| 31 |
+
|
| 32 |
+
out_args = [mels]
|
| 33 |
+
if embeddings is not None:
|
| 34 |
+
out_args.append(embeddings)
|
| 35 |
+
out_args.append(labels)
|
| 36 |
+
if pseudo_labels is not None:
|
| 37 |
+
out_args.append(pseudo_labels)
|
| 38 |
+
return tuple(out_args)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def time_mask(features, labels, embeddings=None, pseudo_labels=None, net_pooling=4,
|
| 42 |
+
min_mask_ratio=0.05, max_mask_ratio=0.2):
|
| 43 |
+
_, _, n_frame = labels.shape
|
| 44 |
+
|
| 45 |
+
if embeddings is not None:
|
| 46 |
+
embed_frames = embeddings.shape[-1]
|
| 47 |
+
embed_pool_fact = embed_frames / n_frame
|
| 48 |
+
|
| 49 |
+
t_width = torch.randint(low=int(n_frame * min_mask_ratio), high=int(n_frame * max_mask_ratio), size=(1,))
|
| 50 |
+
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,))
|
| 51 |
+
features[:, :, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0
|
| 52 |
+
labels[:, :, t_low:t_low+t_width] = 0
|
| 53 |
+
|
| 54 |
+
if pseudo_labels is not None:
|
| 55 |
+
labels[:, :, t_low:t_low + t_width] = 0
|
| 56 |
+
|
| 57 |
+
if embeddings is not None:
|
| 58 |
+
low = round((t_low * embed_pool_fact).item())
|
| 59 |
+
high = round(((t_low + t_width) * embed_pool_fact).item())
|
| 60 |
+
embeddings[..., low:high] = 0
|
| 61 |
+
|
| 62 |
+
out_args = [features]
|
| 63 |
+
|
| 64 |
+
if embeddings is not None:
|
| 65 |
+
out_args.append(embeddings)
|
| 66 |
+
out_args.append(labels)
|
| 67 |
+
if pseudo_labels is not None:
|
| 68 |
+
out_args.append(pseudo_labels)
|
| 69 |
+
return tuple(out_args)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def mixup(data, embeddings=None, targets=None, pseudo_strong=None, alpha=0.2, beta=0.2, return_mix_coef=False):
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
batch_size = data.size(0)
|
| 75 |
+
c = np.random.beta(alpha, beta, size=batch_size)
|
| 76 |
+
c = np.maximum(c, 1 - c)
|
| 77 |
+
|
| 78 |
+
perm = torch.randperm(batch_size)
|
| 79 |
+
cd = torch.tensor(c, dtype=data.dtype, device=data.device).view(batch_size, *([1] * (data.ndim - 1)))
|
| 80 |
+
mixed_data = cd * data + (1 - cd) * data[perm, :]
|
| 81 |
+
|
| 82 |
+
if embeddings is not None:
|
| 83 |
+
ce = torch.tensor(c, dtype=embeddings.dtype, device=embeddings.device).view(batch_size, *([1] * (embeddings.ndim - 1)))
|
| 84 |
+
mixed_embeddings = ce * embeddings + (1 - ce) * embeddings[perm, :]
|
| 85 |
+
|
| 86 |
+
if targets is not None:
|
| 87 |
+
ct = torch.tensor(c, dtype=data.dtype, device=data.device).view(batch_size, *([1] * (targets.ndim - 1)))
|
| 88 |
+
mixed_target = torch.clamp(
|
| 89 |
+
ct * targets + (1 - ct) * targets[perm, :], min=0, max=1
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if pseudo_strong is not None:
|
| 93 |
+
cp = torch.tensor(c, dtype=pseudo_strong.dtype, device=pseudo_strong.device).view(batch_size,
|
| 94 |
+
*([1] * (pseudo_strong.ndim - 1)))
|
| 95 |
+
mixed_pseudo_strong = cp * pseudo_strong + (1 - cp) * pseudo_strong[perm, :]
|
| 96 |
+
|
| 97 |
+
out_args = [mixed_data]
|
| 98 |
+
if embeddings is not None:
|
| 99 |
+
out_args.append(mixed_embeddings)
|
| 100 |
+
if targets is not None:
|
| 101 |
+
out_args.append(mixed_target)
|
| 102 |
+
if pseudo_strong is not None:
|
| 103 |
+
out_args.append(mixed_pseudo_strong)
|
| 104 |
+
|
| 105 |
+
if return_mix_coef:
|
| 106 |
+
out_args.append(perm)
|
| 107 |
+
out_args.append(c)
|
| 108 |
+
return tuple(out_args)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def filt_aug_(features, db_range=(-6, 6), n_band=(3, 6), min_bw=6):
|
| 112 |
+
batch_size, channels, n_freq_bin, _ = features.shape
|
| 113 |
+
n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() # [low, high)
|
| 114 |
+
if n_freq_band > 1:
|
| 115 |
+
while n_freq_bin - n_freq_band * min_bw + 1 < 0:
|
| 116 |
+
min_bw -= 1
|
| 117 |
+
band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1,
|
| 118 |
+
(n_freq_band - 1,)))[0] + \
|
| 119 |
+
torch.arange(1, n_freq_band) * min_bw
|
| 120 |
+
band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin])))
|
| 121 |
+
|
| 122 |
+
band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0]
|
| 123 |
+
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
|
| 124 |
+
for i in range(n_freq_band):
|
| 125 |
+
for j in range(batch_size):
|
| 126 |
+
freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \
|
| 127 |
+
torch.linspace(band_factors[j, i], band_factors[j, i+1],
|
| 128 |
+
band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1)
|
| 129 |
+
freq_filt = 10 ** (freq_filt / 20)
|
| 130 |
+
return features * freq_filt.unsqueeze(1)
|
| 131 |
+
else:
|
| 132 |
+
return features
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def filter_augmentation(features, n_transform=1, filter_db_range=(-6, 6),
|
| 136 |
+
filter_bands=(3, 6), filter_minimum_bandwidth=6):
|
| 137 |
+
if n_transform == 2:
|
| 138 |
+
feature_list = []
|
| 139 |
+
for _ in range(n_transform):
|
| 140 |
+
features_temp = features
|
| 141 |
+
features_temp = filt_aug_(features_temp, db_range=filter_db_range, n_band=filter_bands,
|
| 142 |
+
min_bw=filter_minimum_bandwidth)
|
| 143 |
+
feature_list.append(features_temp)
|
| 144 |
+
return feature_list
|
| 145 |
+
elif n_transform == 1:
|
| 146 |
+
features = filt_aug_(features, db_range=filter_db_range, n_band=filter_bands,
|
| 147 |
+
min_bw=filter_minimum_bandwidth)
|
| 148 |
+
return [features, features]
|
| 149 |
+
else:
|
| 150 |
+
return [features, features]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def mixstyle(x, alpha=0.4, eps=1e-6):
|
| 154 |
+
batch_size = x.size(0)
|
| 155 |
+
|
| 156 |
+
# frequency-wise statistics
|
| 157 |
+
f_mu = x.mean(dim=3, keepdim=True)
|
| 158 |
+
f_var = x.var(dim=3, keepdim=True)
|
| 159 |
+
|
| 160 |
+
f_sig = (f_var + eps).sqrt() # compute instance standard deviation
|
| 161 |
+
f_mu, f_sig = f_mu.detach(), f_sig.detach() # block gradients
|
| 162 |
+
x_normed = (x - f_mu) / f_sig # normalize input
|
| 163 |
+
lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device, dtype=x.dtype) # sample instance-wise convex weights
|
| 164 |
+
lmda = torch.max(lmda, 1-lmda)
|
| 165 |
+
perm = torch.randperm(batch_size).to(x.device) # generate shuffling indices
|
| 166 |
+
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] # shuffling
|
| 167 |
+
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) # generate mixed mean
|
| 168 |
+
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) # generate mixed standard deviation
|
| 169 |
+
x = x_normed * sig_mix + mu_mix # denormalize input using the mixed statistics
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class RandomResizeCrop(nn.Module):
|
| 174 |
+
"""Random Resize Crop block.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
virtual_crop_scale: Virtual crop area `(F ratio, T ratio)` in ratio to input size.
|
| 178 |
+
freq_scale: Random frequency range `(min, max)`.
|
| 179 |
+
time_scale: Random time frame range `(min, max)`.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, virtual_crop_scale=(1.0, 1.5), freq_scale=(0.6, 1.0), time_scale=(0.6, 1.5)):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.virtual_crop_scale = virtual_crop_scale
|
| 185 |
+
self.freq_scale = freq_scale
|
| 186 |
+
self.time_scale = time_scale
|
| 187 |
+
self.interpolation = 'bicubic'
|
| 188 |
+
assert time_scale[1] >= 1.0 and freq_scale[1] >= 1.0
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
def get_params(virtual_crop_size, in_size, time_scale, freq_scale):
|
| 192 |
+
canvas_h, canvas_w = virtual_crop_size
|
| 193 |
+
src_h, src_w = in_size
|
| 194 |
+
h = np.clip(int(np.random.uniform(*freq_scale) * src_h), 1, canvas_h)
|
| 195 |
+
w = np.clip(int(np.random.uniform(*time_scale) * src_w), 1, canvas_w)
|
| 196 |
+
i = random.randint(0, canvas_h - h) if canvas_h > h else 0
|
| 197 |
+
j = random.randint(0, canvas_w - w) if canvas_w > w else 0
|
| 198 |
+
return i, j, h, w
|
| 199 |
+
|
| 200 |
+
def forward(self, lms):
|
| 201 |
+
# spec_output = []
|
| 202 |
+
# for lms in specs:
|
| 203 |
+
# lms = lms.unsqueeze(0)
|
| 204 |
+
# make virtual_crop_arear empty space (virtual crop area) and copy the input log mel spectrogram to th the center
|
| 205 |
+
virtual_crop_size = [int(s * c) for s, c in zip(lms.shape[-2:], self.virtual_crop_scale)]
|
| 206 |
+
virtual_crop_area = (torch.zeros((lms.shape[0], virtual_crop_size[0], virtual_crop_size[1]))
|
| 207 |
+
.to(torch.float).to(lms.device))
|
| 208 |
+
_, lh, lw = virtual_crop_area.shape
|
| 209 |
+
c, h, w = lms.shape
|
| 210 |
+
x, y = (lw - w) // 2, (lh - h) // 2
|
| 211 |
+
virtual_crop_area[:, y:y+h, x:x+w] = lms
|
| 212 |
+
# get random area
|
| 213 |
+
i, j, h, w = self.get_params(virtual_crop_area.shape[-2:], lms.shape[-2:], self.time_scale, self.freq_scale)
|
| 214 |
+
crop = virtual_crop_area[:, i:i+h, j:j+w]
|
| 215 |
+
# print(f'shapes {virtual_crop_area.shape} {crop.shape} -> {lms.shape}')
|
| 216 |
+
lms = F.interpolate(crop.unsqueeze(1), size=lms.shape[-2:],
|
| 217 |
+
mode=self.interpolation, align_corners=True).squeeze(1)
|
| 218 |
+
# spec_output.append(lms.float())
|
| 219 |
+
return lms.float() # torch.concat(lms, dim=0)
|
| 220 |
+
|
| 221 |
+
def __repr__(self):
|
| 222 |
+
format_string = self.__class__.__name__ + f'(virtual_crop_size={self.virtual_crop_scale}'
|
| 223 |
+
format_string += ', time_scale={0}'.format(tuple(round(s, 4) for s in self.time_scale))
|
| 224 |
+
format_string += ', freq_scale={0})'.format(tuple(round(r, 4) for r in self.freq_scale))
|
| 225 |
+
return format_string
|
helpers/decode.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code from:
|
| 3 |
+
https://github.com/DCASE-REPO/DESED_task
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import scipy
|
| 11 |
+
from sed_scores_eval.base_modules.scores import create_score_dataframe
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def batched_decode_preds(
|
| 15 |
+
strong_preds,
|
| 16 |
+
filenames,
|
| 17 |
+
encoder,
|
| 18 |
+
thresholds=[0.5],
|
| 19 |
+
median_filter=None,
|
| 20 |
+
pad_indx=None,
|
| 21 |
+
):
|
| 22 |
+
"""Decode a batch of predictions to dataframes. Each threshold gives a different dataframe and stored in a
|
| 23 |
+
dictionary
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
strong_preds: torch.Tensor, batch of strong predictions.
|
| 27 |
+
filenames: list, the list of filenames of the current batch.
|
| 28 |
+
encoder: ManyHotEncoder object, object used to decode predictions.
|
| 29 |
+
thresholds: list, the list of thresholds to be used for predictions.
|
| 30 |
+
median_filter: int, the number of frames for which to apply median window (smoothing).
|
| 31 |
+
pad_indx: list, the list of indexes which have been used for padding.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
dict of predictions, each keys is a threshold and the value is the DataFrame of predictions.
|
| 35 |
+
"""
|
| 36 |
+
# Init a dataframe per threshold
|
| 37 |
+
scores_raw = {}
|
| 38 |
+
scores_postprocessed = {}
|
| 39 |
+
prediction_dfs = {}
|
| 40 |
+
for threshold in thresholds:
|
| 41 |
+
prediction_dfs[threshold] = pd.DataFrame()
|
| 42 |
+
|
| 43 |
+
for j in range(strong_preds.shape[0]): # over batches
|
| 44 |
+
audio_id = Path(filenames[j]).stem
|
| 45 |
+
filename = audio_id + ".wav"
|
| 46 |
+
c_scores = strong_preds[j]
|
| 47 |
+
if pad_indx is not None:
|
| 48 |
+
true_len = int(c_scores.shape[-1] * pad_indx[j].item())
|
| 49 |
+
c_scores = c_scores[:true_len]
|
| 50 |
+
c_scores = c_scores.transpose(0, 1).detach().cpu().numpy()
|
| 51 |
+
scores_raw[audio_id] = create_score_dataframe(
|
| 52 |
+
scores=c_scores,
|
| 53 |
+
timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)),
|
| 54 |
+
event_classes=encoder.labels,
|
| 55 |
+
)
|
| 56 |
+
if median_filter is not None:
|
| 57 |
+
c_scores = scipy.ndimage.filters.median_filter(c_scores, (median_filter, 1))
|
| 58 |
+
scores_postprocessed[audio_id] = create_score_dataframe(
|
| 59 |
+
scores=c_scores,
|
| 60 |
+
timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)),
|
| 61 |
+
event_classes=encoder.labels,
|
| 62 |
+
)
|
| 63 |
+
for c_th in thresholds:
|
| 64 |
+
pred = c_scores > c_th
|
| 65 |
+
pred = encoder.decode_strong(pred)
|
| 66 |
+
pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"])
|
| 67 |
+
pred["filename"] = filename
|
| 68 |
+
prediction_dfs[c_th] = pd.concat(
|
| 69 |
+
[prediction_dfs[c_th], pred], ignore_index=True
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return scores_raw, scores_postprocessed, prediction_dfs
|
helpers/encode.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code from:
|
| 3 |
+
https://github.com/DCASE-REPO/DESED_task
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from dcase_util.data import DecisionEncoder
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ManyHotEncoder:
|
| 14 |
+
""""
|
| 15 |
+
Adapted after DecisionEncoder.find_contiguous_regions method in
|
| 16 |
+
https://github.com/DCASE-REPO/dcase_util/blob/master/dcase_util/data/decisions.py
|
| 17 |
+
|
| 18 |
+
Encode labels into numpy arrays where 1 correspond to presence of the class and 0 absence.
|
| 19 |
+
Multiple 1 can appear on the same line, it is for multi label problem.
|
| 20 |
+
Args:
|
| 21 |
+
labels: list, the classes which will be encoded
|
| 22 |
+
n_frames: int, (Default value = None) only useful for strong labels. The number of frames of a segment.
|
| 23 |
+
Attributes:
|
| 24 |
+
labels: list, the classes which will be encoded
|
| 25 |
+
n_frames: int, only useful for strong labels. The number of frames of a segment.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self, labels, audio_len=10, frame_hop=160, net_pooling=4, fs=16000
|
| 30 |
+
):
|
| 31 |
+
if type(labels) in [np.ndarray, np.array]:
|
| 32 |
+
labels = labels.tolist()
|
| 33 |
+
elif isinstance(labels, (dict, OrderedDict)):
|
| 34 |
+
labels = list(labels.keys())
|
| 35 |
+
self.labels = labels
|
| 36 |
+
self.audio_len = audio_len
|
| 37 |
+
self.frame_hop = frame_hop
|
| 38 |
+
self.fs = fs
|
| 39 |
+
self.net_pooling = net_pooling
|
| 40 |
+
n_frames = self.audio_len * self.fs
|
| 41 |
+
self.n_frames = int(int((n_frames / self.frame_hop)) / self.net_pooling)
|
| 42 |
+
|
| 43 |
+
def encode_weak(self, labels):
|
| 44 |
+
""" Encode a list of weak labels into a numpy array
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
labels: list, list of labels to encode (to a vector of 0 and 1)
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
numpy.array
|
| 51 |
+
A vector containing 1 for each label, and 0 everywhere else
|
| 52 |
+
"""
|
| 53 |
+
# useful for tensor empty labels
|
| 54 |
+
if type(labels) is str:
|
| 55 |
+
if labels == "empty":
|
| 56 |
+
y = np.zeros(len(self.labels)) - 1
|
| 57 |
+
return y
|
| 58 |
+
else:
|
| 59 |
+
labels = labels.split(",")
|
| 60 |
+
if type(labels) is pd.DataFrame:
|
| 61 |
+
if labels.empty:
|
| 62 |
+
labels = []
|
| 63 |
+
elif "event_label" in labels.columns:
|
| 64 |
+
labels = labels["event_label"]
|
| 65 |
+
y = np.zeros(len(self.labels))
|
| 66 |
+
for label in labels:
|
| 67 |
+
if not pd.isna(label):
|
| 68 |
+
i = self.labels.index(label)
|
| 69 |
+
y[i] = 1
|
| 70 |
+
return y
|
| 71 |
+
|
| 72 |
+
def _time_to_frame(self, time):
|
| 73 |
+
samples = time * self.fs
|
| 74 |
+
frame = (samples) / self.frame_hop
|
| 75 |
+
return np.clip(frame / self.net_pooling, a_min=0, a_max=self.n_frames)
|
| 76 |
+
|
| 77 |
+
def _frame_to_time(self, frame):
|
| 78 |
+
frame = frame * self.net_pooling / (self.fs / self.frame_hop)
|
| 79 |
+
return np.clip(frame, a_min=0, a_max=self.audio_len)
|
| 80 |
+
|
| 81 |
+
def encode_strong_df(self, label_df):
|
| 82 |
+
"""Encode a list (or pandas Dataframe or Serie) of strong labels, they correspond to a given filename
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
label_df: pandas DataFrame or Series, contains filename, onset (in frames) and offset (in frames)
|
| 86 |
+
If only filename (no onset offset) is specified, it will return the event on all the frames
|
| 87 |
+
onset and offset should be in frames
|
| 88 |
+
Returns:
|
| 89 |
+
numpy.array
|
| 90 |
+
Encoded labels, 1 where the label is present, 0 otherwise
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
assert any(
|
| 94 |
+
[x is not None for x in [self.audio_len, self.frame_hop]]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
samples_len = self.n_frames
|
| 98 |
+
if type(label_df) is str:
|
| 99 |
+
if label_df == "empty":
|
| 100 |
+
y = np.zeros((samples_len, len(self.labels))) - 1
|
| 101 |
+
return y
|
| 102 |
+
y = np.zeros((samples_len, len(self.labels)))
|
| 103 |
+
if type(label_df) is pd.DataFrame:
|
| 104 |
+
if {"onset", "offset", "event_label"}.issubset(label_df.columns):
|
| 105 |
+
for _, row in label_df.iterrows():
|
| 106 |
+
if not pd.isna(row["event_label"]):
|
| 107 |
+
i = self.labels.index(row["event_label"])
|
| 108 |
+
onset = int(self._time_to_frame(row["onset"]))
|
| 109 |
+
offset = int(np.ceil(self._time_to_frame(row["offset"])))
|
| 110 |
+
if "confidence" in label_df.columns:
|
| 111 |
+
y[onset:offset, i] = row["confidence"] # support confidence
|
| 112 |
+
else:
|
| 113 |
+
y[
|
| 114 |
+
onset:offset, i
|
| 115 |
+
] = 1 # means offset not included (hypothesis of overlapping frames, so ok)
|
| 116 |
+
|
| 117 |
+
elif type(label_df) in [
|
| 118 |
+
pd.Series,
|
| 119 |
+
list,
|
| 120 |
+
np.ndarray,
|
| 121 |
+
]: # list of list or list of strings
|
| 122 |
+
if type(label_df) is pd.Series:
|
| 123 |
+
if {"onset", "offset", "event_label"}.issubset(
|
| 124 |
+
label_df.index
|
| 125 |
+
): # means only one value
|
| 126 |
+
if not pd.isna(label_df["event_label"]):
|
| 127 |
+
i = self.labels.index(label_df["event_label"])
|
| 128 |
+
onset = int(self._time_to_frame(label_df["onset"]))
|
| 129 |
+
offset = int(np.ceil(self._time_to_frame(label_df["offset"])))
|
| 130 |
+
|
| 131 |
+
if "confidence" in label_df.columns:
|
| 132 |
+
y[onset:offset, i] = label_df["confidence"]
|
| 133 |
+
else:
|
| 134 |
+
y[onset:offset, i] = 1
|
| 135 |
+
return y
|
| 136 |
+
|
| 137 |
+
for event_label in label_df:
|
| 138 |
+
# List of string, so weak labels to be encoded in strong
|
| 139 |
+
if type(event_label) is str:
|
| 140 |
+
if event_label != "":
|
| 141 |
+
i = self.labels.index(event_label)
|
| 142 |
+
y[:, i] = 1
|
| 143 |
+
|
| 144 |
+
# List of list, with [label, onset, offset]
|
| 145 |
+
elif len(event_label) == 3:
|
| 146 |
+
if event_label[0] != "":
|
| 147 |
+
i = self.labels.index(event_label[0])
|
| 148 |
+
onset = int(self._time_to_frame(event_label[1]))
|
| 149 |
+
offset = int(np.ceil(self._time_to_frame(event_label[2])))
|
| 150 |
+
y[onset:offset, i] = 1
|
| 151 |
+
# List of list, with [label, onset, offset, confidence]
|
| 152 |
+
elif len(event_label) == 4:
|
| 153 |
+
if event_label[0] != "":
|
| 154 |
+
i = self.labels.index(event_label[0])
|
| 155 |
+
onset = int(self._time_to_frame(event_label[1]))
|
| 156 |
+
offset = int(np.ceil(self._time_to_frame(event_label[2])))
|
| 157 |
+
y[onset:offset, i] = event_label[3]
|
| 158 |
+
|
| 159 |
+
else:
|
| 160 |
+
raise NotImplementedError(
|
| 161 |
+
"cannot encode strong, type mismatch: {}".format(
|
| 162 |
+
type(event_label)
|
| 163 |
+
)
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
raise NotImplementedError(
|
| 168 |
+
"To encode_strong, type is pandas.Dataframe with onset, offset and event_label"
|
| 169 |
+
"columns, or it is a list or pandas Series of event labels, "
|
| 170 |
+
"type given: {}".format(type(label_df))
|
| 171 |
+
)
|
| 172 |
+
return y
|
| 173 |
+
|
| 174 |
+
def decode_weak(self, labels):
|
| 175 |
+
""" Decode the encoded weak labels
|
| 176 |
+
Args:
|
| 177 |
+
labels: numpy.array, the encoded labels to be decoded
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
list
|
| 181 |
+
Decoded labels, list of string
|
| 182 |
+
|
| 183 |
+
"""
|
| 184 |
+
result_labels = []
|
| 185 |
+
for i, value in enumerate(labels):
|
| 186 |
+
if value == 1:
|
| 187 |
+
result_labels.append(self.labels[i])
|
| 188 |
+
return result_labels
|
| 189 |
+
|
| 190 |
+
def decode_strong(self, labels):
|
| 191 |
+
""" Decode the encoded strong labels
|
| 192 |
+
Args:
|
| 193 |
+
labels: numpy.array, the encoded labels to be decoded
|
| 194 |
+
Returns:
|
| 195 |
+
list
|
| 196 |
+
Decoded labels, list of list: [[label, onset offset], ...]
|
| 197 |
+
|
| 198 |
+
"""
|
| 199 |
+
result_labels = []
|
| 200 |
+
for i, label_column in enumerate(labels.T):
|
| 201 |
+
change_indices = DecisionEncoder().find_contiguous_regions(label_column)
|
| 202 |
+
|
| 203 |
+
# append [label, onset, offset] in the result list
|
| 204 |
+
for row in change_indices:
|
| 205 |
+
result_labels.append(
|
| 206 |
+
[
|
| 207 |
+
self.labels[i],
|
| 208 |
+
self._frame_to_time(row[0]),
|
| 209 |
+
self._frame_to_time(row[1]),
|
| 210 |
+
]
|
| 211 |
+
)
|
| 212 |
+
return result_labels
|
| 213 |
+
|
| 214 |
+
def state_dict(self):
|
| 215 |
+
return {
|
| 216 |
+
"labels": self.labels,
|
| 217 |
+
"audio_len": self.audio_len,
|
| 218 |
+
"frame_hop": self.frame_hop,
|
| 219 |
+
"net_pooling": self.net_pooling,
|
| 220 |
+
"fs": self.fs,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def load_state_dict(cls, state_dict):
|
| 225 |
+
labels = state_dict["labels"]
|
| 226 |
+
audio_len = state_dict["audio_len"]
|
| 227 |
+
frame_hop = state_dict["frame_hop"]
|
| 228 |
+
net_pooling = state_dict["net_pooling"]
|
| 229 |
+
fs = state_dict["fs"]
|
| 230 |
+
return cls(labels, audio_len, frame_hop, net_pooling, fs)
|
helpers/score.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
score functions from: https://hearbenchmark.com/hear-tasks.html
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from collections import ChainMap
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Optional, Tuple, Union, List, Any
|
| 9 |
+
|
| 10 |
+
import more_itertools
|
| 11 |
+
import numpy as np
|
| 12 |
+
import sed_eval
|
| 13 |
+
import torch
|
| 14 |
+
from dcase_util.containers import MetaDataContainer
|
| 15 |
+
from scipy.ndimage import median_filter
|
| 16 |
+
from sklearn.model_selection import ParameterGrid
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def validate_score_return_type(ret: Union[Tuple[Tuple[str, float], ...], float]):
|
| 21 |
+
"""
|
| 22 |
+
Valid return types for the metric are
|
| 23 |
+
- tuple(tuple(string: name of the subtype, float: the value)): This is the
|
| 24 |
+
case with sed eval metrics. They can return (("f_measure", value),
|
| 25 |
+
("precision", value), ...), depending on the scores
|
| 26 |
+
the metric should is supposed to return. This is set as `scores`
|
| 27 |
+
attribute in the metric.
|
| 28 |
+
- float: Standard metric behaviour
|
| 29 |
+
|
| 30 |
+
The downstream prediction pipeline is able to handle these two types.
|
| 31 |
+
In case of the tuple return type, the value of the first entry in the
|
| 32 |
+
tuple will be used as an optimisation criterion wherever required.
|
| 33 |
+
For instance, if the return is (("f_measure", value), ("precision", value)),
|
| 34 |
+
the value corresponding to the f_measure will be used ( for instance in
|
| 35 |
+
early stopping if this metric is the primary score for the task )
|
| 36 |
+
"""
|
| 37 |
+
if isinstance(ret, tuple):
|
| 38 |
+
assert all(
|
| 39 |
+
type(s) == tuple and type(s[0]) == str and type(s[1]) == float for s in ret
|
| 40 |
+
), (
|
| 41 |
+
"If the return type of the score is a tuple, all the elements "
|
| 42 |
+
"in the tuple should be tuple of type (string, float)"
|
| 43 |
+
)
|
| 44 |
+
elif isinstance(ret, float):
|
| 45 |
+
pass
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"Return type {type(ret)} is unexpected. Return type of "
|
| 49 |
+
"the score function should either be a "
|
| 50 |
+
"tuple(tuple) or float. "
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ScoreFunction:
|
| 55 |
+
"""
|
| 56 |
+
A simple abstract base class for score functions
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
# TODO: Remove label_to_idx?
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
label_to_idx: Dict[str, int],
|
| 63 |
+
name: Optional[str] = None,
|
| 64 |
+
maximize: bool = True,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
:param label_to_idx: Map from label string to integer index.
|
| 68 |
+
:param name: Override the name of this scoring function.
|
| 69 |
+
:param maximize: Maximize this score? (Otherwise, it's a loss or energy
|
| 70 |
+
we want to minimize, and I guess technically isn't a score.)
|
| 71 |
+
"""
|
| 72 |
+
self.label_to_idx = label_to_idx
|
| 73 |
+
if name:
|
| 74 |
+
self.name = name
|
| 75 |
+
self.maximize = maximize
|
| 76 |
+
|
| 77 |
+
def __call__(self, *args, **kwargs) -> Union[Tuple[Tuple[str, float], ...], float]:
|
| 78 |
+
"""
|
| 79 |
+
Calls the compute function of the metric, and after validating the output,
|
| 80 |
+
returns the metric score
|
| 81 |
+
"""
|
| 82 |
+
ret = self._compute(*args, **kwargs)
|
| 83 |
+
validate_score_return_type(ret)
|
| 84 |
+
return ret
|
| 85 |
+
|
| 86 |
+
def _compute(
|
| 87 |
+
self, predictions: Any, targets: Any, **kwargs
|
| 88 |
+
) -> Union[Tuple[Tuple[str, float], ...], float]:
|
| 89 |
+
"""
|
| 90 |
+
Compute the score based on the predictions and targets.
|
| 91 |
+
This is a private function and the metric should be used as a functor
|
| 92 |
+
by calling the `__call__` method which calls this and also validates
|
| 93 |
+
the return type
|
| 94 |
+
"""
|
| 95 |
+
raise NotImplementedError("Inheriting classes must implement this function")
|
| 96 |
+
|
| 97 |
+
def __str__(self):
|
| 98 |
+
return self.name
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class SoundEventScore(ScoreFunction):
|
| 102 |
+
"""
|
| 103 |
+
Scores for sound event detection tasks using sed_eval
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
# Score class must be defined in inheriting classes
|
| 107 |
+
score_class: sed_eval.sound_event.SoundEventMetrics = None
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
label_to_idx: Dict[str, int],
|
| 112 |
+
scores: Tuple[str],
|
| 113 |
+
params: Dict = None,
|
| 114 |
+
name: Optional[str] = None,
|
| 115 |
+
maximize: bool = True,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
:param scores: Scores to use, from the list of overall SED eval scores.
|
| 119 |
+
The first score in the tuple will be the primary score for this metric
|
| 120 |
+
:param params: Parameters to pass to the scoring function,
|
| 121 |
+
see inheriting children for details.
|
| 122 |
+
"""
|
| 123 |
+
if params is None:
|
| 124 |
+
params = {}
|
| 125 |
+
super().__init__(label_to_idx=label_to_idx, name=name, maximize=maximize)
|
| 126 |
+
self.scores = scores
|
| 127 |
+
self.params = params
|
| 128 |
+
assert self.score_class is not None
|
| 129 |
+
|
| 130 |
+
def _compute(
|
| 131 |
+
self, predictions: Dict, targets: Dict, **kwargs
|
| 132 |
+
) -> Tuple[Tuple[str, float], ...]:
|
| 133 |
+
# Containers of events for sed_eval
|
| 134 |
+
reference_event_list = self.sed_eval_event_container(targets)
|
| 135 |
+
estimated_event_list = self.sed_eval_event_container(predictions)
|
| 136 |
+
|
| 137 |
+
# This will break in Python < 3.6 if the dict order is not
|
| 138 |
+
# the insertion order I think. I'm a little worried about this line
|
| 139 |
+
scores = self.score_class(
|
| 140 |
+
event_label_list=list(self.label_to_idx.keys()), **self.params
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
for filename in predictions:
|
| 144 |
+
scores.evaluate(
|
| 145 |
+
reference_event_list=reference_event_list.filter(filename=filename),
|
| 146 |
+
estimated_event_list=estimated_event_list.filter(filename=filename),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# results_overall_metrics return a pretty large nested selection of scores,
|
| 150 |
+
# with dicts of scores keyed on the type of scores, like f_measure, error_rate,
|
| 151 |
+
# accuracy
|
| 152 |
+
nested_overall_scores: Dict[
|
| 153 |
+
str, Dict[str, float]
|
| 154 |
+
] = scores.results_overall_metrics()
|
| 155 |
+
# Open up nested overall scores
|
| 156 |
+
overall_scores: Dict[str, float] = dict(
|
| 157 |
+
ChainMap(*nested_overall_scores.values())
|
| 158 |
+
)
|
| 159 |
+
# Return the required scores as tuples. The scores are returned in the
|
| 160 |
+
# order they are passed in the `scores` argument
|
| 161 |
+
return tuple([(score, overall_scores[score]) for score in self.scores])
|
| 162 |
+
|
| 163 |
+
@staticmethod
|
| 164 |
+
def sed_eval_event_container(
|
| 165 |
+
x: Dict[str, List[Dict[str, Any]]]
|
| 166 |
+
) -> MetaDataContainer:
|
| 167 |
+
# Reformat event list for sed_eval
|
| 168 |
+
reference_events = []
|
| 169 |
+
for filename, event_list in x.items():
|
| 170 |
+
for event in event_list:
|
| 171 |
+
reference_events.append(
|
| 172 |
+
{
|
| 173 |
+
# Convert from ms to seconds for sed_eval
|
| 174 |
+
"event_label": str(event["label"]),
|
| 175 |
+
"event_onset": event["start"] / 1000.0,
|
| 176 |
+
"event_offset": event["end"] / 1000.0,
|
| 177 |
+
"file": filename,
|
| 178 |
+
}
|
| 179 |
+
)
|
| 180 |
+
return MetaDataContainer(reference_events)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class EventBasedScore(SoundEventScore):
|
| 184 |
+
"""
|
| 185 |
+
event-based scores - the ground truth and system output are compared at
|
| 186 |
+
event instance level;
|
| 187 |
+
|
| 188 |
+
See https://tut-arg.github.io/sed_eval/generated/sed_eval.sound_event.EventBasedMetrics.html # noqa: E501
|
| 189 |
+
for params.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
score_class = sed_eval.sound_event.EventBasedMetrics
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class SegmentBasedScore(SoundEventScore):
|
| 196 |
+
"""
|
| 197 |
+
segment-based scores - the ground truth and system output are compared in a
|
| 198 |
+
fixed time grid; sound events are marked as active or inactive in each segment;
|
| 199 |
+
|
| 200 |
+
See https://tut-arg.github.io/sed_eval/sound_event.html#sed_eval.sound_event.SegmentBasedMetrics # noqa: E501
|
| 201 |
+
for params.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
score_class = sed_eval.sound_event.SegmentBasedMetrics
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_events_for_all_files(
|
| 208 |
+
predictions: torch.Tensor,
|
| 209 |
+
filenames: List[str],
|
| 210 |
+
timestamps: torch.Tensor,
|
| 211 |
+
idx_to_label: Dict[int, str],
|
| 212 |
+
postprocessing_grid: Dict[str, List[float]],
|
| 213 |
+
postprocessing: Optional[Tuple[Tuple[str, Any], ...]] = None,
|
| 214 |
+
) -> Dict[Tuple[Tuple[str, Any], ...], Dict[str, List[Dict[str, Union[str, float]]]]]:
|
| 215 |
+
"""
|
| 216 |
+
Produces lists of events from a set of frame based label probabilities.
|
| 217 |
+
The input prediction tensor may contain frame predictions from a set of different
|
| 218 |
+
files concatenated together. file_timestamps has a list of filenames and
|
| 219 |
+
timestamps for each frame in the predictions tensor.
|
| 220 |
+
|
| 221 |
+
We split the predictions into separate tensors based on the filename and compute
|
| 222 |
+
events based on those individually.
|
| 223 |
+
|
| 224 |
+
If no postprocessing is specified (during training), we try a
|
| 225 |
+
variety of ways of postprocessing the predictions into events,
|
| 226 |
+
from the postprocessing_grid including median filtering and
|
| 227 |
+
minimum event length.
|
| 228 |
+
|
| 229 |
+
If postprocessing is specified (during test, chosen at the best
|
| 230 |
+
validation epoch), we use this postprocessing.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
predictions: a tensor of frame based multi-label predictions.
|
| 234 |
+
filenames: a list of filenames where each entry corresponds
|
| 235 |
+
to a frame in the predictions tensor.
|
| 236 |
+
timestamps: a list of timestamps where each entry corresponds
|
| 237 |
+
to a frame in the predictions tensor.
|
| 238 |
+
idx_to_label: Index to label mapping.
|
| 239 |
+
postprocessing: See above.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
A dictionary from filtering params to the following values:
|
| 243 |
+
A dictionary of lists of events keyed on the filename slug.
|
| 244 |
+
The event list is of dicts of the following format:
|
| 245 |
+
{"label": str, "start": float ms, "end": float ms}
|
| 246 |
+
"""
|
| 247 |
+
# This probably could be more efficient if we make the assumption that
|
| 248 |
+
# timestamps are in sorted order. But this makes sure of it.
|
| 249 |
+
assert predictions.shape[0] == len(filenames)
|
| 250 |
+
assert predictions.shape[0] == len(timestamps)
|
| 251 |
+
event_files: Dict[str, Dict[float, torch.Tensor]] = {}
|
| 252 |
+
for i, (filename, timestamp) in enumerate(zip(filenames, timestamps)):
|
| 253 |
+
slug = Path(filename).name
|
| 254 |
+
|
| 255 |
+
# Key on the slug to be consistent with the ground truth
|
| 256 |
+
if slug not in event_files:
|
| 257 |
+
event_files[slug] = {}
|
| 258 |
+
|
| 259 |
+
# Save the predictions for the file keyed on the timestamp
|
| 260 |
+
event_files[slug][float(timestamp)] = predictions[i]
|
| 261 |
+
|
| 262 |
+
# Create events for all the different files. Store all the events as a dictionary
|
| 263 |
+
# with the same format as the ground truth from the luigi pipeline.
|
| 264 |
+
# Ex) { slug -> [{"label" : "woof", "start": 0.0, "end": 2.32}, ...], ...}
|
| 265 |
+
event_dict: Dict[
|
| 266 |
+
Tuple[Tuple[str, Any], ...], Dict[str, List[Dict[str, Union[float, str]]]]
|
| 267 |
+
] = {}
|
| 268 |
+
if postprocessing:
|
| 269 |
+
postprocess = postprocessing
|
| 270 |
+
event_dict[postprocess] = {}
|
| 271 |
+
for slug, timestamp_predictions in event_files.items():
|
| 272 |
+
event_dict[postprocess][slug] = create_events_from_prediction(
|
| 273 |
+
timestamp_predictions, idx_to_label, **dict(postprocess)
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
postprocessing_confs = list(ParameterGrid(postprocessing_grid))
|
| 277 |
+
for postprocess_dict in tqdm(postprocessing_confs):
|
| 278 |
+
postprocess = tuple(postprocess_dict.items())
|
| 279 |
+
event_dict[postprocess] = {}
|
| 280 |
+
for slug, timestamp_predictions in event_files.items():
|
| 281 |
+
event_dict[postprocess][slug] = create_events_from_prediction(
|
| 282 |
+
timestamp_predictions, idx_to_label, **postprocess_dict
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return event_dict
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def create_events_from_prediction(
|
| 289 |
+
prediction_dict: Dict[float, torch.Tensor],
|
| 290 |
+
idx_to_label: Dict[int, str],
|
| 291 |
+
threshold: float = 0.5,
|
| 292 |
+
median_filter_ms: float = 150,
|
| 293 |
+
min_duration: float = 60.0,
|
| 294 |
+
) -> List[Dict[str, Union[float, str]]]:
|
| 295 |
+
"""
|
| 296 |
+
Takes a set of prediction tensors keyed on timestamps and generates events.
|
| 297 |
+
(This is for one particular audio scene.)
|
| 298 |
+
We convert the prediction tensor to a binary label based on the threshold value. Any
|
| 299 |
+
events occurring at adjacent timestamps are considered to be part of the same event.
|
| 300 |
+
This loops through and creates events for each label class.
|
| 301 |
+
We optionally apply median filtering to predictions.
|
| 302 |
+
We disregard events that are less than the min_duration milliseconds.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
prediction_dict: A dictionary of predictions keyed on timestamp
|
| 306 |
+
{timestamp -> prediction}. The prediction is a tensor of label
|
| 307 |
+
probabilities.
|
| 308 |
+
idx_to_label: Index to label mapping.
|
| 309 |
+
threshold: Threshold for determining whether to apply a label
|
| 310 |
+
min_duration: the minimum duration in milliseconds for an
|
| 311 |
+
event to be included.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
A list of dicts withs keys "label", "start", and "end"
|
| 315 |
+
"""
|
| 316 |
+
# Make sure the timestamps are in the correct order
|
| 317 |
+
timestamps = np.array(sorted(prediction_dict.keys()))
|
| 318 |
+
|
| 319 |
+
# Create a sorted numpy matrix of frame level predictions for this file. We convert
|
| 320 |
+
# to a numpy array here before applying a median filter.
|
| 321 |
+
predictions = np.stack(
|
| 322 |
+
[prediction_dict[t].detach().cpu().numpy() for t in timestamps]
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Optionally apply a median filter here to smooth out events.
|
| 326 |
+
ts_diff = np.mean(np.diff(timestamps))
|
| 327 |
+
if median_filter_ms:
|
| 328 |
+
filter_width = int(round(median_filter_ms / ts_diff))
|
| 329 |
+
if filter_width:
|
| 330 |
+
predictions = median_filter(predictions, size=(filter_width, 1))
|
| 331 |
+
|
| 332 |
+
# Convert probabilities to binary vectors based on threshold
|
| 333 |
+
predictions = (predictions > threshold).astype(np.int8)
|
| 334 |
+
|
| 335 |
+
events = []
|
| 336 |
+
for label in range(predictions.shape[1]):
|
| 337 |
+
for group in more_itertools.consecutive_groups(
|
| 338 |
+
np.where(predictions[:, label])[0]
|
| 339 |
+
):
|
| 340 |
+
grouptuple = tuple(group)
|
| 341 |
+
assert (
|
| 342 |
+
tuple(sorted(grouptuple)) == grouptuple
|
| 343 |
+
), f"{sorted(grouptuple)} != {grouptuple}"
|
| 344 |
+
startidx, endidx = (grouptuple[0], grouptuple[-1])
|
| 345 |
+
|
| 346 |
+
start = timestamps[startidx]
|
| 347 |
+
end = timestamps[endidx]
|
| 348 |
+
# Add event if greater than the minimum duration threshold
|
| 349 |
+
if end - start >= min_duration:
|
| 350 |
+
events.append(
|
| 351 |
+
{"label": idx_to_label[label], "start": start, "end": end}
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# This is just for pretty output, not really necessary
|
| 355 |
+
events.sort(key=lambda k: k["start"])
|
| 356 |
+
return events
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def combine_target_events(split_names: List[str], task_path):
|
| 360 |
+
"""
|
| 361 |
+
This combines the target events from the list of splits and
|
| 362 |
+
returns the combined target events. This is useful when combining
|
| 363 |
+
multiple folds of data to create the training or validation
|
| 364 |
+
dataloader. For example, in k-fold, the training data-loader
|
| 365 |
+
might be made from the first 4/5 folds, and calling this function
|
| 366 |
+
with [fold00, fold01, fold02, fold03] will return the
|
| 367 |
+
aggregated target events across all the folds
|
| 368 |
+
"""
|
| 369 |
+
combined_target_events: Dict = {}
|
| 370 |
+
for split_name in split_names:
|
| 371 |
+
target_events = json.load(
|
| 372 |
+
task_path.joinpath(f"{split_name}.json").open()
|
| 373 |
+
)
|
| 374 |
+
common_keys = set(combined_target_events.keys()).intersection(
|
| 375 |
+
target_events.keys()
|
| 376 |
+
)
|
| 377 |
+
assert len(common_keys) == 0, (
|
| 378 |
+
"Target events from one split should not override "
|
| 379 |
+
"target events from another. This is very unlikely as the "
|
| 380 |
+
"target_event is keyed on the files which are distinct for "
|
| 381 |
+
"each split"
|
| 382 |
+
)
|
| 383 |
+
combined_target_events.update(target_events)
|
| 384 |
+
return combined_target_events
|
helpers/utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def worker_init_fn(x):
|
| 7 |
+
seed = (torch.initial_seed() + x * 1000) % 2 ** 31 # problem with nearly seeded randoms
|
| 8 |
+
|
| 9 |
+
np.random.seed(seed)
|
| 10 |
+
random.seed(seed)
|
| 11 |
+
torch.manual_seed(seed)
|
| 12 |
+
return
|
images/downstream_task_results.png
ADDED
|
inference.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import librosa
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from data_util import audioset_classes
|
| 6 |
+
from helpers.decode import batched_decode_preds
|
| 7 |
+
from helpers.encode import ManyHotEncoder
|
| 8 |
+
from models.atstframe.ATSTF_wrapper import ATSTWrapper
|
| 9 |
+
from models.beats.BEATs_wrapper import BEATsWrapper
|
| 10 |
+
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
|
| 11 |
+
from models.m2d.M2D_wrapper import M2DWrapper
|
| 12 |
+
from models.asit.ASIT_wrapper import ASiTWrapper
|
| 13 |
+
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
|
| 14 |
+
from models.prediction_wrapper import PredictionsWrapper
|
| 15 |
+
from models.frame_mn.utils import NAME_TO_WIDTH
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def sound_event_detection(args):
|
| 19 |
+
"""
|
| 20 |
+
Running Sound Event Detection on an audio clip.
|
| 21 |
+
"""
|
| 22 |
+
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
|
| 23 |
+
model_name = args.model_name
|
| 24 |
+
|
| 25 |
+
if model_name == "BEATs":
|
| 26 |
+
beats = BEATsWrapper()
|
| 27 |
+
model = PredictionsWrapper(beats, checkpoint="BEATs_strong_1")
|
| 28 |
+
elif model_name == "ATST-F":
|
| 29 |
+
atst = ATSTWrapper()
|
| 30 |
+
model = PredictionsWrapper(atst, checkpoint="ATST-F_strong_1")
|
| 31 |
+
elif model_name == "fpasst":
|
| 32 |
+
fpasst = FPaSSTWrapper()
|
| 33 |
+
model = PredictionsWrapper(fpasst, checkpoint="fpasst_strong_1")
|
| 34 |
+
elif model_name == "M2D":
|
| 35 |
+
m2d = M2DWrapper()
|
| 36 |
+
model = PredictionsWrapper(m2d, checkpoint="M2D_strong_1", embed_dim=m2d.m2d.cfg.feature_d)
|
| 37 |
+
elif model_name == "ASIT":
|
| 38 |
+
asit = ASiTWrapper()
|
| 39 |
+
model = PredictionsWrapper(asit, checkpoint="ASIT_strong_1")
|
| 40 |
+
elif model_name.startswith("frame_mn"):
|
| 41 |
+
width = NAME_TO_WIDTH(model_name)
|
| 42 |
+
frame_mn = FrameMNWrapper(width)
|
| 43 |
+
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
|
| 44 |
+
model = PredictionsWrapper(frame_mn, checkpoint=f"{model_name}_strong_1", embed_dim=embed_dim)
|
| 45 |
+
else:
|
| 46 |
+
raise NotImplementedError(f"Model {model_name} not (yet) implemented")
|
| 47 |
+
|
| 48 |
+
model.eval()
|
| 49 |
+
model.to(device)
|
| 50 |
+
|
| 51 |
+
sample_rate = 16_000 # all our models are trained on 16 kHz audio
|
| 52 |
+
segment_duration = 10 # all models are trained on 10-second pieces
|
| 53 |
+
segment_samples = segment_duration * sample_rate
|
| 54 |
+
|
| 55 |
+
# load audio
|
| 56 |
+
(waveform, _) = librosa.core.load(args.audio_file, sr=sample_rate, mono=True)
|
| 57 |
+
waveform = torch.from_numpy(waveform[None, :]).to(device)
|
| 58 |
+
waveform_len = waveform.shape[1]
|
| 59 |
+
|
| 60 |
+
audio_len = waveform_len / sample_rate # in seconds
|
| 61 |
+
print("Audio length (seconds): ", audio_len)
|
| 62 |
+
|
| 63 |
+
# encoder manages decoding of model predictions into dataframes
|
| 64 |
+
# containing event labels, onsets and offsets
|
| 65 |
+
encoder = ManyHotEncoder(audioset_classes.as_strong_train_classes, audio_len=audio_len)
|
| 66 |
+
|
| 67 |
+
# split audio file into 10-second chunks
|
| 68 |
+
num_chunks = waveform_len // segment_samples + (waveform_len % segment_samples != 0)
|
| 69 |
+
all_predictions = []
|
| 70 |
+
|
| 71 |
+
# Process each 10-second chunk
|
| 72 |
+
for i in range(num_chunks):
|
| 73 |
+
start_idx = i * segment_samples
|
| 74 |
+
end_idx = min((i + 1) * segment_samples, waveform_len)
|
| 75 |
+
waveform_chunk = waveform[:, start_idx:end_idx]
|
| 76 |
+
|
| 77 |
+
# Pad the last chunk if it's shorter than 10 seconds
|
| 78 |
+
if waveform_chunk.shape[1] < segment_samples:
|
| 79 |
+
pad_size = segment_samples - waveform_chunk.shape[1]
|
| 80 |
+
waveform_chunk = torch.nn.functional.pad(waveform_chunk, (0, pad_size))
|
| 81 |
+
|
| 82 |
+
# Run inference for each chunk
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
mel = model.mel_forward(waveform_chunk)
|
| 85 |
+
y_strong, _ = model(mel)
|
| 86 |
+
|
| 87 |
+
# Collect predictions
|
| 88 |
+
all_predictions.append(y_strong)
|
| 89 |
+
|
| 90 |
+
# Concatenate all predictions along the time axis
|
| 91 |
+
y_strong = torch.cat(all_predictions, dim=2)
|
| 92 |
+
# convert into probabilities
|
| 93 |
+
y_strong = torch.sigmoid(y_strong)
|
| 94 |
+
|
| 95 |
+
(
|
| 96 |
+
scores_unprocessed,
|
| 97 |
+
scores_postprocessed,
|
| 98 |
+
decoded_predictions
|
| 99 |
+
) = batched_decode_preds(
|
| 100 |
+
y_strong.float(),
|
| 101 |
+
[args.audio_file],
|
| 102 |
+
encoder,
|
| 103 |
+
median_filter=args.median_window,
|
| 104 |
+
thresholds=args.detection_thresholds,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
for th in decoded_predictions:
|
| 108 |
+
print("***************************************")
|
| 109 |
+
print(f"Detected events using threshold {th}:")
|
| 110 |
+
print(decoded_predictions[th].sort_values(by="onset"))
|
| 111 |
+
print("***************************************")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
parser = argparse.ArgumentParser(description='Example of parser. ')
|
| 116 |
+
# model names: [BEATs, ASIT, ATST-F, fpasst, M2D]
|
| 117 |
+
parser.add_argument('--model_name', type=str, default='BEATs')
|
| 118 |
+
parser.add_argument('--audio_file', type=str,
|
| 119 |
+
default='test_files/752547__iscence__milan_metro_coming_in_station.wav')
|
| 120 |
+
parser.add_argument('--detection_thresholds', type=float, default=(0.1, 0.2, 0.5))
|
| 121 |
+
parser.add_argument('--median_window', type=float, default=9)
|
| 122 |
+
parser.add_argument('--cuda', action='store_true', default=False)
|
| 123 |
+
args = parser.parse_args()
|
| 124 |
+
|
| 125 |
+
assert args.model_name in ["BEATs", "ASIT", "ATST-F", "fpasst", "M2D"] or args.model_name.startswith("frame_mn")
|
| 126 |
+
sound_event_detection(args)
|
models/asit/ASIT_wrapper.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.asit.data_transformations import DataAugmentation
|
| 2 |
+
from models.asit.vision_transformer import vit_base
|
| 3 |
+
from models.transformer_wrapper import BaseModelWrapper
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ASiTWrapper(BaseModelWrapper):
|
| 7 |
+
def __init__(self) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.asit_mel = DataAugmentation()
|
| 10 |
+
self.asit = vit_base(
|
| 11 |
+
patch_size=[16, 16],
|
| 12 |
+
audio_size=[128, 592],
|
| 13 |
+
stride=[16, 16],
|
| 14 |
+
in_chans=1,
|
| 15 |
+
num_classes=0
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def mel_forward(self, x):
|
| 19 |
+
return self.asit_mel(x)
|
| 20 |
+
|
| 21 |
+
def forward(self, spec):
|
| 22 |
+
return self.asit(spec)
|
| 23 |
+
|
| 24 |
+
def separate_params(self):
|
| 25 |
+
pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
|
| 26 |
+
for k, p in self.named_parameters():
|
| 27 |
+
if any(['cls_token' in k,
|
| 28 |
+
'pos_embed' in k,
|
| 29 |
+
'norm_stats' in k,
|
| 30 |
+
'patch_embed' in k]):
|
| 31 |
+
pt_params[0].append(p)
|
| 32 |
+
elif 'blocks.0.' in k:
|
| 33 |
+
pt_params[0].append(p)
|
| 34 |
+
elif 'blocks.1.' in k:
|
| 35 |
+
pt_params[1].append(p)
|
| 36 |
+
elif 'blocks.2.' in k:
|
| 37 |
+
pt_params[2].append(p)
|
| 38 |
+
elif 'blocks.3.' in k:
|
| 39 |
+
pt_params[3].append(p)
|
| 40 |
+
elif 'blocks.4.' in k:
|
| 41 |
+
pt_params[4].append(p)
|
| 42 |
+
elif 'blocks.5.' in k:
|
| 43 |
+
pt_params[5].append(p)
|
| 44 |
+
elif 'blocks.6.' in k:
|
| 45 |
+
pt_params[6].append(p)
|
| 46 |
+
elif 'blocks.7.' in k:
|
| 47 |
+
pt_params[7].append(p)
|
| 48 |
+
elif 'blocks.8.' in k:
|
| 49 |
+
pt_params[8].append(p)
|
| 50 |
+
elif 'blocks.9.' in k:
|
| 51 |
+
pt_params[9].append(p)
|
| 52 |
+
elif 'blocks.10.' in k:
|
| 53 |
+
pt_params[10].append(p)
|
| 54 |
+
elif 'blocks.11.' in k:
|
| 55 |
+
pt_params[11].append(p)
|
| 56 |
+
elif 'asit.norm.weight' in k or 'asit.norm.bias' in k:
|
| 57 |
+
pt_params[11].append(p)
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Check separate params for ASiT! Unknown key: {k}")
|
| 60 |
+
return list(reversed(pt_params))
|
models/asit/data_transformations.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional
|
| 3 |
+
import torchaudio
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DataAugmentation(object):
|
| 7 |
+
def __init__(self, data_mean=-4.2677393, data_std=4.5689974, num_mel_bins=128, sample_rate=16000):
|
| 8 |
+
self.data_mean = data_mean
|
| 9 |
+
self.data_std = data_std
|
| 10 |
+
self.num_mel_bins = num_mel_bins
|
| 11 |
+
self.sample_rate = sample_rate
|
| 12 |
+
|
| 13 |
+
def _wav2fbank(self, waveform):
|
| 14 |
+
waveform = (waveform - waveform.mean())
|
| 15 |
+
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=self.sample_rate,
|
| 16 |
+
use_energy=False,
|
| 17 |
+
window_type='hanning', num_mel_bins=self.num_mel_bins, dither=0.0,
|
| 18 |
+
frame_shift=10)
|
| 19 |
+
return fbank
|
| 20 |
+
|
| 21 |
+
def convert_waveform(self, waveform):
|
| 22 |
+
w = self._wav2fbank(waveform)
|
| 23 |
+
fbank = (w - self.data_mean) / (self.data_std * 2)
|
| 24 |
+
fbank = fbank.unsqueeze(0)
|
| 25 |
+
return fbank
|
| 26 |
+
|
| 27 |
+
def __call__(self, batch):
|
| 28 |
+
# apply convert_waveform to each sample of the batch and return the result
|
| 29 |
+
return torch.stack([self.convert_waveform(sample.reshape(1, -1)) for sample in batch]).permute(0, 1, 3, 2)
|
models/asit/utils.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
warnings.filterwarnings("ignore")
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
import math
|
| 9 |
+
import random
|
| 10 |
+
import datetime
|
| 11 |
+
import subprocess
|
| 12 |
+
from collections import defaultdict, deque
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
from numpy.random import randint
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def GMML_replace_list(samples, corrup_prev, masks_prev, drop_type='noise', max_replace=0.35, align=16):
|
| 24 |
+
rep_drop = 1 if drop_type == '' else (1 / (len(drop_type.split('-')) + 1))
|
| 25 |
+
|
| 26 |
+
n_imgs = samples.size()[0] # this is batch size, but in case bad inistance happened while loading
|
| 27 |
+
samples_aug = samples.detach().clone()
|
| 28 |
+
masks = torch.zeros_like(samples_aug)
|
| 29 |
+
for i in range(n_imgs):
|
| 30 |
+
idx_rnd = randint(0, n_imgs)
|
| 31 |
+
if random.random() < rep_drop:
|
| 32 |
+
samples_aug[i], masks[i] = GMML_drop_rand_patches(samples_aug[i], samples[idx_rnd], max_replace=max_replace,
|
| 33 |
+
align=align)
|
| 34 |
+
else:
|
| 35 |
+
samples_aug[i], masks[i] = corrup_prev[i], masks_prev[i]
|
| 36 |
+
|
| 37 |
+
return samples_aug, masks
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def GMML_drop_rand_patches(X, X_rep=None, drop_type='noise', max_replace=0.7, align=16, max_block_sz=0.3):
|
| 41 |
+
#######################
|
| 42 |
+
# max_replace: percentage of image to be replaced
|
| 43 |
+
# align: align corruption with the patch sizes
|
| 44 |
+
# max_block_sz: percentage of the maximum block to be dropped
|
| 45 |
+
#######################
|
| 46 |
+
|
| 47 |
+
np.random.seed()
|
| 48 |
+
C, H, W = X.size()
|
| 49 |
+
n_drop_pix = np.random.uniform(min(0.5, max_replace), max_replace) * H * W
|
| 50 |
+
mx_blk_height = int(H * max_block_sz)
|
| 51 |
+
mx_blk_width = int(W * max_block_sz)
|
| 52 |
+
|
| 53 |
+
align = max(1, align)
|
| 54 |
+
|
| 55 |
+
mask = torch.zeros_like(X)
|
| 56 |
+
drop_t = np.random.choice(drop_type.split('-'))
|
| 57 |
+
|
| 58 |
+
while mask[0].sum() < n_drop_pix:
|
| 59 |
+
|
| 60 |
+
####### get a random block to replace
|
| 61 |
+
rnd_r = (randint(0, H - align) // align) * align
|
| 62 |
+
rnd_c = (randint(0, W - align) // align) * align
|
| 63 |
+
|
| 64 |
+
rnd_h = min(randint(align, mx_blk_height), H - rnd_r)
|
| 65 |
+
rnd_h = round(rnd_h / align) * align
|
| 66 |
+
rnd_w = min(randint(align, mx_blk_width), W - rnd_c)
|
| 67 |
+
rnd_w = round(rnd_w / align) * align
|
| 68 |
+
|
| 69 |
+
if X_rep is not None:
|
| 70 |
+
X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = X_rep[:, rnd_r:rnd_r + rnd_h,
|
| 71 |
+
rnd_c:rnd_c + rnd_w].detach().clone()
|
| 72 |
+
else:
|
| 73 |
+
if drop_t == 'noise':
|
| 74 |
+
X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = torch.empty((C, rnd_h, rnd_w), dtype=X.dtype,
|
| 75 |
+
device=X.device).normal_()
|
| 76 |
+
elif drop_t == 'zeros':
|
| 77 |
+
X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = torch.zeros((C, rnd_h, rnd_w), dtype=X.dtype,
|
| 78 |
+
device=X.device)
|
| 79 |
+
else:
|
| 80 |
+
####### get a random block to replace from
|
| 81 |
+
rnd_r2 = (randint(0, H - rnd_h) // align) * align
|
| 82 |
+
rnd_c2 = (randint(0, W - rnd_w) // align) * align
|
| 83 |
+
|
| 84 |
+
X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = X[:, rnd_r2:rnd_r2 + rnd_h,
|
| 85 |
+
rnd_c2:rnd_c2 + rnd_w].detach().clone()
|
| 86 |
+
|
| 87 |
+
mask[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = 1
|
| 88 |
+
|
| 89 |
+
return X, mask
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class collate_batch(object): # replace from other images
|
| 93 |
+
def __init__(self, drop_replace=0., drop_align=1):
|
| 94 |
+
self.drop_replace = drop_replace
|
| 95 |
+
self.drop_align = drop_align
|
| 96 |
+
|
| 97 |
+
def __call__(self, batch):
|
| 98 |
+
batch = torch.utils.data.dataloader.default_collate(batch)
|
| 99 |
+
|
| 100 |
+
if self.drop_replace > 0:
|
| 101 |
+
batch[0][1][0], batch[0][2][0] = GMML_replace_list(batch[0][0][0], batch[0][1][0], batch[0][2][0],
|
| 102 |
+
max_replace=self.drop_replace, align=self.drop_align)
|
| 103 |
+
batch[0][1][1], batch[0][2][1] = GMML_replace_list(batch[0][0][1], batch[0][1][1], batch[0][2][1],
|
| 104 |
+
max_replace=self.drop_replace, align=self.drop_align)
|
| 105 |
+
|
| 106 |
+
return batch
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def clip_gradients(model, clip):
|
| 110 |
+
norms = []
|
| 111 |
+
for name, p in model.named_parameters():
|
| 112 |
+
if p.grad is not None:
|
| 113 |
+
param_norm = p.grad.data.norm(2)
|
| 114 |
+
norms.append(param_norm.item())
|
| 115 |
+
clip_coef = clip / (param_norm + 1e-6)
|
| 116 |
+
if clip_coef < 1:
|
| 117 |
+
p.grad.data.mul_(clip_coef)
|
| 118 |
+
return norms
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
| 122 |
+
if epoch >= freeze_last_layer:
|
| 123 |
+
return
|
| 124 |
+
for n, p in model.named_parameters():
|
| 125 |
+
if "last_layer" in n:
|
| 126 |
+
p.grad = None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
|
| 130 |
+
"""
|
| 131 |
+
Re-start from checkpoint
|
| 132 |
+
"""
|
| 133 |
+
if not os.path.isfile(ckp_path):
|
| 134 |
+
return
|
| 135 |
+
print("Found checkpoint at {}".format(ckp_path))
|
| 136 |
+
|
| 137 |
+
# open checkpoint file
|
| 138 |
+
checkpoint = torch.load(ckp_path, map_location="cpu")
|
| 139 |
+
|
| 140 |
+
# key is what to look for in the checkpoint file
|
| 141 |
+
# value is the object to load
|
| 142 |
+
# example: {'state_dict': model}
|
| 143 |
+
for key, value in kwargs.items():
|
| 144 |
+
if key in checkpoint and value is not None:
|
| 145 |
+
try:
|
| 146 |
+
msg = value.load_state_dict(checkpoint[key], strict=False)
|
| 147 |
+
print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
|
| 148 |
+
except TypeError:
|
| 149 |
+
try:
|
| 150 |
+
msg = value.load_state_dict(checkpoint[key])
|
| 151 |
+
print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
|
| 152 |
+
except ValueError:
|
| 153 |
+
print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
|
| 154 |
+
else:
|
| 155 |
+
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
|
| 156 |
+
|
| 157 |
+
# re load variable important for the run
|
| 158 |
+
if run_variables is not None:
|
| 159 |
+
for var_name in run_variables:
|
| 160 |
+
if var_name in checkpoint:
|
| 161 |
+
run_variables[var_name] = checkpoint[var_name]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
|
| 165 |
+
warmup_schedule = np.array([])
|
| 166 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
| 167 |
+
if warmup_epochs > 0:
|
| 168 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 169 |
+
|
| 170 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
| 171 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
| 172 |
+
|
| 173 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
| 174 |
+
assert len(schedule) == epochs * niter_per_ep
|
| 175 |
+
return schedule
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def bool_flag(s):
|
| 179 |
+
"""
|
| 180 |
+
Parse boolean arguments from the command line.
|
| 181 |
+
"""
|
| 182 |
+
FALSY_STRINGS = {"off", "false", "0"}
|
| 183 |
+
TRUTHY_STRINGS = {"on", "true", "1"}
|
| 184 |
+
if s.lower() in FALSY_STRINGS:
|
| 185 |
+
return False
|
| 186 |
+
elif s.lower() in TRUTHY_STRINGS:
|
| 187 |
+
return True
|
| 188 |
+
else:
|
| 189 |
+
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def fix_random_seeds(seed=31):
|
| 193 |
+
"""
|
| 194 |
+
Fix random seeds.
|
| 195 |
+
"""
|
| 196 |
+
torch.manual_seed(seed)
|
| 197 |
+
torch.cuda.manual_seed_all(seed)
|
| 198 |
+
np.random.seed(seed)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class SmoothedValue(object):
|
| 202 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 203 |
+
window or the global series average.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(self, window_size=20, fmt=None):
|
| 207 |
+
if fmt is None:
|
| 208 |
+
fmt = "{median:.6f} ({global_avg:.6f})"
|
| 209 |
+
self.deque = deque(maxlen=window_size)
|
| 210 |
+
self.total = 0.0
|
| 211 |
+
self.count = 0
|
| 212 |
+
self.fmt = fmt
|
| 213 |
+
|
| 214 |
+
def update(self, value, n=1):
|
| 215 |
+
self.deque.append(value)
|
| 216 |
+
self.count += n
|
| 217 |
+
self.total += value * n
|
| 218 |
+
|
| 219 |
+
def synchronize_between_processes(self):
|
| 220 |
+
"""
|
| 221 |
+
Warning: does not synchronize the deque!
|
| 222 |
+
"""
|
| 223 |
+
if not is_dist_avail_and_initialized():
|
| 224 |
+
return
|
| 225 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 226 |
+
dist.barrier()
|
| 227 |
+
dist.all_reduce(t)
|
| 228 |
+
t = t.tolist()
|
| 229 |
+
self.count = int(t[0])
|
| 230 |
+
self.total = t[1]
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def median(self):
|
| 234 |
+
d = torch.tensor(list(self.deque))
|
| 235 |
+
return d.median().item()
|
| 236 |
+
|
| 237 |
+
@property
|
| 238 |
+
def avg(self):
|
| 239 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 240 |
+
return d.mean().item()
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
def global_avg(self):
|
| 244 |
+
return self.total / self.count
|
| 245 |
+
|
| 246 |
+
@property
|
| 247 |
+
def max(self):
|
| 248 |
+
return max(self.deque)
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def value(self):
|
| 252 |
+
return self.deque[-1]
|
| 253 |
+
|
| 254 |
+
def __str__(self):
|
| 255 |
+
return self.fmt.format(
|
| 256 |
+
median=self.median,
|
| 257 |
+
avg=self.avg,
|
| 258 |
+
global_avg=self.global_avg,
|
| 259 |
+
max=self.max,
|
| 260 |
+
value=self.value)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def reduce_dict(input_dict, average=True):
|
| 264 |
+
"""
|
| 265 |
+
Args:
|
| 266 |
+
input_dict (dict): all the values will be reduced
|
| 267 |
+
average (bool): whether to do average or sum
|
| 268 |
+
Reduce the values in the dictionary from all processes so that all processes
|
| 269 |
+
have the averaged results. Returns a dict with the same fields as
|
| 270 |
+
input_dict, after reduction.
|
| 271 |
+
"""
|
| 272 |
+
world_size = get_world_size()
|
| 273 |
+
if world_size < 2:
|
| 274 |
+
return input_dict
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
names = []
|
| 277 |
+
values = []
|
| 278 |
+
# sort the keys so that they are consistent across processes
|
| 279 |
+
for k in sorted(input_dict.keys()):
|
| 280 |
+
names.append(k)
|
| 281 |
+
values.append(input_dict[k])
|
| 282 |
+
values = torch.stack(values, dim=0)
|
| 283 |
+
dist.all_reduce(values)
|
| 284 |
+
if average:
|
| 285 |
+
values /= world_size
|
| 286 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
| 287 |
+
return reduced_dict
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class MetricLogger(object):
|
| 291 |
+
def __init__(self, delimiter="\t"):
|
| 292 |
+
self.meters = defaultdict(SmoothedValue)
|
| 293 |
+
self.delimiter = delimiter
|
| 294 |
+
|
| 295 |
+
def update(self, **kwargs):
|
| 296 |
+
for k, v in kwargs.items():
|
| 297 |
+
if isinstance(v, torch.Tensor):
|
| 298 |
+
v = v.item()
|
| 299 |
+
assert isinstance(v, (float, int))
|
| 300 |
+
self.meters[k].update(v)
|
| 301 |
+
|
| 302 |
+
def __getattr__(self, attr):
|
| 303 |
+
if attr in self.meters:
|
| 304 |
+
return self.meters[attr]
|
| 305 |
+
if attr in self.__dict__:
|
| 306 |
+
return self.__dict__[attr]
|
| 307 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 308 |
+
type(self).__name__, attr))
|
| 309 |
+
|
| 310 |
+
def __str__(self):
|
| 311 |
+
loss_str = []
|
| 312 |
+
for name, meter in self.meters.items():
|
| 313 |
+
loss_str.append(
|
| 314 |
+
"{}: {}".format(name, str(meter))
|
| 315 |
+
)
|
| 316 |
+
return self.delimiter.join(loss_str)
|
| 317 |
+
|
| 318 |
+
def synchronize_between_processes(self):
|
| 319 |
+
for meter in self.meters.values():
|
| 320 |
+
meter.synchronize_between_processes()
|
| 321 |
+
|
| 322 |
+
def add_meter(self, name, meter):
|
| 323 |
+
self.meters[name] = meter
|
| 324 |
+
|
| 325 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 326 |
+
i = 0
|
| 327 |
+
if not header:
|
| 328 |
+
header = ''
|
| 329 |
+
start_time = time.time()
|
| 330 |
+
end = time.time()
|
| 331 |
+
iter_time = SmoothedValue(fmt='{avg:.6f}')
|
| 332 |
+
data_time = SmoothedValue(fmt='{avg:.6f}')
|
| 333 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 334 |
+
if torch.cuda.is_available():
|
| 335 |
+
log_msg = self.delimiter.join([
|
| 336 |
+
header,
|
| 337 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 338 |
+
'eta: {eta}',
|
| 339 |
+
'{meters}',
|
| 340 |
+
'time: {time}',
|
| 341 |
+
'data: {data}',
|
| 342 |
+
'max mem: {memory:.0f}'
|
| 343 |
+
])
|
| 344 |
+
else:
|
| 345 |
+
log_msg = self.delimiter.join([
|
| 346 |
+
header,
|
| 347 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 348 |
+
'eta: {eta}',
|
| 349 |
+
'{meters}',
|
| 350 |
+
'time: {time}',
|
| 351 |
+
'data: {data}'
|
| 352 |
+
])
|
| 353 |
+
MB = 1024.0 * 1024.0
|
| 354 |
+
for obj in iterable:
|
| 355 |
+
data_time.update(time.time() - end)
|
| 356 |
+
yield obj
|
| 357 |
+
iter_time.update(time.time() - end)
|
| 358 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 359 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 360 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 361 |
+
if torch.cuda.is_available():
|
| 362 |
+
print(log_msg.format(
|
| 363 |
+
i, len(iterable), eta=eta_string,
|
| 364 |
+
meters=str(self),
|
| 365 |
+
time=str(iter_time), data=str(data_time),
|
| 366 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 367 |
+
else:
|
| 368 |
+
print(log_msg.format(
|
| 369 |
+
i, len(iterable), eta=eta_string,
|
| 370 |
+
meters=str(self),
|
| 371 |
+
time=str(iter_time), data=str(data_time)))
|
| 372 |
+
i += 1
|
| 373 |
+
end = time.time()
|
| 374 |
+
total_time = time.time() - start_time
|
| 375 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 376 |
+
print('{} Total time: {} ({:.6f} s / it)'.format(
|
| 377 |
+
header, total_time_str, total_time / len(iterable)))
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def get_sha():
|
| 381 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 382 |
+
|
| 383 |
+
def _run(command):
|
| 384 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
| 385 |
+
|
| 386 |
+
sha = 'N/A'
|
| 387 |
+
diff = "clean"
|
| 388 |
+
branch = 'N/A'
|
| 389 |
+
try:
|
| 390 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
| 391 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
| 392 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
| 393 |
+
diff = "has uncommited changes" if diff else "clean"
|
| 394 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
| 395 |
+
except Exception:
|
| 396 |
+
pass
|
| 397 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 398 |
+
return message
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def is_dist_avail_and_initialized():
|
| 402 |
+
if not dist.is_available():
|
| 403 |
+
return False
|
| 404 |
+
if not dist.is_initialized():
|
| 405 |
+
return False
|
| 406 |
+
return True
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def get_world_size():
|
| 410 |
+
if not is_dist_avail_and_initialized():
|
| 411 |
+
return 1
|
| 412 |
+
return dist.get_world_size()
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def get_rank():
|
| 416 |
+
if not is_dist_avail_and_initialized():
|
| 417 |
+
return 0
|
| 418 |
+
return dist.get_rank()
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def is_main_process():
|
| 422 |
+
return get_rank() == 0
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def save_on_master(*args, **kwargs):
|
| 426 |
+
if is_main_process():
|
| 427 |
+
torch.save(*args, **kwargs)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def setup_for_distributed(is_master):
|
| 431 |
+
"""
|
| 432 |
+
This function disables printing when not in master process
|
| 433 |
+
"""
|
| 434 |
+
import builtins as __builtin__
|
| 435 |
+
builtin_print = __builtin__.print
|
| 436 |
+
|
| 437 |
+
def print(*args, **kwargs):
|
| 438 |
+
force = kwargs.pop('force', False)
|
| 439 |
+
if is_master or force:
|
| 440 |
+
builtin_print(*args, **kwargs)
|
| 441 |
+
|
| 442 |
+
__builtin__.print = print
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def init_distributed_mode(args):
|
| 446 |
+
# launched with torch.distributed.launch
|
| 447 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 448 |
+
args.rank = int(os.environ["RANK"])
|
| 449 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 450 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 451 |
+
# launched with submitit on a slurm cluster
|
| 452 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 453 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 454 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 455 |
+
elif torch.cuda.is_available():
|
| 456 |
+
print('Will run the code on one GPU.')
|
| 457 |
+
args.rank, args.gpu, args.world_size = 0, 0, 1
|
| 458 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
| 459 |
+
os.environ['MASTER_PORT'] = '29500'
|
| 460 |
+
else:
|
| 461 |
+
print('Does not support training without GPU.')
|
| 462 |
+
sys.exit(1)
|
| 463 |
+
|
| 464 |
+
args.distributed = True
|
| 465 |
+
dist.init_process_group(
|
| 466 |
+
backend="nccl",
|
| 467 |
+
init_method=args.dist_url,
|
| 468 |
+
world_size=args.world_size,
|
| 469 |
+
rank=args.rank,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
torch.cuda.set_device(args.gpu)
|
| 473 |
+
print('| distributed init (rank {}): {}'.format(
|
| 474 |
+
args.rank, args.dist_url), flush=True)
|
| 475 |
+
dist.barrier()
|
| 476 |
+
setup_for_distributed(args.rank == 0)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def accuracy(output, target, topk=(1,)):
|
| 480 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
| 481 |
+
maxk = max(topk)
|
| 482 |
+
batch_size = target.size(0)
|
| 483 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 484 |
+
pred = pred.t()
|
| 485 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
| 486 |
+
return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 490 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 491 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 492 |
+
def norm_cdf(x):
|
| 493 |
+
# Computes standard normal cumulative distribution function
|
| 494 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 495 |
+
|
| 496 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 497 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 498 |
+
"The distribution of values may be incorrect.",
|
| 499 |
+
stacklevel=2)
|
| 500 |
+
|
| 501 |
+
with torch.no_grad():
|
| 502 |
+
# Values are generated by using a truncated uniform distribution and
|
| 503 |
+
# then using the inverse CDF for the normal distribution.
|
| 504 |
+
# Get upper and lower cdf values
|
| 505 |
+
l = norm_cdf((a - mean) / std)
|
| 506 |
+
u = norm_cdf((b - mean) / std)
|
| 507 |
+
|
| 508 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 509 |
+
# [2l-1, 2u-1].
|
| 510 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 511 |
+
|
| 512 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 513 |
+
# standard normal
|
| 514 |
+
tensor.erfinv_()
|
| 515 |
+
|
| 516 |
+
# Transform to proper mean, std
|
| 517 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 518 |
+
tensor.add_(mean)
|
| 519 |
+
|
| 520 |
+
# Clamp to ensure it's in the proper range
|
| 521 |
+
tensor.clamp_(min=a, max=b)
|
| 522 |
+
return tensor
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 526 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def get_params_groups(model):
|
| 530 |
+
regularized = []
|
| 531 |
+
not_regularized = []
|
| 532 |
+
for name, param in model.named_parameters():
|
| 533 |
+
if not param.requires_grad:
|
| 534 |
+
continue
|
| 535 |
+
# we do not regularize biases nor Norm parameters
|
| 536 |
+
if name.endswith(".bias") or len(param.shape) == 1:
|
| 537 |
+
not_regularized.append(param)
|
| 538 |
+
else:
|
| 539 |
+
regularized.append(param)
|
| 540 |
+
return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
|
models/asit/vision_transformer.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from models.asit.utils import trunc_normal_
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 10 |
+
if drop_prob == 0. or not training:
|
| 11 |
+
return x
|
| 12 |
+
keep_prob = 1 - drop_prob
|
| 13 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 14 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 15 |
+
random_tensor.floor_() # binarize
|
| 16 |
+
output = x.div(keep_prob) * random_tensor
|
| 17 |
+
return output
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DropPath(nn.Module):
|
| 21 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, drop_prob=None):
|
| 25 |
+
super(DropPath, self).__init__()
|
| 26 |
+
self.drop_prob = drop_prob
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Mlp(nn.Module):
|
| 33 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 34 |
+
super().__init__()
|
| 35 |
+
out_features = out_features or in_features
|
| 36 |
+
hidden_features = hidden_features or in_features
|
| 37 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 38 |
+
self.act = act_layer()
|
| 39 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 40 |
+
self.drop = nn.Dropout(drop)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
x = self.fc1(x)
|
| 44 |
+
x = self.act(x)
|
| 45 |
+
x = self.drop(x)
|
| 46 |
+
x = self.fc2(x)
|
| 47 |
+
x = self.drop(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Attention(nn.Module):
|
| 52 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.num_heads = num_heads
|
| 55 |
+
head_dim = dim // num_heads
|
| 56 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 57 |
+
|
| 58 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 59 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 60 |
+
self.proj = nn.Linear(dim, dim)
|
| 61 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
B, N, C = x.shape
|
| 65 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 66 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 67 |
+
|
| 68 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 69 |
+
attn = attn.softmax(dim=-1)
|
| 70 |
+
attn = self.attn_drop(attn)
|
| 71 |
+
|
| 72 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 73 |
+
x = self.proj(x)
|
| 74 |
+
x = self.proj_drop(x)
|
| 75 |
+
return x, attn
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Block(nn.Module):
|
| 79 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 80 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.norm1 = norm_layer(dim)
|
| 83 |
+
self.attn = Attention(
|
| 84 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 85 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 86 |
+
self.norm2 = norm_layer(dim)
|
| 87 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 88 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, return_attention=False):
|
| 91 |
+
y, attn = self.attn(self.norm1(x))
|
| 92 |
+
if return_attention:
|
| 93 |
+
return attn
|
| 94 |
+
x = x + self.drop_path(y)
|
| 95 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class PatchEmbed(nn.Module):
|
| 100 |
+
""" Image to Patch Embedding
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, img_size=[1024, 128], patch_size=[16, 16], in_chans=3, embed_dim=768):
|
| 104 |
+
super().__init__()
|
| 105 |
+
num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
| 106 |
+
self.img_size = img_size
|
| 107 |
+
self.patch_size = patch_size
|
| 108 |
+
self.num_patches = num_patches
|
| 109 |
+
|
| 110 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
B, C, H, W = x.shape
|
| 114 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class VisionTransformer(nn.Module):
|
| 119 |
+
""" Vision Transformer """
|
| 120 |
+
|
| 121 |
+
def __init__(self, audio_size=[1024, 128], patch_size=[16, 16], in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
| 122 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
| 123 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.num_features = self.embed_dim = embed_dim
|
| 126 |
+
self.audio_size = audio_size
|
| 127 |
+
self.patch_size = patch_size
|
| 128 |
+
self.patch_embed = PatchEmbed(
|
| 129 |
+
img_size=audio_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 130 |
+
num_patches = self.patch_embed.num_patches
|
| 131 |
+
|
| 132 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 133 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 134 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 135 |
+
|
| 136 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 137 |
+
self.blocks = nn.ModuleList([
|
| 138 |
+
Block(
|
| 139 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 140 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 141 |
+
for i in range(depth)])
|
| 142 |
+
self.norm = norm_layer(embed_dim)
|
| 143 |
+
|
| 144 |
+
# Classifier head
|
| 145 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 146 |
+
|
| 147 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 148 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 149 |
+
self.apply(self._init_weights)
|
| 150 |
+
|
| 151 |
+
def _init_weights(self, m):
|
| 152 |
+
if isinstance(m, nn.Linear):
|
| 153 |
+
trunc_normal_(m.weight, std=.02)
|
| 154 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 155 |
+
nn.init.constant_(m.bias, 0)
|
| 156 |
+
elif isinstance(m, nn.LayerNorm):
|
| 157 |
+
nn.init.constant_(m.bias, 0)
|
| 158 |
+
nn.init.constant_(m.weight, 1.0)
|
| 159 |
+
|
| 160 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 161 |
+
npatch = (w / 16) * (h / 16)
|
| 162 |
+
N = self.pos_embed.shape[1] - 1
|
| 163 |
+
if npatch == N:
|
| 164 |
+
return self.pos_embed
|
| 165 |
+
|
| 166 |
+
class_pos_embed = self.pos_embed[:, 0]
|
| 167 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
| 168 |
+
|
| 169 |
+
sz1 = w // self.patch_size[0]
|
| 170 |
+
sz2 = h // self.patch_size[0]
|
| 171 |
+
|
| 172 |
+
prev_sz1 = self.audio_size[0] // self.patch_size[0]
|
| 173 |
+
prev_sz2 = self.audio_size[1] // self.patch_size[1]
|
| 174 |
+
patch_pos_embed = torch.nn.functional.interpolate(
|
| 175 |
+
patch_pos_embed.transpose(1, 2).reshape(1, self.embed_dim, prev_sz1, prev_sz2), size=(sz1, sz2),
|
| 176 |
+
mode='bicubic', align_corners=False)
|
| 177 |
+
|
| 178 |
+
patch_pos_embed = patch_pos_embed.reshape(1, self.embed_dim, sz1 * sz2).transpose(1, 2)
|
| 179 |
+
|
| 180 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 181 |
+
|
| 182 |
+
def prepare_tokens(self, x):
|
| 183 |
+
B, nc, w, h = x.shape
|
| 184 |
+
x = self.patch_embed(x) # patch linear embedding
|
| 185 |
+
|
| 186 |
+
# add the [CLS] token to the embed patch tokens
|
| 187 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 188 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 189 |
+
|
| 190 |
+
# add positional encoding to each token
|
| 191 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 192 |
+
# x = x + self.pos_embed
|
| 193 |
+
return self.pos_drop(x)
|
| 194 |
+
|
| 195 |
+
def forward(self, x, classify=False):
|
| 196 |
+
x = x.permute(0, 1, 3, 2)
|
| 197 |
+
x = self.prepare_tokens(x)
|
| 198 |
+
for blk in self.blocks:
|
| 199 |
+
x = blk(x)
|
| 200 |
+
x = self.norm(x)
|
| 201 |
+
if classify == True:
|
| 202 |
+
return self.head(x[:, 0])
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
def get_last_selfattention(self, x):
|
| 206 |
+
x = self.prepare_tokens(x)
|
| 207 |
+
for i, blk in enumerate(self.blocks):
|
| 208 |
+
if i < len(self.blocks) - 1:
|
| 209 |
+
x = blk(x)
|
| 210 |
+
else:
|
| 211 |
+
# return attention of the last block
|
| 212 |
+
return blk(x, return_attention=True)
|
| 213 |
+
|
| 214 |
+
def get_intermediate_layers(self, x, n=1):
|
| 215 |
+
x = self.prepare_tokens(x)
|
| 216 |
+
# we return the output tokens from the `n` last blocks
|
| 217 |
+
output = []
|
| 218 |
+
for i, blk in enumerate(self.blocks):
|
| 219 |
+
x = blk(x)
|
| 220 |
+
if len(self.blocks) - i <= n:
|
| 221 |
+
output.append(self.norm(x))
|
| 222 |
+
return output
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def vit_tiny(patch_size=16, **kwargs):
|
| 226 |
+
model = VisionTransformer(
|
| 227 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
| 228 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 229 |
+
return model
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def vit_small(patch_size=[16, 16], audio_size=[1024, 128], stride=[16, 16], **kwargs):
|
| 233 |
+
model = VisionTransformer(
|
| 234 |
+
patch_size=patch_size, audio_size=audio_size, stride=stride, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
| 235 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 236 |
+
return model
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def vit_base(patch_size=[16, 16], audio_size=[1024, 128], stride=[16, 16], **kwargs):
|
| 240 |
+
model = VisionTransformer(
|
| 241 |
+
patch_size=patch_size, audio_size=audio_size, stride=stride, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
| 242 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 243 |
+
return model
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class CLSHead(nn.Module):
|
| 247 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
|
| 248 |
+
bottleneck_dim=256):
|
| 249 |
+
super().__init__()
|
| 250 |
+
nlayers = max(nlayers, 1)
|
| 251 |
+
if nlayers == 1:
|
| 252 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
| 253 |
+
else:
|
| 254 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
| 255 |
+
if use_bn:
|
| 256 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 257 |
+
layers.append(nn.GELU())
|
| 258 |
+
for _ in range(nlayers - 2):
|
| 259 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
| 260 |
+
if use_bn:
|
| 261 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 262 |
+
layers.append(nn.GELU())
|
| 263 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
| 264 |
+
self.mlp = nn.Sequential(*layers)
|
| 265 |
+
self.apply(self._init_weights)
|
| 266 |
+
self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)
|
| 267 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 268 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 269 |
+
|
| 270 |
+
def _init_weights(self, m):
|
| 271 |
+
if isinstance(m, nn.Linear):
|
| 272 |
+
trunc_normal_(m.weight, std=.02)
|
| 273 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 274 |
+
nn.init.constant_(m.bias, 0)
|
| 275 |
+
|
| 276 |
+
def forward(self, x):
|
| 277 |
+
x = self.mlp(x)
|
| 278 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
| 279 |
+
return self.last_layer(x)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class RECHead(nn.Module):
|
| 283 |
+
def __init__(self, in_dim, audio_size, in_chans=3, patch_size=16):
|
| 284 |
+
super().__init__()
|
| 285 |
+
|
| 286 |
+
self.audio_size = audio_size
|
| 287 |
+
self.patch_size = patch_size
|
| 288 |
+
|
| 289 |
+
layers = [nn.Linear(in_dim, in_dim)]
|
| 290 |
+
layers.append(nn.GELU())
|
| 291 |
+
layers.append(nn.Linear(in_dim, in_dim))
|
| 292 |
+
layers.append(nn.GELU())
|
| 293 |
+
layers.append(nn.Linear(in_dim, in_dim))
|
| 294 |
+
layers.append(nn.GELU())
|
| 295 |
+
|
| 296 |
+
self.mlp = nn.Sequential(*layers)
|
| 297 |
+
self.apply(self._init_weights)
|
| 298 |
+
|
| 299 |
+
self.convTrans = nn.ConvTranspose2d(in_dim, in_chans, kernel_size=(patch_size, patch_size),
|
| 300 |
+
stride=(patch_size, patch_size))
|
| 301 |
+
|
| 302 |
+
def _init_weights(self, m):
|
| 303 |
+
if isinstance(m, nn.Linear):
|
| 304 |
+
trunc_normal_(m.weight, std=.02)
|
| 305 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 306 |
+
nn.init.constant_(m.bias, 0)
|
| 307 |
+
|
| 308 |
+
def forward(self, x):
|
| 309 |
+
x = self.mlp(x)
|
| 310 |
+
|
| 311 |
+
x_rec = x.transpose(1, 2)
|
| 312 |
+
out_sz = (self.audio_size[0] // self.patch_size, self.audio_size[
|
| 313 |
+
1] // self.patch_size) # tuple( ( int(math.sqrt(x_rec.size()[2])) , int(math.sqrt(x_rec.size()[2])) ) )
|
| 314 |
+
x_rec = self.convTrans(x_rec.unflatten(2, out_sz))
|
| 315 |
+
|
| 316 |
+
return x_rec
|
models/atstframe/ATSTF_wrapper.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
|
| 3 |
+
|
| 4 |
+
from models.atstframe.audio_transformer import FrameASTModel
|
| 5 |
+
from models.transformer_wrapper import BaseModelWrapper
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ATSTWrapper(BaseModelWrapper):
|
| 9 |
+
def __init__(self, atst_dropout=0.0) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.atst_mel = ATSTMel()
|
| 12 |
+
self.atst = FrameASTModel(atst_dropout=atst_dropout)
|
| 13 |
+
self.fake_length = torch.tensor([1001])
|
| 14 |
+
self.cls_embed = None
|
| 15 |
+
|
| 16 |
+
def mel_forward(self, x):
|
| 17 |
+
return self.atst_mel(x)
|
| 18 |
+
|
| 19 |
+
def forward(self, spec):
|
| 20 |
+
atst_x = self.atst.get_intermediate_layers(
|
| 21 |
+
spec,
|
| 22 |
+
self.fake_length.to(spec).repeat(len(spec)),
|
| 23 |
+
1,
|
| 24 |
+
scene=False
|
| 25 |
+
)
|
| 26 |
+
return atst_x
|
| 27 |
+
|
| 28 |
+
def separate_params(self):
|
| 29 |
+
pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
|
| 30 |
+
for k, p in self.named_parameters():
|
| 31 |
+
if k in ['atst.mask_embed', 'atst.pos_embed', 'atst.patch_embed.patch_embed.weight',
|
| 32 |
+
'atst.patch_embed.patch_embed.bias'] or "blocks.0." in k:
|
| 33 |
+
pt_params[0].append(p)
|
| 34 |
+
elif "blocks.1." in k:
|
| 35 |
+
pt_params[1].append(p)
|
| 36 |
+
elif "blocks.2." in k:
|
| 37 |
+
pt_params[2].append(p)
|
| 38 |
+
elif "blocks.3." in k:
|
| 39 |
+
pt_params[3].append(p)
|
| 40 |
+
elif "blocks.4." in k:
|
| 41 |
+
pt_params[4].append(p)
|
| 42 |
+
elif "blocks.5." in k:
|
| 43 |
+
pt_params[5].append(p)
|
| 44 |
+
elif "blocks.6." in k:
|
| 45 |
+
pt_params[6].append(p)
|
| 46 |
+
elif "blocks.7." in k:
|
| 47 |
+
pt_params[7].append(p)
|
| 48 |
+
elif "blocks.8" in k:
|
| 49 |
+
pt_params[8].append(p)
|
| 50 |
+
elif "blocks.9." in k:
|
| 51 |
+
pt_params[9].append(p)
|
| 52 |
+
elif "blocks.10." in k:
|
| 53 |
+
pt_params[10].append(p)
|
| 54 |
+
elif "blocks.11." in k or ".norm_frame." in k:
|
| 55 |
+
pt_params[11].append(p)
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Check separate params for ATST! Unknown key: {k}")
|
| 58 |
+
return list(reversed(pt_params))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ATSTMel(torch.nn.Module):
|
| 62 |
+
def __init__(self) -> None:
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.mel_transform = MelSpectrogram(
|
| 65 |
+
16000,
|
| 66 |
+
f_min=60,
|
| 67 |
+
f_max=7800,
|
| 68 |
+
hop_length=160,
|
| 69 |
+
win_length=1024,
|
| 70 |
+
n_fft=1024,
|
| 71 |
+
n_mels=64
|
| 72 |
+
)
|
| 73 |
+
self.amp_to_db = AmplitudeToDB(stype="power", top_db=80)
|
| 74 |
+
self.scaler = MinMax(min=-79.6482, max=50.6842)
|
| 75 |
+
|
| 76 |
+
def amp2db(self, spec):
|
| 77 |
+
return self.amp_to_db(spec).clamp(min=-50, max=80)
|
| 78 |
+
|
| 79 |
+
def forward(self, audio):
|
| 80 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
| 81 |
+
spec = self.mel_transform(audio)
|
| 82 |
+
spec = self.scaler(self.amp2db(spec))
|
| 83 |
+
spec = spec.unsqueeze(1)
|
| 84 |
+
return spec
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class CustomAudioTransform:
|
| 88 |
+
def __repr__(self):
|
| 89 |
+
return self.__class__.__name__ + '()'
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class MinMax(CustomAudioTransform):
|
| 93 |
+
def __init__(self, min, max):
|
| 94 |
+
self.min = min
|
| 95 |
+
self.max = max
|
| 96 |
+
|
| 97 |
+
def __call__(self, input):
|
| 98 |
+
if self.min is None:
|
| 99 |
+
min_ = torch.min(input)
|
| 100 |
+
max_ = torch.max(input)
|
| 101 |
+
else:
|
| 102 |
+
min_ = self.min
|
| 103 |
+
max_ = self.max
|
| 104 |
+
input = (input - min_) / (max_ - min_) * 2. - 1.
|
| 105 |
+
return input
|
models/atstframe/audio_transformer.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from functools import partial
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .transformer import Block
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 11 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 12 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 13 |
+
def norm_cdf(x):
|
| 14 |
+
# Computes standard normal cumulative distribution function
|
| 15 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 16 |
+
|
| 17 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 18 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 19 |
+
"The distribution of values may be incorrect.",
|
| 20 |
+
stacklevel=2)
|
| 21 |
+
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
# Values are generated by using a truncated uniform distribution and
|
| 24 |
+
# then using the inverse CDF for the normal distribution.
|
| 25 |
+
# Get upper and lower cdf values
|
| 26 |
+
l = norm_cdf((a - mean) / std)
|
| 27 |
+
u = norm_cdf((b - mean) / std)
|
| 28 |
+
|
| 29 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 30 |
+
# [2l-1, 2u-1].
|
| 31 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 32 |
+
|
| 33 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 34 |
+
# standard normal
|
| 35 |
+
tensor.erfinv_()
|
| 36 |
+
|
| 37 |
+
# Transform to proper mean, std
|
| 38 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 39 |
+
tensor.add_(mean)
|
| 40 |
+
|
| 41 |
+
# Clamp to ensure it's in the proper range
|
| 42 |
+
tensor.clamp_(min=a, max=b)
|
| 43 |
+
return tensor
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 47 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_num_patches(height=64, width=1001, patch_height=16, patch_width=16):
|
| 51 |
+
return (height // patch_height) * (width // patch_width)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
from einops.layers.torch import Rearrange
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PatchEmbed_v2(nn.Module):
|
| 58 |
+
def __init__(self, patch_height=64, patch_width=4, embed_dim=768, input_dim=1):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.patch_height = patch_height
|
| 61 |
+
self.patch_width = patch_width
|
| 62 |
+
self.patch_maker = Rearrange('b c (h p1) (w p2) -> b (w h) (p1 p2 c)', p1=patch_height, p2=patch_width)
|
| 63 |
+
self.patch_embed = nn.Linear(patch_height * patch_width * input_dim, embed_dim)
|
| 64 |
+
|
| 65 |
+
def forward(self, melspec, length=None):
|
| 66 |
+
height = melspec.shape[2] - melspec.shape[2] % self.patch_height
|
| 67 |
+
width = melspec.shape[3] - melspec.shape[3] % self.patch_width
|
| 68 |
+
patch = self.patch_maker(melspec[:, :, :height, :width])
|
| 69 |
+
patch_embed = self.patch_embed(patch)
|
| 70 |
+
|
| 71 |
+
if length is not None:
|
| 72 |
+
patch_length = (torch.div(height, self.patch_height, rounding_mode='trunc')) * torch.div(
|
| 73 |
+
(length - length % self.patch_width), self.patch_width, rounding_mode='trunc')
|
| 74 |
+
else:
|
| 75 |
+
patch_length = None
|
| 76 |
+
|
| 77 |
+
return patch, patch_embed, patch_length
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class FrameAST(nn.Module):
|
| 81 |
+
""" Vision Transformer """
|
| 82 |
+
|
| 83 |
+
def __init__(self, nprompt=0, spec_h=64, spec_w=1001, patch_w=16, patch_h=16, pos_type="cut", in_chans=1,
|
| 84 |
+
num_classes=0, embed_dim=768, depth=12,
|
| 85 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.,
|
| 86 |
+
drop_path_rate=0.0, norm_layer=nn.LayerNorm, **kwargs):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.num_features = self.embed_dim = embed_dim
|
| 89 |
+
self.spec_w = spec_w
|
| 90 |
+
self.spec_h = spec_h
|
| 91 |
+
self.embed_dim = embed_dim
|
| 92 |
+
self.patch_w = patch_w
|
| 93 |
+
self.patch_h = patch_h
|
| 94 |
+
|
| 95 |
+
self.pos_type = pos_type
|
| 96 |
+
|
| 97 |
+
self.patch_embed = PatchEmbed_v2(patch_h, patch_w, embed_dim)
|
| 98 |
+
self.mask_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
| 99 |
+
|
| 100 |
+
# hack
|
| 101 |
+
self.nprompt = nprompt
|
| 102 |
+
if self.nprompt > 0:
|
| 103 |
+
self.prompt_embed = nn.Parameter(torch.zeros(1, self.nprompt, self.embed_dim))
|
| 104 |
+
trunc_normal_(self.prompt_embed, std=.02)
|
| 105 |
+
|
| 106 |
+
num_patches = get_num_patches(spec_h, spec_w, patch_h, patch_w)
|
| 107 |
+
self.num_patches = num_patches
|
| 108 |
+
|
| 109 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 110 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 111 |
+
|
| 112 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 113 |
+
self.blocks = nn.ModuleList([
|
| 114 |
+
Block(
|
| 115 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 116 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 117 |
+
for i in range(depth)])
|
| 118 |
+
self.norm_frame = norm_layer(embed_dim)
|
| 119 |
+
|
| 120 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 121 |
+
trunc_normal_(self.mask_embed, std=.02)
|
| 122 |
+
self.apply(self._init_weights)
|
| 123 |
+
|
| 124 |
+
def _init_weights(self, m):
|
| 125 |
+
if isinstance(m, nn.Linear):
|
| 126 |
+
trunc_normal_(m.weight, std=.02)
|
| 127 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 128 |
+
nn.init.constant_(m.bias, 0)
|
| 129 |
+
elif isinstance(m, nn.LayerNorm):
|
| 130 |
+
nn.init.constant_(m.bias, 0)
|
| 131 |
+
nn.init.constant_(m.weight, 1.0)
|
| 132 |
+
|
| 133 |
+
def prepare_tokens(self, x, mask_index, length, mask=True):
|
| 134 |
+
B, nc, h, w = x.shape
|
| 135 |
+
mel_patches, x, patch_length = self.patch_embed(x, length) # patch linear embedding
|
| 136 |
+
B, T, C = x.shape
|
| 137 |
+
|
| 138 |
+
if (mask_index is not None) and mask:
|
| 139 |
+
mask_index_expand = mask_index.unsqueeze(2).expand(B, T, self.embed_dim).float()
|
| 140 |
+
x = (1 - mask_index_expand) * x + mask_index_expand * self.mask_embed.expand(B, T, C)
|
| 141 |
+
|
| 142 |
+
# add positional encoding to each token
|
| 143 |
+
if self.pos_type == "cut":
|
| 144 |
+
pos = self.pos_embed[:, 1:T + 1, :].expand(B, -1, -1)
|
| 145 |
+
x = x + pos
|
| 146 |
+
else:
|
| 147 |
+
pos = self.interpolate_pos_encoding(x, h, w)
|
| 148 |
+
x = x + pos[:, 1:]
|
| 149 |
+
|
| 150 |
+
# pos = self.pos_embed[:,1:T+1,:].expand(B,-1,-1)
|
| 151 |
+
# x = x + pos
|
| 152 |
+
|
| 153 |
+
return self.pos_drop(x), pos, mel_patches, h, w, patch_length
|
| 154 |
+
|
| 155 |
+
def forward(self, x, mask_index=None, mask_input=True, length=None):
|
| 156 |
+
x, pos, mel_patches, h, w, patch_length = self.prepare_tokens(x, mask_index, length, mask_input)
|
| 157 |
+
|
| 158 |
+
length_mask = torch.arange(mel_patches.shape[1]).to(x.device) < patch_length.unsqueeze(1)
|
| 159 |
+
length_mask = length_mask.to(x.device)
|
| 160 |
+
mask_index = mask_index & length_mask
|
| 161 |
+
|
| 162 |
+
if self.nprompt > 0:
|
| 163 |
+
x = torch.cat([self.prompt_embed.expand(x.shape[0], -1, -1), x], dim=1)
|
| 164 |
+
|
| 165 |
+
for i, blk in enumerate(self.blocks):
|
| 166 |
+
x = blk(x, patch_length + self.nprompt)
|
| 167 |
+
|
| 168 |
+
frame_repr = self.norm_frame(x)
|
| 169 |
+
|
| 170 |
+
return frame_repr[:, self.nprompt:][mask_index]
|
| 171 |
+
|
| 172 |
+
def interpolate_pos_encoding(self, x, h, w):
|
| 173 |
+
npatch = x.shape[1] - 1
|
| 174 |
+
N = self.pos_embed.shape[1] - 1
|
| 175 |
+
if npatch == N and w == self.spec_w and h == self.spec_h:
|
| 176 |
+
return self.pos_embed
|
| 177 |
+
class_pos_embed = self.pos_embed[:, 0]
|
| 178 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
| 179 |
+
dim = x.shape[-1]
|
| 180 |
+
w0 = w // self.patch_embed.patch_width
|
| 181 |
+
h0 = h // self.patch_embed.patch_height
|
| 182 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 183 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 184 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
| 185 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 186 |
+
patch_pos_embed.reshape(1, self.spec_h // self.patch_h, self.spec_w // self.patch_w, dim).permute(0, 3, 1,
|
| 187 |
+
2),
|
| 188 |
+
scale_factor=(h0 / (self.spec_h // self.patch_h), w0 / (self.spec_w // self.patch_w)),
|
| 189 |
+
mode='bicubic',
|
| 190 |
+
)
|
| 191 |
+
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
| 192 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 193 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 194 |
+
|
| 195 |
+
def get_last_selfattention(self, x):
|
| 196 |
+
x, _, _, _, _, _ = self.prepare_tokens(x, mask_index=None, length=None, mask=False)
|
| 197 |
+
atts = []
|
| 198 |
+
for i, blk in enumerate(self.blocks):
|
| 199 |
+
if i < len(self.blocks) - 1:
|
| 200 |
+
x, att = blk(x, return_attention=True)
|
| 201 |
+
atts.append(att)
|
| 202 |
+
else:
|
| 203 |
+
x, att = blk(x, return_attention=True)
|
| 204 |
+
atts.append(att)
|
| 205 |
+
return atts
|
| 206 |
+
# return attention of the last block
|
| 207 |
+
|
| 208 |
+
def get_intermediate_layers(self, x, length, n=1, scene=True, other_emb=None):
|
| 209 |
+
x, _, _, _, _, patch_length = self.prepare_tokens(x, mask_index=None, length=length, mask=False)
|
| 210 |
+
# we return the output tokens from the `n` last blocks
|
| 211 |
+
if other_emb is not None:
|
| 212 |
+
x = torch.cat([other_emb, x], dim=1)
|
| 213 |
+
output = []
|
| 214 |
+
if self.nprompt > 0:
|
| 215 |
+
x = torch.cat([self.prompt_embed.expand(x.shape[0], -1, -1), x], dim=1)
|
| 216 |
+
for i, blk in enumerate(self.blocks):
|
| 217 |
+
x = blk(x, patch_length + self.nprompt)
|
| 218 |
+
if len(self.blocks) - i <= n:
|
| 219 |
+
norm_x = self.norm_frame(x)
|
| 220 |
+
if scene:
|
| 221 |
+
length_mask = torch.arange(x.shape[1] - self.nprompt).to(x.device) < patch_length.unsqueeze(1)
|
| 222 |
+
avg = torch.sum(norm_x[:, self.nprompt:] * length_mask.unsqueeze(-1), dim=1) / (
|
| 223 |
+
patch_length.unsqueeze(-1) + 1e-6)
|
| 224 |
+
negative = (~length_mask) * -1e10
|
| 225 |
+
# max = torch.max(norm_x[:,self.nprompt:]+negative.unsqueeze(-1),1).values
|
| 226 |
+
output.append(avg)
|
| 227 |
+
if self.nprompt > 0:
|
| 228 |
+
output.append(torch.mean(norm_x[:, :self.nprompt], dim=1))
|
| 229 |
+
else:
|
| 230 |
+
output.append(norm_x[:, self.nprompt:])
|
| 231 |
+
|
| 232 |
+
return torch.cat(output, dim=-1)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def get_cls_avg(output_i, cur_len, use_cls):
|
| 236 |
+
length_mask = torch.arange(output_i[0].shape[1]).to(output_i[0].device) < cur_len.unsqueeze(1)
|
| 237 |
+
cls = [torch.zeros_like(x[:, 0]) for x in output_i]
|
| 238 |
+
avg = [torch.sum(x * length_mask.unsqueeze(-1), dim=1) / (cur_len.unsqueeze(1) + 1e-6) for x in output_i]
|
| 239 |
+
return cls, avg
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def FrameASTModel(patch_h=64, patch_w=4, atst_dropout=0.1, **kwargs):
|
| 243 |
+
return FrameAST(
|
| 244 |
+
patch_h=patch_h,
|
| 245 |
+
patch_w=patch_w,
|
| 246 |
+
embed_dim=768,
|
| 247 |
+
depth=12,
|
| 248 |
+
num_heads=12,
|
| 249 |
+
qkv_bias=False,
|
| 250 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 251 |
+
drop_path_rate=atst_dropout,
|
| 252 |
+
drop_rate=atst_dropout,
|
| 253 |
+
**kwargs)
|
models/atstframe/transformer.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 6 |
+
if drop_prob == 0. or not training:
|
| 7 |
+
return x
|
| 8 |
+
keep_prob = 1 - drop_prob
|
| 9 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 10 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 11 |
+
random_tensor.floor_() # binarize
|
| 12 |
+
output = x.div(keep_prob) * random_tensor
|
| 13 |
+
return output
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DropPath(nn.Module):
|
| 17 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, drop_prob=None):
|
| 21 |
+
super(DropPath, self).__init__()
|
| 22 |
+
self.drop_prob = drop_prob
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Mlp(nn.Module):
|
| 29 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 30 |
+
super().__init__()
|
| 31 |
+
out_features = out_features or in_features
|
| 32 |
+
hidden_features = hidden_features or in_features
|
| 33 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 34 |
+
self.act = act_layer()
|
| 35 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 36 |
+
self.drop = nn.Dropout(drop)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
x = self.fc1(x)
|
| 40 |
+
x = self.act(x)
|
| 41 |
+
x = self.drop(x)
|
| 42 |
+
x = self.fc2(x)
|
| 43 |
+
x = self.drop(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Attention(nn.Module):
|
| 48 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.num_heads = num_heads
|
| 51 |
+
head_dim = dim // num_heads
|
| 52 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 53 |
+
|
| 54 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 55 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 56 |
+
self.proj = nn.Linear(dim, dim)
|
| 57 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 58 |
+
|
| 59 |
+
def forward(self, x, mask):
|
| 60 |
+
B, N, C = x.shape
|
| 61 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 62 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 63 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 64 |
+
if mask is not None:
|
| 65 |
+
attn += mask
|
| 66 |
+
|
| 67 |
+
attn = attn.softmax(dim=-1)
|
| 68 |
+
attn = self.attn_drop(attn)
|
| 69 |
+
|
| 70 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 71 |
+
x = self.proj(x)
|
| 72 |
+
x = self.proj_drop(x)
|
| 73 |
+
return x, attn
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Block(nn.Module):
|
| 77 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 78 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.norm1 = norm_layer(dim)
|
| 81 |
+
self.attn = Attention(
|
| 82 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 83 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 84 |
+
self.norm2 = norm_layer(dim)
|
| 85 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 86 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 87 |
+
|
| 88 |
+
def forward(self, x, length=None, return_attention=False):
|
| 89 |
+
|
| 90 |
+
# if length is not None:
|
| 91 |
+
# print(length)
|
| 92 |
+
# mask_att = get_attention_mask(x,length)
|
| 93 |
+
# else:
|
| 94 |
+
mask_att = None
|
| 95 |
+
|
| 96 |
+
y, attn = self.attn(self.norm1(x), mask_att)
|
| 97 |
+
x = x + self.drop_path(y)
|
| 98 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 99 |
+
if return_attention:
|
| 100 |
+
return x, attn
|
| 101 |
+
else:
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_attention_mask(x, length):
|
| 106 |
+
batch_size, max_len, _ = x.shape
|
| 107 |
+
# create mask for padded elements and zero-out them
|
| 108 |
+
mask = torch.arange(max_len, device=length.device).expand(batch_size, max_len) >= length[:, None]
|
| 109 |
+
# extend the mask to attention shape and set weight
|
| 110 |
+
mask = -10000.0 * mask[:, None, None, :]
|
| 111 |
+
mask = mask.expand(batch_size, 1, max_len, max_len).to(x.device)
|
| 112 |
+
return mask
|
models/beats/BEATs.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import LayerNorm
|
| 14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
| 15 |
+
|
| 16 |
+
from models.beats.backbone import (
|
| 17 |
+
TransformerEncoder,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BEATsConfig:
|
| 27 |
+
def __init__(self, cfg=None):
|
| 28 |
+
self.input_patch_size: int = 16 # path size of patch embedding
|
| 29 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
| 30 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
| 31 |
+
|
| 32 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
| 33 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
| 34 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
| 35 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
| 36 |
+
self.activation_fn: str = "gelu" # activation function to use
|
| 37 |
+
|
| 38 |
+
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
| 39 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
| 40 |
+
self.deep_norm: bool = True # apply deep_norm first in the transformer
|
| 41 |
+
|
| 42 |
+
# dropouts
|
| 43 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
| 44 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
| 45 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
| 46 |
+
self.encoder_layerdrop: float = 0.05 # probability of dropping a tarnsformer layer
|
| 47 |
+
self.dropout_input: float = 0.1 # dropout to apply to the input (after feat extr)
|
| 48 |
+
|
| 49 |
+
# positional embeddings
|
| 50 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
| 51 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
| 52 |
+
|
| 53 |
+
# relative position embedding
|
| 54 |
+
self.relative_position_embedding: bool = True # apply relative position embedding
|
| 55 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
| 56 |
+
self.max_distance: int = 800 # maximum distance for relative position embedding
|
| 57 |
+
self.gru_rel_pos: bool = True # apply gated relative position embedding
|
| 58 |
+
|
| 59 |
+
# label predictor
|
| 60 |
+
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
| 61 |
+
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
| 62 |
+
self.predictor_class: int = 527 # target class number for the predictor
|
| 63 |
+
|
| 64 |
+
if cfg is not None:
|
| 65 |
+
self.update(cfg)
|
| 66 |
+
|
| 67 |
+
def update(self, cfg: dict):
|
| 68 |
+
self.__dict__.update(cfg)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class BEATs(nn.Module):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
cfg: BEATsConfig,
|
| 75 |
+
) -> None:
|
| 76 |
+
super().__init__()
|
| 77 |
+
logger.info(f"BEATs Config: {cfg.__dict__}")
|
| 78 |
+
|
| 79 |
+
self.cfg = cfg
|
| 80 |
+
|
| 81 |
+
self.embed = cfg.embed_dim
|
| 82 |
+
self.post_extract_proj = (
|
| 83 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
| 84 |
+
if self.embed != cfg.encoder_embed_dim
|
| 85 |
+
else None
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.input_patch_size = cfg.input_patch_size
|
| 89 |
+
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
| 90 |
+
bias=cfg.conv_bias)
|
| 91 |
+
|
| 92 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 93 |
+
|
| 94 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
| 95 |
+
self.encoder = TransformerEncoder(cfg)
|
| 96 |
+
self.layer_norm = LayerNorm(self.embed)
|
| 97 |
+
|
| 98 |
+
if cfg.finetuned_model:
|
| 99 |
+
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
| 100 |
+
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
| 101 |
+
else:
|
| 102 |
+
self.predictor = None
|
| 103 |
+
|
| 104 |
+
def forward_padding_mask(
|
| 105 |
+
self,
|
| 106 |
+
features: torch.Tensor,
|
| 107 |
+
padding_mask: torch.Tensor,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
extra = padding_mask.size(1) % features.size(1)
|
| 110 |
+
if extra > 0:
|
| 111 |
+
padding_mask = padding_mask[:, :-extra]
|
| 112 |
+
padding_mask = padding_mask.view(
|
| 113 |
+
padding_mask.size(0), features.size(1), -1
|
| 114 |
+
)
|
| 115 |
+
padding_mask = padding_mask.all(-1)
|
| 116 |
+
return padding_mask
|
| 117 |
+
|
| 118 |
+
def preprocess(
|
| 119 |
+
self,
|
| 120 |
+
source: torch.Tensor,
|
| 121 |
+
fbank_mean: float = 15.41663,
|
| 122 |
+
fbank_std: float = 6.55582,
|
| 123 |
+
) -> torch.Tensor:
|
| 124 |
+
fbanks = []
|
| 125 |
+
for waveform in source:
|
| 126 |
+
waveform = waveform.unsqueeze(0) * 2 ** 15
|
| 127 |
+
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
| 128 |
+
fbanks.append(fbank)
|
| 129 |
+
fbank = torch.stack(fbanks, dim=0)
|
| 130 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
| 131 |
+
return fbank
|
| 132 |
+
|
| 133 |
+
def extract_features(
|
| 134 |
+
self,
|
| 135 |
+
source: torch.Tensor,
|
| 136 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 137 |
+
fbank_mean: float = 15.41663,
|
| 138 |
+
fbank_std: float = 6.55582,
|
| 139 |
+
do_preprocess: bool = True,
|
| 140 |
+
):
|
| 141 |
+
if do_preprocess:
|
| 142 |
+
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
| 143 |
+
|
| 144 |
+
if padding_mask is not None:
|
| 145 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
| 146 |
+
|
| 147 |
+
fbank = fbank.unsqueeze(1)
|
| 148 |
+
else:
|
| 149 |
+
fbank = source
|
| 150 |
+
features = self.patch_embedding(fbank)
|
| 151 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
| 152 |
+
features = features.transpose(1, 2)
|
| 153 |
+
features = self.layer_norm(features)
|
| 154 |
+
|
| 155 |
+
if padding_mask is not None:
|
| 156 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 157 |
+
|
| 158 |
+
if self.post_extract_proj is not None:
|
| 159 |
+
features = self.post_extract_proj(features)
|
| 160 |
+
|
| 161 |
+
x = self.dropout_input(features)
|
| 162 |
+
|
| 163 |
+
x, layer_results = self.encoder(
|
| 164 |
+
x,
|
| 165 |
+
padding_mask=padding_mask,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if self.predictor is not None:
|
| 169 |
+
x = self.predictor_dropout(x)
|
| 170 |
+
logits = self.predictor(x)
|
| 171 |
+
|
| 172 |
+
if padding_mask is not None and padding_mask.any():
|
| 173 |
+
logits[padding_mask] = 0
|
| 174 |
+
logits = logits.sum(dim=1)
|
| 175 |
+
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
|
| 176 |
+
else:
|
| 177 |
+
logits = logits.mean(dim=1)
|
| 178 |
+
|
| 179 |
+
lprobs = torch.sigmoid(logits)
|
| 180 |
+
|
| 181 |
+
return lprobs, padding_mask
|
| 182 |
+
else:
|
| 183 |
+
return x, padding_mask
|
models/beats/BEATs_wrapper.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from models.beats.BEATs import BEATsConfig, BEATs
|
| 4 |
+
from models.transformer_wrapper import BaseModelWrapper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BEATsWrapper(BaseModelWrapper):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
cfg = BEATsConfig()
|
| 11 |
+
self.beats = BEATs(cfg)
|
| 12 |
+
|
| 13 |
+
def mel_forward(self, x):
|
| 14 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
| 15 |
+
mel = self.beats.preprocess(x)
|
| 16 |
+
mel = mel.unsqueeze(1).transpose(2, 3)
|
| 17 |
+
return mel
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x = x.transpose(2, 3)
|
| 21 |
+
features = self.beats.extract_features(x, do_preprocess=False)[0]
|
| 22 |
+
return features
|
| 23 |
+
|
| 24 |
+
def separate_params(self):
|
| 25 |
+
pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
|
| 26 |
+
for k, p in self.named_parameters():
|
| 27 |
+
if ".layers.0." in k:
|
| 28 |
+
pt_params[0].append(p)
|
| 29 |
+
elif ".layers.1." in k:
|
| 30 |
+
pt_params[1].append(p)
|
| 31 |
+
elif ".layers.2." in k:
|
| 32 |
+
pt_params[2].append(p)
|
| 33 |
+
elif ".layers.3." in k:
|
| 34 |
+
pt_params[3].append(p)
|
| 35 |
+
elif ".layers.4." in k:
|
| 36 |
+
pt_params[4].append(p)
|
| 37 |
+
elif ".layers.5." in k:
|
| 38 |
+
pt_params[5].append(p)
|
| 39 |
+
elif ".layers.6." in k:
|
| 40 |
+
pt_params[6].append(p)
|
| 41 |
+
elif ".layers.7." in k:
|
| 42 |
+
pt_params[7].append(p)
|
| 43 |
+
elif ".layers.8." in k:
|
| 44 |
+
pt_params[8].append(p)
|
| 45 |
+
elif ".layers.9." in k:
|
| 46 |
+
pt_params[9].append(p)
|
| 47 |
+
elif ".layers.10." in k:
|
| 48 |
+
pt_params[10].append(p)
|
| 49 |
+
elif ".layers.11." in k:
|
| 50 |
+
pt_params[11].append(p)
|
| 51 |
+
elif (".post_extract_proj." in k or ".patch_embedding." in k or '.pos_conv.' in k
|
| 52 |
+
or 'beats.layer_norm.' in k or "beats.encoder.layer_norm." in k):
|
| 53 |
+
pt_params[0].append(p)
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f"Check separate params for BEATs! Unknown key: {k}")
|
| 56 |
+
return list(reversed(pt_params))
|
models/beats/Tokenizers.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import LayerNorm
|
| 14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
| 15 |
+
|
| 16 |
+
from backbone import (
|
| 17 |
+
TransformerEncoder,
|
| 18 |
+
)
|
| 19 |
+
from quantizer import (
|
| 20 |
+
NormEMAVectorQuantizer,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import logging
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TokenizersConfig:
|
| 30 |
+
def __init__(self, cfg=None):
|
| 31 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
| 32 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
| 33 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
| 34 |
+
|
| 35 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
| 36 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
| 37 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
| 38 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
| 39 |
+
self.activation_fn: str = "gelu" # activation function to use
|
| 40 |
+
|
| 41 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
| 42 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
| 43 |
+
|
| 44 |
+
# dropouts
|
| 45 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
| 46 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
| 47 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
| 48 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
| 49 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
| 50 |
+
|
| 51 |
+
# positional embeddings
|
| 52 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
| 53 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
| 54 |
+
|
| 55 |
+
# relative position embedding
|
| 56 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
| 57 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
| 58 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
| 59 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
| 60 |
+
|
| 61 |
+
# quantizer
|
| 62 |
+
self.quant_n: int = 1024 # codebook number in quantizer
|
| 63 |
+
self.quant_dim: int = 256 # codebook dimension in quantizer
|
| 64 |
+
|
| 65 |
+
if cfg is not None:
|
| 66 |
+
self.update(cfg)
|
| 67 |
+
|
| 68 |
+
def update(self, cfg: dict):
|
| 69 |
+
self.__dict__.update(cfg)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Tokenizers(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
cfg: TokenizersConfig,
|
| 76 |
+
) -> None:
|
| 77 |
+
super().__init__()
|
| 78 |
+
logger.info(f"Tokenizers Config: {cfg.__dict__}")
|
| 79 |
+
|
| 80 |
+
self.cfg = cfg
|
| 81 |
+
|
| 82 |
+
self.embed = cfg.embed_dim
|
| 83 |
+
self.post_extract_proj = (
|
| 84 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
| 85 |
+
if self.embed != cfg.encoder_embed_dim
|
| 86 |
+
else None
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.input_patch_size = cfg.input_patch_size
|
| 90 |
+
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
| 91 |
+
bias=cfg.conv_bias)
|
| 92 |
+
|
| 93 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 94 |
+
|
| 95 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
| 96 |
+
self.encoder = TransformerEncoder(cfg)
|
| 97 |
+
self.layer_norm = LayerNorm(self.embed)
|
| 98 |
+
|
| 99 |
+
self.quantize = NormEMAVectorQuantizer(
|
| 100 |
+
n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
|
| 101 |
+
)
|
| 102 |
+
self.quant_n = cfg.quant_n
|
| 103 |
+
self.quantize_layer = nn.Sequential(
|
| 104 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
|
| 105 |
+
nn.Tanh(),
|
| 106 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def forward_padding_mask(
|
| 110 |
+
self,
|
| 111 |
+
features: torch.Tensor,
|
| 112 |
+
padding_mask: torch.Tensor,
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
extra = padding_mask.size(1) % features.size(1)
|
| 115 |
+
if extra > 0:
|
| 116 |
+
padding_mask = padding_mask[:, :-extra]
|
| 117 |
+
padding_mask = padding_mask.view(
|
| 118 |
+
padding_mask.size(0), features.size(1), -1
|
| 119 |
+
)
|
| 120 |
+
padding_mask = padding_mask.all(-1)
|
| 121 |
+
return padding_mask
|
| 122 |
+
|
| 123 |
+
def preprocess(
|
| 124 |
+
self,
|
| 125 |
+
source: torch.Tensor,
|
| 126 |
+
fbank_mean: float = 15.41663,
|
| 127 |
+
fbank_std: float = 6.55582,
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
fbanks = []
|
| 130 |
+
for waveform in source:
|
| 131 |
+
waveform = waveform.unsqueeze(0) * 2 ** 15
|
| 132 |
+
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
| 133 |
+
fbanks.append(fbank)
|
| 134 |
+
fbank = torch.stack(fbanks, dim=0)
|
| 135 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
| 136 |
+
return fbank
|
| 137 |
+
|
| 138 |
+
def extract_labels(
|
| 139 |
+
self,
|
| 140 |
+
source: torch.Tensor,
|
| 141 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 142 |
+
fbank_mean: float = 15.41663,
|
| 143 |
+
fbank_std: float = 6.55582,
|
| 144 |
+
):
|
| 145 |
+
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
| 146 |
+
|
| 147 |
+
if padding_mask is not None:
|
| 148 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
| 149 |
+
|
| 150 |
+
fbank = fbank.unsqueeze(1)
|
| 151 |
+
features = self.patch_embedding(fbank)
|
| 152 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
| 153 |
+
features = features.transpose(1, 2)
|
| 154 |
+
features = self.layer_norm(features)
|
| 155 |
+
|
| 156 |
+
if padding_mask is not None:
|
| 157 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 158 |
+
|
| 159 |
+
if self.post_extract_proj is not None:
|
| 160 |
+
features = self.post_extract_proj(features)
|
| 161 |
+
|
| 162 |
+
x = self.dropout_input(features)
|
| 163 |
+
|
| 164 |
+
x, layer_results = self.encoder(
|
| 165 |
+
x,
|
| 166 |
+
padding_mask=padding_mask,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
quantize_input = self.quantize_layer(x)
|
| 170 |
+
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
|
| 171 |
+
|
| 172 |
+
return embed_ind
|
models/beats/backbone.py
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
from typing import Dict, Optional, Tuple
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch.nn import LayerNorm, Parameter
|
| 17 |
+
from models.beats.modules import (
|
| 18 |
+
GradMultiply,
|
| 19 |
+
SamePad,
|
| 20 |
+
get_activation_fn,
|
| 21 |
+
GLU_Linear,
|
| 22 |
+
quant_noise,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TransformerEncoder(nn.Module):
|
| 27 |
+
def __init__(self, args):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.dropout = args.dropout
|
| 31 |
+
self.embedding_dim = args.encoder_embed_dim
|
| 32 |
+
|
| 33 |
+
self.pos_conv = nn.Conv1d(
|
| 34 |
+
self.embedding_dim,
|
| 35 |
+
self.embedding_dim,
|
| 36 |
+
kernel_size=args.conv_pos,
|
| 37 |
+
padding=args.conv_pos // 2,
|
| 38 |
+
groups=args.conv_pos_groups,
|
| 39 |
+
)
|
| 40 |
+
dropout = 0
|
| 41 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
| 42 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
| 43 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
| 44 |
+
|
| 45 |
+
self.pos_conv = torch.nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
|
| 46 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
| 47 |
+
|
| 48 |
+
if hasattr(args, "relative_position_embedding"):
|
| 49 |
+
self.relative_position_embedding = args.relative_position_embedding
|
| 50 |
+
self.num_buckets = args.num_buckets
|
| 51 |
+
self.max_distance = args.max_distance
|
| 52 |
+
else:
|
| 53 |
+
self.relative_position_embedding = False
|
| 54 |
+
self.num_buckets = 0
|
| 55 |
+
self.max_distance = 0
|
| 56 |
+
|
| 57 |
+
self.layers = nn.ModuleList(
|
| 58 |
+
[
|
| 59 |
+
TransformerSentenceEncoderLayer(
|
| 60 |
+
embedding_dim=self.embedding_dim,
|
| 61 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
| 62 |
+
num_attention_heads=args.encoder_attention_heads,
|
| 63 |
+
dropout=self.dropout,
|
| 64 |
+
attention_dropout=args.attention_dropout,
|
| 65 |
+
activation_dropout=args.activation_dropout,
|
| 66 |
+
activation_fn=args.activation_fn,
|
| 67 |
+
layer_norm_first=args.layer_norm_first,
|
| 68 |
+
deep_norm=args.deep_norm,
|
| 69 |
+
has_relative_attention_bias=self.relative_position_embedding,
|
| 70 |
+
num_buckets=self.num_buckets,
|
| 71 |
+
max_distance=self.max_distance,
|
| 72 |
+
gru_rel_pos=args.gru_rel_pos,
|
| 73 |
+
encoder_layers=args.encoder_layers,
|
| 74 |
+
)
|
| 75 |
+
for i in range(args.encoder_layers)
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
if self.relative_position_embedding:
|
| 79 |
+
for i in range(1, args.encoder_layers):
|
| 80 |
+
del self.layers[i].self_attn.relative_attention_bias
|
| 81 |
+
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
|
| 82 |
+
|
| 83 |
+
self.layer_norm_first = args.layer_norm_first
|
| 84 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
| 85 |
+
self.layerdrop = args.encoder_layerdrop
|
| 86 |
+
|
| 87 |
+
self.apply(init_bert_params)
|
| 88 |
+
|
| 89 |
+
if args.deep_norm:
|
| 90 |
+
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
| 91 |
+
for i in range(args.encoder_layers):
|
| 92 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
| 93 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
| 94 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
| 95 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
|
| 96 |
+
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
| 97 |
+
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
| 98 |
+
|
| 99 |
+
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
|
| 100 |
+
|
| 101 |
+
def forward(self, x, padding_mask=None, layer=None):
|
| 102 |
+
x, layer_results = self.extract_features(x, padding_mask, layer)
|
| 103 |
+
|
| 104 |
+
if self.layer_norm_first and layer is None:
|
| 105 |
+
x = self.layer_norm(x)
|
| 106 |
+
|
| 107 |
+
return x, layer_results
|
| 108 |
+
|
| 109 |
+
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
| 110 |
+
|
| 111 |
+
if padding_mask is not None:
|
| 112 |
+
x[padding_mask] = 0
|
| 113 |
+
|
| 114 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
| 115 |
+
x_conv = x_conv.transpose(1, 2)
|
| 116 |
+
x = x + x_conv
|
| 117 |
+
|
| 118 |
+
if not self.layer_norm_first:
|
| 119 |
+
x = self.layer_norm(x)
|
| 120 |
+
|
| 121 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 122 |
+
|
| 123 |
+
# B x T x C -> T x B x C
|
| 124 |
+
x = x.transpose(0, 1)
|
| 125 |
+
|
| 126 |
+
layer_results = []
|
| 127 |
+
z = None
|
| 128 |
+
if tgt_layer is not None:
|
| 129 |
+
layer_results.append((x, z))
|
| 130 |
+
r = None
|
| 131 |
+
pos_bias = None
|
| 132 |
+
for i, layer in enumerate(self.layers):
|
| 133 |
+
if self.layer_wise_gradient_decay_ratio != 1.0:
|
| 134 |
+
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
| 135 |
+
dropout_probability = np.random.random()
|
| 136 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
| 137 |
+
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
|
| 138 |
+
if tgt_layer is not None:
|
| 139 |
+
layer_results.append((x, z))
|
| 140 |
+
if i == tgt_layer:
|
| 141 |
+
r = x
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
if r is not None:
|
| 145 |
+
x = r
|
| 146 |
+
|
| 147 |
+
# T x B x C -> B x T x C
|
| 148 |
+
x = x.transpose(0, 1)
|
| 149 |
+
|
| 150 |
+
return x, layer_results
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
embedding_dim: float = 768,
|
| 157 |
+
ffn_embedding_dim: float = 3072,
|
| 158 |
+
num_attention_heads: float = 8,
|
| 159 |
+
dropout: float = 0.1,
|
| 160 |
+
attention_dropout: float = 0.1,
|
| 161 |
+
activation_dropout: float = 0.1,
|
| 162 |
+
activation_fn: str = "relu",
|
| 163 |
+
layer_norm_first: bool = False,
|
| 164 |
+
deep_norm: bool = False,
|
| 165 |
+
has_relative_attention_bias: bool = False,
|
| 166 |
+
num_buckets: int = 0,
|
| 167 |
+
max_distance: int = 0,
|
| 168 |
+
rescale_init: bool = False,
|
| 169 |
+
gru_rel_pos: bool = False,
|
| 170 |
+
encoder_layers: int = 0,
|
| 171 |
+
) -> None:
|
| 172 |
+
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.embedding_dim = embedding_dim
|
| 175 |
+
self.dropout = dropout
|
| 176 |
+
self.activation_dropout = activation_dropout
|
| 177 |
+
|
| 178 |
+
self.activation_name = activation_fn
|
| 179 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
| 180 |
+
self.self_attn = MultiheadAttention(
|
| 181 |
+
self.embedding_dim,
|
| 182 |
+
num_attention_heads,
|
| 183 |
+
dropout=attention_dropout,
|
| 184 |
+
self_attention=True,
|
| 185 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
| 186 |
+
num_buckets=num_buckets,
|
| 187 |
+
max_distance=max_distance,
|
| 188 |
+
rescale_init=rescale_init,
|
| 189 |
+
gru_rel_pos=gru_rel_pos,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 193 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
| 194 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 195 |
+
|
| 196 |
+
self.layer_norm_first = layer_norm_first
|
| 197 |
+
|
| 198 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
| 199 |
+
|
| 200 |
+
if self.activation_name == "glu":
|
| 201 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
| 202 |
+
else:
|
| 203 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
| 204 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
| 205 |
+
|
| 206 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
| 207 |
+
|
| 208 |
+
self.deep_norm = deep_norm
|
| 209 |
+
if self.deep_norm:
|
| 210 |
+
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
| 211 |
+
else:
|
| 212 |
+
self.deep_norm_alpha = 1
|
| 213 |
+
|
| 214 |
+
def forward(
|
| 215 |
+
self,
|
| 216 |
+
x: torch.Tensor,
|
| 217 |
+
self_attn_mask: torch.Tensor = None,
|
| 218 |
+
self_attn_padding_mask: torch.Tensor = None,
|
| 219 |
+
need_weights: bool = False,
|
| 220 |
+
pos_bias=None
|
| 221 |
+
):
|
| 222 |
+
residual = x
|
| 223 |
+
|
| 224 |
+
if self.layer_norm_first:
|
| 225 |
+
x = self.self_attn_layer_norm(x)
|
| 226 |
+
x, attn, pos_bias = self.self_attn(
|
| 227 |
+
query=x,
|
| 228 |
+
key=x,
|
| 229 |
+
value=x,
|
| 230 |
+
key_padding_mask=self_attn_padding_mask,
|
| 231 |
+
need_weights=False,
|
| 232 |
+
attn_mask=self_attn_mask,
|
| 233 |
+
position_bias=pos_bias
|
| 234 |
+
)
|
| 235 |
+
x = self.dropout1(x)
|
| 236 |
+
x = residual + x
|
| 237 |
+
|
| 238 |
+
residual = x
|
| 239 |
+
x = self.final_layer_norm(x)
|
| 240 |
+
if self.activation_name == "glu":
|
| 241 |
+
x = self.fc1(x)
|
| 242 |
+
else:
|
| 243 |
+
x = self.activation_fn(self.fc1(x))
|
| 244 |
+
x = self.dropout2(x)
|
| 245 |
+
x = self.fc2(x)
|
| 246 |
+
x = self.dropout3(x)
|
| 247 |
+
x = residual + x
|
| 248 |
+
else:
|
| 249 |
+
x, attn, pos_bias = self.self_attn(
|
| 250 |
+
query=x,
|
| 251 |
+
key=x,
|
| 252 |
+
value=x,
|
| 253 |
+
key_padding_mask=self_attn_padding_mask,
|
| 254 |
+
need_weights=need_weights,
|
| 255 |
+
attn_mask=self_attn_mask,
|
| 256 |
+
position_bias=pos_bias
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
x = self.dropout1(x)
|
| 260 |
+
x = residual * self.deep_norm_alpha + x
|
| 261 |
+
|
| 262 |
+
x = self.self_attn_layer_norm(x)
|
| 263 |
+
|
| 264 |
+
residual = x
|
| 265 |
+
if self.activation_name == "glu":
|
| 266 |
+
x = self.fc1(x)
|
| 267 |
+
else:
|
| 268 |
+
x = self.activation_fn(self.fc1(x))
|
| 269 |
+
x = self.dropout2(x)
|
| 270 |
+
x = self.fc2(x)
|
| 271 |
+
x = self.dropout3(x)
|
| 272 |
+
x = residual * self.deep_norm_alpha + x
|
| 273 |
+
x = self.final_layer_norm(x)
|
| 274 |
+
|
| 275 |
+
return x, attn, pos_bias
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class MultiheadAttention(nn.Module):
|
| 279 |
+
"""Multi-headed attention.
|
| 280 |
+
|
| 281 |
+
See "Attention Is All You Need" for more details.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
embed_dim,
|
| 287 |
+
num_heads,
|
| 288 |
+
kdim=None,
|
| 289 |
+
vdim=None,
|
| 290 |
+
dropout=0.0,
|
| 291 |
+
bias=True,
|
| 292 |
+
add_bias_kv=False,
|
| 293 |
+
add_zero_attn=False,
|
| 294 |
+
self_attention=False,
|
| 295 |
+
encoder_decoder_attention=False,
|
| 296 |
+
q_noise=0.0,
|
| 297 |
+
qn_block_size=8,
|
| 298 |
+
has_relative_attention_bias=False,
|
| 299 |
+
num_buckets=32,
|
| 300 |
+
max_distance=128,
|
| 301 |
+
gru_rel_pos=False,
|
| 302 |
+
rescale_init=False,
|
| 303 |
+
):
|
| 304 |
+
super().__init__()
|
| 305 |
+
self.embed_dim = embed_dim
|
| 306 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 307 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 308 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 309 |
+
|
| 310 |
+
self.num_heads = num_heads
|
| 311 |
+
self.dropout_module = nn.Dropout(dropout)
|
| 312 |
+
|
| 313 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
| 314 |
+
self.num_buckets = num_buckets
|
| 315 |
+
self.max_distance = max_distance
|
| 316 |
+
if self.has_relative_attention_bias:
|
| 317 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
| 318 |
+
|
| 319 |
+
self.head_dim = embed_dim // num_heads
|
| 320 |
+
self.q_head_dim = self.head_dim
|
| 321 |
+
self.k_head_dim = self.head_dim
|
| 322 |
+
assert (
|
| 323 |
+
self.head_dim * num_heads == self.embed_dim
|
| 324 |
+
), "embed_dim must be divisible by num_heads"
|
| 325 |
+
self.scaling = self.head_dim ** -0.5
|
| 326 |
+
|
| 327 |
+
self.self_attention = self_attention
|
| 328 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 329 |
+
|
| 330 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
| 331 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
k_bias = True
|
| 335 |
+
if rescale_init:
|
| 336 |
+
k_bias = False
|
| 337 |
+
|
| 338 |
+
k_embed_dim = embed_dim
|
| 339 |
+
q_embed_dim = embed_dim
|
| 340 |
+
|
| 341 |
+
self.k_proj = quant_noise(
|
| 342 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
| 343 |
+
)
|
| 344 |
+
self.v_proj = quant_noise(
|
| 345 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 346 |
+
)
|
| 347 |
+
self.q_proj = quant_noise(
|
| 348 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
self.out_proj = quant_noise(
|
| 352 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if add_bias_kv:
|
| 356 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 357 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 358 |
+
else:
|
| 359 |
+
self.bias_k = self.bias_v = None
|
| 360 |
+
|
| 361 |
+
self.add_zero_attn = add_zero_attn
|
| 362 |
+
|
| 363 |
+
self.gru_rel_pos = gru_rel_pos
|
| 364 |
+
if self.gru_rel_pos:
|
| 365 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
| 366 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
| 367 |
+
|
| 368 |
+
self.reset_parameters()
|
| 369 |
+
|
| 370 |
+
def reset_parameters(self):
|
| 371 |
+
if self.qkv_same_dim:
|
| 372 |
+
# Empirically observed the convergence to be much better with
|
| 373 |
+
# the scaled initialization
|
| 374 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 375 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 376 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 377 |
+
else:
|
| 378 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 379 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 380 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 381 |
+
|
| 382 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 383 |
+
if self.out_proj.bias is not None:
|
| 384 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 385 |
+
if self.bias_k is not None:
|
| 386 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 387 |
+
if self.bias_v is not None:
|
| 388 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 389 |
+
if self.has_relative_attention_bias:
|
| 390 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
| 391 |
+
|
| 392 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
| 393 |
+
num_buckets = self.num_buckets
|
| 394 |
+
max_distance = self.max_distance
|
| 395 |
+
relative_buckets = 0
|
| 396 |
+
|
| 397 |
+
if bidirectional:
|
| 398 |
+
num_buckets = num_buckets // 2
|
| 399 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
| 400 |
+
relative_positions = torch.abs(relative_positions)
|
| 401 |
+
else:
|
| 402 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
| 403 |
+
|
| 404 |
+
max_exact = num_buckets // 2
|
| 405 |
+
is_small = relative_positions < max_exact
|
| 406 |
+
|
| 407 |
+
relative_postion_if_large = max_exact + (
|
| 408 |
+
torch.log(relative_positions.float() / max_exact)
|
| 409 |
+
/ math.log(max_distance / max_exact)
|
| 410 |
+
* (num_buckets - max_exact)
|
| 411 |
+
).to(torch.long)
|
| 412 |
+
relative_postion_if_large = torch.min(
|
| 413 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
| 417 |
+
return relative_buckets
|
| 418 |
+
|
| 419 |
+
def compute_bias(self, query_length, key_length):
|
| 420 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
| 421 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
| 422 |
+
relative_position = memory_position - context_position
|
| 423 |
+
relative_position_bucket = self._relative_positions_bucket(
|
| 424 |
+
relative_position,
|
| 425 |
+
bidirectional=True
|
| 426 |
+
)
|
| 427 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
| 428 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
| 429 |
+
values = values.permute([2, 0, 1])
|
| 430 |
+
return values
|
| 431 |
+
|
| 432 |
+
def forward(
|
| 433 |
+
self,
|
| 434 |
+
query,
|
| 435 |
+
key: Optional[Tensor],
|
| 436 |
+
value: Optional[Tensor],
|
| 437 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 438 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 439 |
+
need_weights: bool = True,
|
| 440 |
+
static_kv: bool = False,
|
| 441 |
+
attn_mask: Optional[Tensor] = None,
|
| 442 |
+
before_softmax: bool = False,
|
| 443 |
+
need_head_weights: bool = False,
|
| 444 |
+
position_bias: Optional[Tensor] = None
|
| 445 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 446 |
+
"""Input shape: Time x Batch x Channel
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 450 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 451 |
+
padding elements are indicated by 1s.
|
| 452 |
+
need_weights (bool, optional): return the attention weights,
|
| 453 |
+
averaged over heads (default: False).
|
| 454 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 455 |
+
implement causal attention, where the mask prevents the
|
| 456 |
+
attention from looking forward in time (default: None).
|
| 457 |
+
before_softmax (bool, optional): return the raw attention
|
| 458 |
+
weights and values before the attention softmax.
|
| 459 |
+
need_head_weights (bool, optional): return the attention
|
| 460 |
+
weights for each head. Implies *need_weights*. Default:
|
| 461 |
+
return the average attention weights over all heads.
|
| 462 |
+
"""
|
| 463 |
+
if need_head_weights:
|
| 464 |
+
need_weights = True
|
| 465 |
+
|
| 466 |
+
is_tpu = query.device.type == "xla"
|
| 467 |
+
|
| 468 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 469 |
+
src_len = tgt_len
|
| 470 |
+
assert embed_dim == self.embed_dim
|
| 471 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 472 |
+
if key is not None:
|
| 473 |
+
src_len, key_bsz, _ = key.size()
|
| 474 |
+
if not torch.jit.is_scripting():
|
| 475 |
+
assert key_bsz == bsz
|
| 476 |
+
assert value is not None
|
| 477 |
+
assert src_len, bsz == value.shape[:2]
|
| 478 |
+
|
| 479 |
+
if self.has_relative_attention_bias and position_bias is None:
|
| 480 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
| 481 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
| 482 |
+
|
| 483 |
+
if incremental_state is not None:
|
| 484 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 485 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 486 |
+
# previous time steps are cached - no need to recompute
|
| 487 |
+
# key and value if they are static
|
| 488 |
+
if static_kv:
|
| 489 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 490 |
+
key = value = None
|
| 491 |
+
else:
|
| 492 |
+
saved_state = None
|
| 493 |
+
|
| 494 |
+
if self.self_attention:
|
| 495 |
+
q = self.q_proj(query)
|
| 496 |
+
k = self.k_proj(query)
|
| 497 |
+
v = self.v_proj(query)
|
| 498 |
+
elif self.encoder_decoder_attention:
|
| 499 |
+
# encoder-decoder attention
|
| 500 |
+
q = self.q_proj(query)
|
| 501 |
+
if key is None:
|
| 502 |
+
assert value is None
|
| 503 |
+
k = v = None
|
| 504 |
+
else:
|
| 505 |
+
k = self.k_proj(key)
|
| 506 |
+
v = self.v_proj(key)
|
| 507 |
+
|
| 508 |
+
else:
|
| 509 |
+
assert key is not None and value is not None
|
| 510 |
+
q = self.q_proj(query)
|
| 511 |
+
k = self.k_proj(key)
|
| 512 |
+
v = self.v_proj(value)
|
| 513 |
+
q *= self.scaling
|
| 514 |
+
alpha = 32
|
| 515 |
+
q *= 1 / alpha
|
| 516 |
+
|
| 517 |
+
if self.bias_k is not None:
|
| 518 |
+
assert self.bias_v is not None
|
| 519 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 520 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 521 |
+
if attn_mask is not None:
|
| 522 |
+
attn_mask = torch.cat(
|
| 523 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 524 |
+
)
|
| 525 |
+
if key_padding_mask is not None:
|
| 526 |
+
key_padding_mask = torch.cat(
|
| 527 |
+
[
|
| 528 |
+
key_padding_mask,
|
| 529 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
| 530 |
+
],
|
| 531 |
+
dim=1,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
q = (
|
| 535 |
+
q.contiguous()
|
| 536 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
| 537 |
+
.transpose(0, 1)
|
| 538 |
+
)
|
| 539 |
+
if k is not None:
|
| 540 |
+
k = (
|
| 541 |
+
k.contiguous()
|
| 542 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
| 543 |
+
.transpose(0, 1)
|
| 544 |
+
)
|
| 545 |
+
if v is not None:
|
| 546 |
+
v = (
|
| 547 |
+
v.contiguous()
|
| 548 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 549 |
+
.transpose(0, 1)
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
if saved_state is not None:
|
| 553 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 554 |
+
if "prev_key" in saved_state:
|
| 555 |
+
_prev_key = saved_state["prev_key"]
|
| 556 |
+
assert _prev_key is not None
|
| 557 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
| 558 |
+
if static_kv:
|
| 559 |
+
k = prev_key
|
| 560 |
+
else:
|
| 561 |
+
assert k is not None
|
| 562 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 563 |
+
src_len = k.size(1)
|
| 564 |
+
if "prev_value" in saved_state:
|
| 565 |
+
_prev_value = saved_state["prev_value"]
|
| 566 |
+
assert _prev_value is not None
|
| 567 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
| 568 |
+
if static_kv:
|
| 569 |
+
v = prev_value
|
| 570 |
+
else:
|
| 571 |
+
assert v is not None
|
| 572 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 573 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
| 574 |
+
if "prev_key_padding_mask" in saved_state:
|
| 575 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 576 |
+
assert k is not None and v is not None
|
| 577 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
| 578 |
+
key_padding_mask=key_padding_mask,
|
| 579 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
| 580 |
+
batch_size=bsz,
|
| 581 |
+
src_len=k.size(1),
|
| 582 |
+
static_kv=static_kv,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 586 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 587 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 588 |
+
# In this branch incremental_state is never None
|
| 589 |
+
assert incremental_state is not None
|
| 590 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 591 |
+
assert k is not None
|
| 592 |
+
assert k.size(1) == src_len
|
| 593 |
+
|
| 594 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 595 |
+
# not supporting Optional types.
|
| 596 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
| 597 |
+
key_padding_mask = None
|
| 598 |
+
|
| 599 |
+
if key_padding_mask is not None:
|
| 600 |
+
assert key_padding_mask.size(0) == bsz
|
| 601 |
+
assert key_padding_mask.size(1) == src_len
|
| 602 |
+
|
| 603 |
+
if self.add_zero_attn:
|
| 604 |
+
assert v is not None
|
| 605 |
+
src_len += 1
|
| 606 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 607 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 608 |
+
if attn_mask is not None:
|
| 609 |
+
attn_mask = torch.cat(
|
| 610 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 611 |
+
)
|
| 612 |
+
if key_padding_mask is not None:
|
| 613 |
+
key_padding_mask = torch.cat(
|
| 614 |
+
[
|
| 615 |
+
key_padding_mask,
|
| 616 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
| 617 |
+
key_padding_mask
|
| 618 |
+
),
|
| 619 |
+
],
|
| 620 |
+
dim=1,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 624 |
+
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
| 625 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 626 |
+
|
| 627 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 628 |
+
|
| 629 |
+
if attn_mask is not None:
|
| 630 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 631 |
+
attn_weights += attn_mask
|
| 632 |
+
|
| 633 |
+
if key_padding_mask is not None:
|
| 634 |
+
# don't attend to padding symbols
|
| 635 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 636 |
+
if not is_tpu:
|
| 637 |
+
attn_weights = attn_weights.masked_fill(
|
| 638 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
| 639 |
+
float("-inf"),
|
| 640 |
+
)
|
| 641 |
+
else:
|
| 642 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 643 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
| 644 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 645 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 646 |
+
|
| 647 |
+
if before_softmax:
|
| 648 |
+
return attn_weights, v, position_bias
|
| 649 |
+
|
| 650 |
+
if position_bias is not None:
|
| 651 |
+
attn_mask_rel_pos = position_bias
|
| 652 |
+
if self.gru_rel_pos == 1:
|
| 653 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
|
| 654 |
+
_B, _H, _L, __ = query_layer.size()
|
| 655 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
| 656 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
| 657 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 658 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
| 659 |
+
|
| 660 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
| 661 |
+
|
| 662 |
+
attn_weights = attn_weights + attn_mask_rel_pos
|
| 663 |
+
|
| 664 |
+
attn_weights_float = F.softmax(
|
| 665 |
+
attn_weights, dim=-1
|
| 666 |
+
)
|
| 667 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 668 |
+
attn_probs = self.dropout_module(attn_weights)
|
| 669 |
+
|
| 670 |
+
assert v is not None
|
| 671 |
+
attn = torch.bmm(attn_probs, v)
|
| 672 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 673 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 674 |
+
attn = self.out_proj(attn)
|
| 675 |
+
attn_weights: Optional[Tensor] = None
|
| 676 |
+
if need_weights:
|
| 677 |
+
attn_weights = attn_weights_float.view(
|
| 678 |
+
bsz, self.num_heads, tgt_len, src_len
|
| 679 |
+
).transpose(1, 0)
|
| 680 |
+
if not need_head_weights:
|
| 681 |
+
# average attention weights over heads
|
| 682 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 683 |
+
|
| 684 |
+
return attn, attn_weights, position_bias
|
| 685 |
+
|
| 686 |
+
@staticmethod
|
| 687 |
+
def _append_prev_key_padding_mask(
|
| 688 |
+
key_padding_mask: Optional[Tensor],
|
| 689 |
+
prev_key_padding_mask: Optional[Tensor],
|
| 690 |
+
batch_size: int,
|
| 691 |
+
src_len: int,
|
| 692 |
+
static_kv: bool,
|
| 693 |
+
) -> Optional[Tensor]:
|
| 694 |
+
# saved key padding masks have shape (bsz, seq_len)
|
| 695 |
+
if prev_key_padding_mask is not None and static_kv:
|
| 696 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 697 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 698 |
+
new_key_padding_mask = torch.cat(
|
| 699 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
| 700 |
+
)
|
| 701 |
+
# During incremental decoding, as the padding token enters and
|
| 702 |
+
# leaves the frame, there will be a time when prev or current
|
| 703 |
+
# is None
|
| 704 |
+
elif prev_key_padding_mask is not None:
|
| 705 |
+
if src_len > prev_key_padding_mask.size(1):
|
| 706 |
+
filler = torch.zeros(
|
| 707 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 708 |
+
device=prev_key_padding_mask.device,
|
| 709 |
+
)
|
| 710 |
+
new_key_padding_mask = torch.cat(
|
| 711 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
| 715 |
+
elif key_padding_mask is not None:
|
| 716 |
+
if src_len > key_padding_mask.size(1):
|
| 717 |
+
filler = torch.zeros(
|
| 718 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
| 719 |
+
device=key_padding_mask.device,
|
| 720 |
+
)
|
| 721 |
+
new_key_padding_mask = torch.cat(
|
| 722 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
| 723 |
+
)
|
| 724 |
+
else:
|
| 725 |
+
new_key_padding_mask = key_padding_mask.float()
|
| 726 |
+
else:
|
| 727 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 728 |
+
return new_key_padding_mask
|
| 729 |
+
|
| 730 |
+
def _get_input_buffer(
|
| 731 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
| 732 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 733 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 734 |
+
if result is not None:
|
| 735 |
+
return result
|
| 736 |
+
else:
|
| 737 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
| 738 |
+
return empty_result
|
| 739 |
+
|
| 740 |
+
def _set_input_buffer(
|
| 741 |
+
self,
|
| 742 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 743 |
+
buffer: Dict[str, Optional[Tensor]],
|
| 744 |
+
):
|
| 745 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 746 |
+
|
| 747 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 748 |
+
return attn_weights
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def init_bert_params(module):
|
| 752 |
+
"""
|
| 753 |
+
Initialize the weights specific to the BERT Model.
|
| 754 |
+
This overrides the default initializations depending on the specified arguments.
|
| 755 |
+
1. If normal_init_linear_weights is set then weights of linear
|
| 756 |
+
layer will be initialized using the normal distribution and
|
| 757 |
+
bais will be set to the specified value.
|
| 758 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
| 759 |
+
layer will be initialized using the normal distribution.
|
| 760 |
+
3. If normal_init_proj_weights is set then weights of
|
| 761 |
+
in_project_weight for MultiHeadAttention initialized using
|
| 762 |
+
the normal distribution (to be validated).
|
| 763 |
+
"""
|
| 764 |
+
|
| 765 |
+
def normal_(data):
|
| 766 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
| 767 |
+
# so that the RNG is consistent with and without FSDP
|
| 768 |
+
data.copy_(
|
| 769 |
+
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
if isinstance(module, nn.Linear):
|
| 773 |
+
normal_(module.weight.data)
|
| 774 |
+
if module.bias is not None:
|
| 775 |
+
module.bias.data.zero_()
|
| 776 |
+
if isinstance(module, nn.Embedding):
|
| 777 |
+
normal_(module.weight.data)
|
| 778 |
+
if module.padding_idx is not None:
|
| 779 |
+
module.weight.data[module.padding_idx].zero_()
|
| 780 |
+
if isinstance(module, MultiheadAttention):
|
| 781 |
+
normal_(module.q_proj.weight.data)
|
| 782 |
+
normal_(module.k_proj.weight.data)
|
| 783 |
+
normal_(module.v_proj.weight.data)
|
models/beats/modules.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import warnings
|
| 12 |
+
import torch
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GradMultiply(torch.autograd.Function):
|
| 18 |
+
@staticmethod
|
| 19 |
+
def forward(ctx, x, scale):
|
| 20 |
+
ctx.scale = scale
|
| 21 |
+
res = x.new(x)
|
| 22 |
+
return res
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def backward(ctx, grad):
|
| 26 |
+
return grad * ctx.scale, None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SamePad(nn.Module):
|
| 30 |
+
def __init__(self, kernel_size, causal=False):
|
| 31 |
+
super().__init__()
|
| 32 |
+
if causal:
|
| 33 |
+
self.remove = kernel_size - 1
|
| 34 |
+
else:
|
| 35 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
if self.remove > 0:
|
| 39 |
+
x = x[:, :, : -self.remove]
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Swish(nn.Module):
|
| 44 |
+
def __init__(self):
|
| 45 |
+
super(Swish, self).__init__()
|
| 46 |
+
self.act = torch.nn.Sigmoid()
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return x * self.act(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class GLU_Linear(nn.Module):
|
| 53 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
| 54 |
+
super(GLU_Linear, self).__init__()
|
| 55 |
+
|
| 56 |
+
self.glu_type = glu_type
|
| 57 |
+
self.output_dim = output_dim
|
| 58 |
+
|
| 59 |
+
if glu_type == "sigmoid":
|
| 60 |
+
self.glu_act = torch.nn.Sigmoid()
|
| 61 |
+
elif glu_type == "swish":
|
| 62 |
+
self.glu_act = Swish()
|
| 63 |
+
elif glu_type == "relu":
|
| 64 |
+
self.glu_act = torch.nn.ReLU()
|
| 65 |
+
elif glu_type == "gelu":
|
| 66 |
+
self.glu_act = torch.nn.GELU()
|
| 67 |
+
|
| 68 |
+
if bias_in_glu:
|
| 69 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
| 70 |
+
else:
|
| 71 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
| 75 |
+
x = self.linear(x)
|
| 76 |
+
|
| 77 |
+
if self.glu_type == "bilinear":
|
| 78 |
+
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
| 79 |
+
else:
|
| 80 |
+
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
| 81 |
+
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def gelu_accurate(x):
|
| 86 |
+
if not hasattr(gelu_accurate, "_a"):
|
| 87 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
| 88 |
+
return (
|
| 89 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_activation_fn(activation: str):
|
| 98 |
+
"""Returns the activation function corresponding to `activation`"""
|
| 99 |
+
|
| 100 |
+
if activation == "relu":
|
| 101 |
+
return F.relu
|
| 102 |
+
elif activation == "gelu":
|
| 103 |
+
return gelu
|
| 104 |
+
elif activation == "gelu_fast":
|
| 105 |
+
warnings.warn(
|
| 106 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
| 107 |
+
)
|
| 108 |
+
return gelu_accurate
|
| 109 |
+
elif activation == "gelu_accurate":
|
| 110 |
+
return gelu_accurate
|
| 111 |
+
elif activation == "tanh":
|
| 112 |
+
return torch.tanh
|
| 113 |
+
elif activation == "linear":
|
| 114 |
+
return lambda x: x
|
| 115 |
+
elif activation == "glu":
|
| 116 |
+
return lambda x: x
|
| 117 |
+
else:
|
| 118 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def quant_noise(module, p, block_size):
|
| 122 |
+
"""
|
| 123 |
+
Wraps modules and applies quantization noise to the weights for
|
| 124 |
+
subsequent quantization with Iterative Product Quantization as
|
| 125 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
- module: nn.Module
|
| 129 |
+
- p: amount of Quantization Noise
|
| 130 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
| 131 |
+
|
| 132 |
+
Remarks:
|
| 133 |
+
- Module weights must have the right sizes wrt the block size
|
| 134 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
| 135 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
| 136 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
| 137 |
+
- We implement the simplest form of noise here as stated in the paper
|
| 138 |
+
which consists in randomly dropping blocks
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
# if no quantization noise, don't register hook
|
| 142 |
+
if p <= 0:
|
| 143 |
+
return module
|
| 144 |
+
|
| 145 |
+
# supported modules
|
| 146 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
| 147 |
+
|
| 148 |
+
# test whether module.weight has the right sizes wrt block_size
|
| 149 |
+
is_conv = module.weight.ndim == 4
|
| 150 |
+
|
| 151 |
+
# 2D matrix
|
| 152 |
+
if not is_conv:
|
| 153 |
+
assert (
|
| 154 |
+
module.weight.size(1) % block_size == 0
|
| 155 |
+
), "Input features must be a multiple of block sizes"
|
| 156 |
+
|
| 157 |
+
# 4D matrix
|
| 158 |
+
else:
|
| 159 |
+
# 1x1 convolutions
|
| 160 |
+
if module.kernel_size == (1, 1):
|
| 161 |
+
assert (
|
| 162 |
+
module.in_channels % block_size == 0
|
| 163 |
+
), "Input channels must be a multiple of block sizes"
|
| 164 |
+
# regular convolutions
|
| 165 |
+
else:
|
| 166 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
| 167 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
| 168 |
+
|
| 169 |
+
def _forward_pre_hook(mod, input):
|
| 170 |
+
# no noise for evaluation
|
| 171 |
+
if mod.training:
|
| 172 |
+
if not is_conv:
|
| 173 |
+
# gather weight and sizes
|
| 174 |
+
weight = mod.weight
|
| 175 |
+
in_features = weight.size(1)
|
| 176 |
+
out_features = weight.size(0)
|
| 177 |
+
|
| 178 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 179 |
+
mask = torch.zeros(
|
| 180 |
+
in_features // block_size * out_features, device=weight.device
|
| 181 |
+
)
|
| 182 |
+
mask.bernoulli_(p)
|
| 183 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 184 |
+
|
| 185 |
+
else:
|
| 186 |
+
# gather weight and sizes
|
| 187 |
+
weight = mod.weight
|
| 188 |
+
in_channels = mod.in_channels
|
| 189 |
+
out_channels = mod.out_channels
|
| 190 |
+
|
| 191 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 192 |
+
if mod.kernel_size == (1, 1):
|
| 193 |
+
mask = torch.zeros(
|
| 194 |
+
int(in_channels // block_size * out_channels),
|
| 195 |
+
device=weight.device,
|
| 196 |
+
)
|
| 197 |
+
mask.bernoulli_(p)
|
| 198 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 199 |
+
else:
|
| 200 |
+
mask = torch.zeros(
|
| 201 |
+
weight.size(0), weight.size(1), device=weight.device
|
| 202 |
+
)
|
| 203 |
+
mask.bernoulli_(p)
|
| 204 |
+
mask = (
|
| 205 |
+
mask.unsqueeze(2)
|
| 206 |
+
.unsqueeze(3)
|
| 207 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# scale weights and apply mask
|
| 211 |
+
mask = mask.to(
|
| 212 |
+
torch.bool
|
| 213 |
+
) # x.bool() is not currently supported in TorchScript
|
| 214 |
+
s = 1 / (1 - p)
|
| 215 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 216 |
+
|
| 217 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
| 218 |
+
return module
|
models/beats/quantizer.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on VQGAN code bases
|
| 7 |
+
# https://github.com/CompVis/taming-transformers
|
| 8 |
+
# --------------------------------------------------------'
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.distributed as distributed
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
except ImportError:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def l2norm(t):
|
| 22 |
+
return F.normalize(t, p=2, dim=-1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ema_inplace(moving_avg, new, decay):
|
| 26 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def sample_vectors(samples, num):
|
| 30 |
+
num_samples, device = samples.shape[0], samples.device
|
| 31 |
+
|
| 32 |
+
if num_samples >= num:
|
| 33 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 34 |
+
else:
|
| 35 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 36 |
+
|
| 37 |
+
return samples[indices]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
| 41 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
| 42 |
+
|
| 43 |
+
means = sample_vectors(samples, num_clusters)
|
| 44 |
+
|
| 45 |
+
for _ in range(num_iters):
|
| 46 |
+
if use_cosine_sim:
|
| 47 |
+
dists = samples @ means.t()
|
| 48 |
+
else:
|
| 49 |
+
diffs = rearrange(samples, 'n d -> n () d') \
|
| 50 |
+
- rearrange(means, 'c d -> () c d')
|
| 51 |
+
dists = -(diffs ** 2).sum(dim=-1)
|
| 52 |
+
|
| 53 |
+
buckets = dists.max(dim=-1).indices
|
| 54 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 55 |
+
zero_mask = bins == 0
|
| 56 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 57 |
+
|
| 58 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 59 |
+
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
| 60 |
+
new_means = new_means / bins_min_clamped[..., None]
|
| 61 |
+
|
| 62 |
+
if use_cosine_sim:
|
| 63 |
+
new_means = l2norm(new_means)
|
| 64 |
+
|
| 65 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
| 66 |
+
|
| 67 |
+
return means, bins
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class EmbeddingEMA(nn.Module):
|
| 71 |
+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.num_tokens = num_tokens
|
| 74 |
+
self.codebook_dim = codebook_dim
|
| 75 |
+
self.decay = decay
|
| 76 |
+
self.eps = eps
|
| 77 |
+
if codebook_init_path == '':
|
| 78 |
+
if not kmeans_init:
|
| 79 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
| 80 |
+
weight = l2norm(weight)
|
| 81 |
+
else:
|
| 82 |
+
weight = torch.zeros(num_tokens, codebook_dim)
|
| 83 |
+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
| 84 |
+
else:
|
| 85 |
+
print(f"load init codebook weight from {codebook_init_path}")
|
| 86 |
+
codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
|
| 87 |
+
weight = codebook_ckpt_weight.clone()
|
| 88 |
+
self.register_buffer('initted', torch.Tensor([True]))
|
| 89 |
+
|
| 90 |
+
self.weight = nn.Parameter(weight, requires_grad=False)
|
| 91 |
+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
| 92 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
| 93 |
+
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
| 94 |
+
self.update = True
|
| 95 |
+
|
| 96 |
+
@torch.jit.ignore
|
| 97 |
+
def init_embed_(self, data):
|
| 98 |
+
if self.initted:
|
| 99 |
+
return
|
| 100 |
+
print("Performing Kemans init for codebook")
|
| 101 |
+
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
|
| 102 |
+
self.weight.data.copy_(embed)
|
| 103 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 104 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
| 105 |
+
|
| 106 |
+
def forward(self, embed_id):
|
| 107 |
+
return F.embedding(embed_id, self.weight)
|
| 108 |
+
|
| 109 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
| 110 |
+
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
| 111 |
+
|
| 112 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
| 113 |
+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
| 114 |
+
|
| 115 |
+
def weight_update(self, num_tokens):
|
| 116 |
+
n = self.cluster_size.sum()
|
| 117 |
+
smoothed_cluster_size = (
|
| 118 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
| 119 |
+
)
|
| 120 |
+
# normalize embedding average with smoothed cluster size
|
| 121 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
| 122 |
+
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
|
| 123 |
+
self.weight.data.copy_(embed_normalized)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def norm_ema_inplace(moving_avg, new, decay):
|
| 127 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 128 |
+
moving_avg.data.copy_(l2norm(moving_avg.data))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class NormEMAVectorQuantizer(nn.Module):
|
| 132 |
+
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
| 133 |
+
statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.codebook_dim = embedding_dim
|
| 136 |
+
self.num_tokens = n_embed
|
| 137 |
+
self.beta = beta
|
| 138 |
+
self.decay = decay
|
| 139 |
+
|
| 140 |
+
# learnable = True if orthogonal_reg_weight > 0 else False
|
| 141 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
|
| 142 |
+
|
| 143 |
+
self.statistic_code_usage = statistic_code_usage
|
| 144 |
+
if statistic_code_usage:
|
| 145 |
+
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
| 146 |
+
if distributed.is_available() and distributed.is_initialized():
|
| 147 |
+
print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
|
| 148 |
+
self.all_reduce_fn = distributed.all_reduce
|
| 149 |
+
else:
|
| 150 |
+
self.all_reduce_fn = nn.Identity()
|
| 151 |
+
|
| 152 |
+
def reset_cluster_size(self, device):
|
| 153 |
+
if self.statistic_code_usage:
|
| 154 |
+
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
| 155 |
+
self.cluster_size = self.cluster_size.to(device)
|
| 156 |
+
|
| 157 |
+
def forward(self, z):
|
| 158 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 159 |
+
# z, 'b c h w -> b h w c'
|
| 160 |
+
# z = rearrange(z, 'b c h w -> b h w c')
|
| 161 |
+
# z = z.transpose(1, 2)
|
| 162 |
+
z = l2norm(z)
|
| 163 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
| 164 |
+
|
| 165 |
+
self.embedding.init_embed_(z_flattened)
|
| 166 |
+
|
| 167 |
+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
| 168 |
+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
| 169 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
| 170 |
+
|
| 171 |
+
encoding_indices = torch.argmin(d, dim=1)
|
| 172 |
+
|
| 173 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
| 174 |
+
|
| 175 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
| 176 |
+
|
| 177 |
+
if not self.training:
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
cluster_size = encodings.sum(0)
|
| 180 |
+
self.all_reduce_fn(cluster_size)
|
| 181 |
+
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
| 182 |
+
|
| 183 |
+
if self.training and self.embedding.update:
|
| 184 |
+
# EMA cluster size
|
| 185 |
+
|
| 186 |
+
bins = encodings.sum(0)
|
| 187 |
+
self.all_reduce_fn(bins)
|
| 188 |
+
|
| 189 |
+
# self.embedding.cluster_size_ema_update(bins)
|
| 190 |
+
ema_inplace(self.cluster_size, bins, self.decay)
|
| 191 |
+
|
| 192 |
+
zero_mask = (bins == 0)
|
| 193 |
+
bins = bins.masked_fill(zero_mask, 1.)
|
| 194 |
+
|
| 195 |
+
embed_sum = z_flattened.t() @ encodings
|
| 196 |
+
self.all_reduce_fn(embed_sum)
|
| 197 |
+
|
| 198 |
+
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
| 199 |
+
embed_normalized = l2norm(embed_normalized)
|
| 200 |
+
|
| 201 |
+
embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
|
| 202 |
+
embed_normalized)
|
| 203 |
+
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
|
| 204 |
+
|
| 205 |
+
# compute loss for embedding
|
| 206 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
| 207 |
+
|
| 208 |
+
# preserve gradients
|
| 209 |
+
z_q = z + (z_q - z).detach()
|
| 210 |
+
|
| 211 |
+
# reshape back to match original input shape
|
| 212 |
+
# z_q, 'b h w c -> b c h w'
|
| 213 |
+
# z_q = rearrange(z_q, 'b h w c -> b c h w')
|
| 214 |
+
# z_q = z_q.transpose(1, 2)
|
| 215 |
+
return z_q, loss, encoding_indices
|
models/frame_mn/Frame_MN_wrapper.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.frame_passt.preprocess import AugmentMelSTFT
|
| 2 |
+
from models.transformer_wrapper import BaseModelWrapper
|
| 3 |
+
from models.frame_mn.model import get_model
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FrameMNWrapper(BaseModelWrapper):
|
| 7 |
+
def __init__(self, width_mult=1.0) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.mel = AugmentMelSTFT(
|
| 10 |
+
n_mels=128,
|
| 11 |
+
sr=16_000,
|
| 12 |
+
win_length=400,
|
| 13 |
+
hopsize=160,
|
| 14 |
+
n_fft=512,
|
| 15 |
+
freqm=0,
|
| 16 |
+
timem=0,
|
| 17 |
+
htk=False,
|
| 18 |
+
fmin=0.0,
|
| 19 |
+
fmax=None,
|
| 20 |
+
norm=1,
|
| 21 |
+
fmin_aug_range=10,
|
| 22 |
+
fmax_aug_range=2000,
|
| 23 |
+
fast_norm=True,
|
| 24 |
+
preamp=True,
|
| 25 |
+
padding="center",
|
| 26 |
+
periodic_window=False,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.frame_mn = get_model(
|
| 30 |
+
width_mult=width_mult
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def mel_forward(self, x):
|
| 34 |
+
return self.mel(x)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return self.frame_mn(x)
|
| 38 |
+
|
| 39 |
+
def separate_params(self):
|
| 40 |
+
pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
|
| 41 |
+
for k, p in self.named_parameters():
|
| 42 |
+
if any(['cls_token' in k,
|
| 43 |
+
'pos_embed' in k,
|
| 44 |
+
'norm_stats' in k,
|
| 45 |
+
'patch_embed' in k]):
|
| 46 |
+
pt_params[0].append(p)
|
| 47 |
+
elif 'blocks.0.' in k:
|
| 48 |
+
pt_params[0].append(p)
|
| 49 |
+
elif 'blocks.1.' in k:
|
| 50 |
+
pt_params[1].append(p)
|
| 51 |
+
elif 'blocks.2.' in k:
|
| 52 |
+
pt_params[2].append(p)
|
| 53 |
+
elif 'blocks.3.' in k:
|
| 54 |
+
pt_params[3].append(p)
|
| 55 |
+
elif 'blocks.4.' in k:
|
| 56 |
+
pt_params[4].append(p)
|
| 57 |
+
elif 'blocks.5.' in k:
|
| 58 |
+
pt_params[5].append(p)
|
| 59 |
+
elif 'blocks.6.' in k:
|
| 60 |
+
pt_params[6].append(p)
|
| 61 |
+
elif 'blocks.7.' in k:
|
| 62 |
+
pt_params[7].append(p)
|
| 63 |
+
elif 'blocks.8.' in k:
|
| 64 |
+
pt_params[8].append(p)
|
| 65 |
+
elif 'blocks.9.' in k:
|
| 66 |
+
pt_params[9].append(p)
|
| 67 |
+
elif 'blocks.10.' in k:
|
| 68 |
+
pt_params[10].append(p)
|
| 69 |
+
elif 'blocks.11.' in k:
|
| 70 |
+
pt_params[11].append(p)
|
| 71 |
+
elif 'asit.norm.weight' in k or 'asit.norm.bias' in k:
|
| 72 |
+
pt_params[11].append(p)
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Check separate params for ASiT! Unknown key: {k}")
|
| 75 |
+
return list(reversed(pt_params))
|
models/frame_mn/block_types.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Callable, List
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torchvision.ops.misc import ConvNormActivation
|
| 6 |
+
|
| 7 |
+
from models.frame_mn.utils import make_divisible, cnn_out_size
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ConcurrentSEBlock(torch.nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
c_dim: int,
|
| 14 |
+
f_dim: int,
|
| 15 |
+
t_dim: int,
|
| 16 |
+
se_cnf: Dict
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
dims = [c_dim, f_dim, t_dim]
|
| 20 |
+
self.conc_se_layers = nn.ModuleList()
|
| 21 |
+
for d in se_cnf['se_dims']:
|
| 22 |
+
input_dim = dims[d-1]
|
| 23 |
+
squeeze_dim = make_divisible(input_dim // se_cnf['se_r'], 8)
|
| 24 |
+
self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d))
|
| 25 |
+
if se_cnf['se_agg'] == "max":
|
| 26 |
+
self.agg_op = lambda x: torch.max(x, dim=0)[0]
|
| 27 |
+
elif se_cnf['se_agg'] == "avg":
|
| 28 |
+
self.agg_op = lambda x: torch.mean(x, dim=0)
|
| 29 |
+
elif se_cnf['se_agg'] == "add":
|
| 30 |
+
self.agg_op = lambda x: torch.sum(x, dim=0)
|
| 31 |
+
elif se_cnf['se_agg'] == "min":
|
| 32 |
+
self.agg_op = lambda x: torch.min(x, dim=0)[0]
|
| 33 |
+
else:
|
| 34 |
+
raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented")
|
| 35 |
+
|
| 36 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 37 |
+
# apply all concurrent se layers
|
| 38 |
+
se_outs = []
|
| 39 |
+
for se_layer in self.conc_se_layers:
|
| 40 |
+
se_outs.append(se_layer(input))
|
| 41 |
+
out = self.agg_op(torch.stack(se_outs, dim=0))
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SqueezeExcitation(torch.nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507.
|
| 48 |
+
Args:
|
| 49 |
+
input_dim (int): Input dimension
|
| 50 |
+
squeeze_dim (int): Size of Bottleneck
|
| 51 |
+
activation (Callable): activation applied to bottleneck
|
| 52 |
+
scale_activation (Callable): activation applied to the output
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
input_dim: int,
|
| 58 |
+
squeeze_dim: int,
|
| 59 |
+
se_dim: int,
|
| 60 |
+
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
|
| 61 |
+
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
|
| 62 |
+
) -> None:
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.fc1 = torch.nn.Linear(input_dim, squeeze_dim)
|
| 65 |
+
self.fc2 = torch.nn.Linear(squeeze_dim, input_dim)
|
| 66 |
+
assert se_dim in [1, 2, 3]
|
| 67 |
+
self.se_dim = [1, 2, 3]
|
| 68 |
+
self.se_dim.remove(se_dim)
|
| 69 |
+
self.activation = activation()
|
| 70 |
+
self.scale_activation = scale_activation()
|
| 71 |
+
|
| 72 |
+
def _scale(self, input: Tensor) -> Tensor:
|
| 73 |
+
scale = torch.mean(input, self.se_dim, keepdim=True)
|
| 74 |
+
shape = scale.size()
|
| 75 |
+
scale = self.fc1(scale.squeeze(2).squeeze(2))
|
| 76 |
+
scale = self.activation(scale)
|
| 77 |
+
scale = self.fc2(scale)
|
| 78 |
+
scale = scale
|
| 79 |
+
return self.scale_activation(scale).view(shape)
|
| 80 |
+
|
| 81 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 82 |
+
scale = self._scale(input)
|
| 83 |
+
return scale * input
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class InvertedResidualConfig:
|
| 87 |
+
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
input_channels: int,
|
| 91 |
+
kernel: int,
|
| 92 |
+
expanded_channels: int,
|
| 93 |
+
out_channels: int,
|
| 94 |
+
use_se: bool,
|
| 95 |
+
activation: str,
|
| 96 |
+
stride: tuple[int],
|
| 97 |
+
dilation: tuple[int],
|
| 98 |
+
width_mult: float,
|
| 99 |
+
):
|
| 100 |
+
self.input_channels = self.adjust_channels(input_channels, width_mult)
|
| 101 |
+
self.kernel = kernel
|
| 102 |
+
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
|
| 103 |
+
self.out_channels = self.adjust_channels(out_channels, width_mult)
|
| 104 |
+
self.use_se = use_se
|
| 105 |
+
self.use_hs = activation == "HS"
|
| 106 |
+
self.stride = stride
|
| 107 |
+
self.dilation = dilation
|
| 108 |
+
self.f_dim = None
|
| 109 |
+
self.t_dim = None
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def adjust_channels(channels: int, width_mult: float):
|
| 113 |
+
return make_divisible(channels * width_mult, 8)
|
| 114 |
+
|
| 115 |
+
def out_size(self, in_size, idx=None):
|
| 116 |
+
dilation = self.dilation if idx is None else self.dilation[idx]
|
| 117 |
+
padding = (self.kernel - 1) // 2 * dilation
|
| 118 |
+
stride = self.stride if idx is None else self.stride[idx]
|
| 119 |
+
return cnn_out_size(in_size, padding, dilation, self.kernel, stride)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class InvertedResidual(nn.Module):
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
cnf: InvertedResidualConfig,
|
| 126 |
+
se_cnf: Dict,
|
| 127 |
+
norm_layer: Callable[..., nn.Module],
|
| 128 |
+
depthwise_norm_layer: Callable[..., nn.Module]
|
| 129 |
+
):
|
| 130 |
+
super().__init__()
|
| 131 |
+
|
| 132 |
+
if not (1 <= cnf.stride[0] <= 2 or 1 <= cnf.stride[1] <= 2):
|
| 133 |
+
raise ValueError("illegal stride value")
|
| 134 |
+
|
| 135 |
+
self.use_res_connect = cnf.stride[0] == 1 and cnf.stride[1] == 1 and cnf.input_channels == cnf.out_channels
|
| 136 |
+
|
| 137 |
+
layers: List[nn.Module] = []
|
| 138 |
+
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
|
| 139 |
+
|
| 140 |
+
# expand
|
| 141 |
+
if cnf.expanded_channels != cnf.input_channels:
|
| 142 |
+
layers.append(
|
| 143 |
+
ConvNormActivation(
|
| 144 |
+
cnf.input_channels,
|
| 145 |
+
cnf.expanded_channels,
|
| 146 |
+
kernel_size=1,
|
| 147 |
+
norm_layer=norm_layer,
|
| 148 |
+
activation_layer=activation_layer,
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# depthwise
|
| 153 |
+
d = cnf.dilation > 1 if isinstance(cnf.dilation, int) else cnf.dilation[1] > 1
|
| 154 |
+
stride = [cnf.stride, cnf.stride] if isinstance(cnf.stride, int) else list(cnf.stride)
|
| 155 |
+
|
| 156 |
+
if d:
|
| 157 |
+
stride[1] = 1
|
| 158 |
+
|
| 159 |
+
layers.append(
|
| 160 |
+
ConvNormActivation(
|
| 161 |
+
cnf.expanded_channels,
|
| 162 |
+
cnf.expanded_channels,
|
| 163 |
+
kernel_size=cnf.kernel,
|
| 164 |
+
stride=tuple(stride),
|
| 165 |
+
dilation=cnf.dilation,
|
| 166 |
+
groups=cnf.expanded_channels,
|
| 167 |
+
norm_layer=depthwise_norm_layer,
|
| 168 |
+
activation_layer=activation_layer,
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
if cnf.use_se and se_cnf['se_dims'] is not None:
|
| 172 |
+
layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf))
|
| 173 |
+
|
| 174 |
+
# project
|
| 175 |
+
layers.append(
|
| 176 |
+
ConvNormActivation(
|
| 177 |
+
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self.block = nn.Sequential(*layers)
|
| 182 |
+
self.out_channels = cnf.out_channels
|
| 183 |
+
# self._is_cn = cnf.stride[0] > 1 and cnf.stride[1] > 1
|
| 184 |
+
|
| 185 |
+
def forward(self, inp: Tensor) -> Tensor:
|
| 186 |
+
result = self.block(inp)
|
| 187 |
+
if self.use_res_connect:
|
| 188 |
+
result += inp
|
| 189 |
+
return result
|
models/frame_mn/model.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import urllib.parse
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
from torch.hub import load_state_dict_from_url
|
| 9 |
+
from torchvision.ops.misc import ConvNormActivation
|
| 10 |
+
|
| 11 |
+
from models.frame_mn.block_types import InvertedResidualConfig, InvertedResidual
|
| 12 |
+
from models.frame_mn.utils import cnn_out_size
|
| 13 |
+
|
| 14 |
+
# Adapted version of MobileNetV3 pytorch implementation
|
| 15 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py
|
| 16 |
+
|
| 17 |
+
# points to github releases
|
| 18 |
+
model_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/"
|
| 19 |
+
# folder to store downloaded models to
|
| 20 |
+
model_dir = "resources"
|
| 21 |
+
|
| 22 |
+
pretrained_models = {
|
| 23 |
+
# pytorch ImageNet pre-trained model
|
| 24 |
+
# own ImageNet pre-trained models will follow
|
| 25 |
+
# NOTE: for easy loading we provide the adapted state dict ready for AudioSet training (1 input channel,
|
| 26 |
+
# 527 output classes)
|
| 27 |
+
# NOTE: the classifier is just a random initialization, feature extractor (conv layers) is pre-trained
|
| 28 |
+
"mn10_im_pytorch": urllib.parse.urljoin(model_url, "mn10_im_pytorch.pt"),
|
| 29 |
+
# self-trained models on ImageNet
|
| 30 |
+
"mn01_im": urllib.parse.urljoin(model_url, "mn01_im.pt"),
|
| 31 |
+
"mn02_im": urllib.parse.urljoin(model_url, "mn02_im.pt"),
|
| 32 |
+
"mn04_im": urllib.parse.urljoin(model_url, "mn04_im.pt"),
|
| 33 |
+
"mn05_im": urllib.parse.urljoin(model_url, "mn05_im.pt"),
|
| 34 |
+
"mn06_im": urllib.parse.urljoin(model_url, "mn06_im.pt"),
|
| 35 |
+
"mn10_im": urllib.parse.urljoin(model_url, "mn10_im.pt"),
|
| 36 |
+
"mn20_im": urllib.parse.urljoin(model_url, "mn20_im.pt"),
|
| 37 |
+
"mn30_im": urllib.parse.urljoin(model_url, "mn30_im.pt"),
|
| 38 |
+
"mn40_im": urllib.parse.urljoin(model_url, "mn40_im.pt"),
|
| 39 |
+
# Models trained on AudioSet
|
| 40 |
+
"mn01_as": urllib.parse.urljoin(model_url, "mn01_as_mAP_298.pt"),
|
| 41 |
+
"mn02_as": urllib.parse.urljoin(model_url, "mn02_as_mAP_378.pt"),
|
| 42 |
+
"mn04_as": urllib.parse.urljoin(model_url, "mn04_as_mAP_432.pt"),
|
| 43 |
+
"mn05_as": urllib.parse.urljoin(model_url, "mn05_as_mAP_443.pt"),
|
| 44 |
+
"mn10_as": urllib.parse.urljoin(model_url, "mn10_as_mAP_471.pt"),
|
| 45 |
+
"mn20_as": urllib.parse.urljoin(model_url, "mn20_as_mAP_478.pt"),
|
| 46 |
+
"mn30_as": urllib.parse.urljoin(model_url, "mn30_as_mAP_482.pt"),
|
| 47 |
+
"mn40_as": urllib.parse.urljoin(model_url, "mn40_as_mAP_484.pt"),
|
| 48 |
+
"mn40_as(2)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483.pt"),
|
| 49 |
+
"mn40_as(3)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483(2).pt"),
|
| 50 |
+
"mn40_as_no_im_pre": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483.pt"),
|
| 51 |
+
"mn40_as_no_im_pre(2)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483(2).pt"),
|
| 52 |
+
"mn40_as_no_im_pre(3)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_482.pt"),
|
| 53 |
+
"mn40_as_ext": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_487.pt"),
|
| 54 |
+
"mn40_as_ext(2)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_486.pt"),
|
| 55 |
+
"mn40_as_ext(3)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_485.pt"),
|
| 56 |
+
# varying hop size (time resolution)
|
| 57 |
+
"mn10_as_hop_5": urllib.parse.urljoin(model_url, "mn10_as_hop_5_mAP_475.pt"),
|
| 58 |
+
"mn10_as_hop_15": urllib.parse.urljoin(model_url, "mn10_as_hop_15_mAP_463.pt"),
|
| 59 |
+
"mn10_as_hop_20": urllib.parse.urljoin(model_url, "mn10_as_hop_20_mAP_456.pt"),
|
| 60 |
+
"mn10_as_hop_25": urllib.parse.urljoin(model_url, "mn10_as_hop_25_mAP_447.pt"),
|
| 61 |
+
# varying n_mels (frequency resolution)
|
| 62 |
+
"mn10_as_mels_40": urllib.parse.urljoin(model_url, "mn10_as_mels_40_mAP_453.pt"),
|
| 63 |
+
"mn10_as_mels_64": urllib.parse.urljoin(model_url, "mn10_as_mels_64_mAP_461.pt"),
|
| 64 |
+
"mn10_as_mels_256": urllib.parse.urljoin(model_url, "mn10_as_mels_256_mAP_474.pt"),
|
| 65 |
+
# fully-convolutional head
|
| 66 |
+
"mn10_as_fc": urllib.parse.urljoin(model_url, "mn10_as_fc_mAP_465.pt"),
|
| 67 |
+
"mn10_as_fc_s2221": urllib.parse.urljoin(model_url, "mn10_as_fc_s2221_mAP_466.pt"),
|
| 68 |
+
"mn10_as_fc_s2211": urllib.parse.urljoin(model_url, "mn10_as_fc_s2211_mAP_466.pt"),
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MN(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
inverted_residual_setting: List[InvertedResidualConfig],
|
| 76 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
| 77 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 78 |
+
in_conv_kernel: int = 3,
|
| 79 |
+
in_conv_stride: int = 2,
|
| 80 |
+
in_channels: int = 1,
|
| 81 |
+
**kwargs: Any,
|
| 82 |
+
) -> None:
|
| 83 |
+
"""
|
| 84 |
+
MobileNet V3 main class
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
|
| 88 |
+
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for models
|
| 89 |
+
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
|
| 90 |
+
in_conv_kernel (int): Size of kernel for first convolution
|
| 91 |
+
in_conv_stride (int): Size of stride for first convolution
|
| 92 |
+
in_channels (int): Number of input channels
|
| 93 |
+
"""
|
| 94 |
+
super(MN, self).__init__()
|
| 95 |
+
|
| 96 |
+
if not inverted_residual_setting:
|
| 97 |
+
raise ValueError("The inverted_residual_setting should not be empty")
|
| 98 |
+
elif not (
|
| 99 |
+
isinstance(inverted_residual_setting, Sequence)
|
| 100 |
+
and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
|
| 101 |
+
):
|
| 102 |
+
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
|
| 103 |
+
|
| 104 |
+
if block is None:
|
| 105 |
+
block = InvertedResidual
|
| 106 |
+
|
| 107 |
+
depthwise_norm_layer = norm_layer = \
|
| 108 |
+
norm_layer if norm_layer is not None else partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
|
| 109 |
+
|
| 110 |
+
layers: List[nn.Module] = []
|
| 111 |
+
|
| 112 |
+
kernel_sizes = [in_conv_kernel]
|
| 113 |
+
strides = [in_conv_stride]
|
| 114 |
+
|
| 115 |
+
# building first layer
|
| 116 |
+
firstconv_output_channels = inverted_residual_setting[0].input_channels
|
| 117 |
+
layers.append(
|
| 118 |
+
ConvNormActivation(
|
| 119 |
+
in_channels,
|
| 120 |
+
firstconv_output_channels,
|
| 121 |
+
kernel_size=in_conv_kernel,
|
| 122 |
+
stride=in_conv_stride,
|
| 123 |
+
norm_layer=norm_layer,
|
| 124 |
+
activation_layer=nn.Hardswish,
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# get squeeze excitation config
|
| 129 |
+
se_cnf = kwargs.get('se_conf', None)
|
| 130 |
+
|
| 131 |
+
# building inverted residual blocks
|
| 132 |
+
# - keep track of size of frequency and time dimensions for possible application of Squeeze-and-Excitation
|
| 133 |
+
# on the frequency/time dimension
|
| 134 |
+
# - applying Squeeze-and-Excitation on the time dimension is not recommended as this constrains the network to
|
| 135 |
+
# a particular length of the audio clip, whereas Squeeze-and-Excitation on the frequency bands is fine,
|
| 136 |
+
# as the number of frequency bands is usually not changing
|
| 137 |
+
f_dim, t_dim = kwargs.get('input_dims', (128, 1000))
|
| 138 |
+
# take into account first conv layer
|
| 139 |
+
f_dim = cnn_out_size(f_dim, 1, 1, 3, 2)
|
| 140 |
+
t_dim = cnn_out_size(t_dim, 1, 1, 3, 2)
|
| 141 |
+
for cnf in inverted_residual_setting:
|
| 142 |
+
f_dim = cnf.out_size(f_dim, idx=0)
|
| 143 |
+
t_dim = cnf.out_size(t_dim, idx=1)
|
| 144 |
+
cnf.f_dim, cnf.t_dim = f_dim, t_dim # update dimensions in block config
|
| 145 |
+
layers.append(block(cnf, se_cnf, norm_layer, depthwise_norm_layer))
|
| 146 |
+
kernel_sizes.append(cnf.kernel)
|
| 147 |
+
strides.append(cnf.stride)
|
| 148 |
+
|
| 149 |
+
# building last several layers
|
| 150 |
+
lastconv_input_channels = inverted_residual_setting[-1].out_channels
|
| 151 |
+
lastconv_output_channels = 6 * lastconv_input_channels
|
| 152 |
+
self.lastconv_output_channels = lastconv_output_channels
|
| 153 |
+
layers.append(
|
| 154 |
+
ConvNormActivation(
|
| 155 |
+
lastconv_input_channels,
|
| 156 |
+
lastconv_output_channels,
|
| 157 |
+
kernel_size=1,
|
| 158 |
+
norm_layer=norm_layer,
|
| 159 |
+
activation_layer=nn.Hardswish,
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self.features = nn.Sequential(*layers)
|
| 164 |
+
|
| 165 |
+
# no prediction head needed - we want to use Frame-MobileNet to extract a 3D sequence
|
| 166 |
+
# i.e.: batch size x sequence length x channel dimension
|
| 167 |
+
|
| 168 |
+
for m in self.modules():
|
| 169 |
+
if isinstance(m, nn.Conv2d):
|
| 170 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
| 171 |
+
if m.bias is not None:
|
| 172 |
+
nn.init.zeros_(m.bias)
|
| 173 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
|
| 174 |
+
nn.init.ones_(m.weight)
|
| 175 |
+
nn.init.zeros_(m.bias)
|
| 176 |
+
elif isinstance(m, nn.Linear):
|
| 177 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 178 |
+
if m.bias is not None:
|
| 179 |
+
nn.init.zeros_(m.bias)
|
| 180 |
+
|
| 181 |
+
def _forward_impl(self, x: Tensor, return_fmaps: bool = False) -> Tensor:
|
| 182 |
+
fmaps = []
|
| 183 |
+
|
| 184 |
+
for i, layer in enumerate(self.features):
|
| 185 |
+
x = layer(x)
|
| 186 |
+
if return_fmaps:
|
| 187 |
+
fmaps.append(x)
|
| 188 |
+
|
| 189 |
+
# reshape: batch size x channels x frequency bands x time -> batch size x time x channels
|
| 190 |
+
# works, because frequency dimension is exactly 1
|
| 191 |
+
x = x.squeeze(2).permute(0, 2, 1)
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 195 |
+
return self._forward_impl(x)
|
| 196 |
+
|
| 197 |
+
def load_model(self, path, wandb_id):
|
| 198 |
+
ckpt_path = os.path.join(path, wandb_id + ".ckpt")
|
| 199 |
+
|
| 200 |
+
pretrained_weights = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 201 |
+
pretrained_weights = {k[10:]: v for k, v in pretrained_weights.items() if k[:10] == "net.model."}
|
| 202 |
+
self.load_state_dict(pretrained_weights)
|
| 203 |
+
|
| 204 |
+
print("Loaded model successfully. Wandb_id:", wandb_id)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _mobilenet_v3_conf(
|
| 208 |
+
width_mult: float = 1.0,
|
| 209 |
+
reduced_tail: bool = False,
|
| 210 |
+
dilated: bool = False,
|
| 211 |
+
strides: Tuple[int] = None,
|
| 212 |
+
dilation_list_t_dim: Optional[List[int]] = None,
|
| 213 |
+
**kwargs
|
| 214 |
+
):
|
| 215 |
+
reduce_divider = 2 if reduced_tail else 1
|
| 216 |
+
if dilation_list_t_dim is None:
|
| 217 |
+
dilation_list_t_dim = [1] * 15
|
| 218 |
+
if dilated:
|
| 219 |
+
dilation_list_t_dim[-3:] = [2] * 3
|
| 220 |
+
|
| 221 |
+
print("dilation_list_t_dim: ")
|
| 222 |
+
print(dilation_list_t_dim)
|
| 223 |
+
|
| 224 |
+
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
|
| 225 |
+
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
|
| 226 |
+
|
| 227 |
+
if strides is None:
|
| 228 |
+
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
|
| 229 |
+
f_strides = (1, 2, 2, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2)
|
| 230 |
+
t_strides = (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
|
| 231 |
+
|
| 232 |
+
strides = tuple(zip(f_strides, t_strides))
|
| 233 |
+
|
| 234 |
+
# InvertedResidualConfig:
|
| 235 |
+
# input_channels, kernel, expanded_channels, out_channels, use_se, activation, stride, dilation
|
| 236 |
+
inverted_residual_setting = [
|
| 237 |
+
bneck_conf(16, 3, 16, 16, False, "RE", strides[0], (1, dilation_list_t_dim[0])), # 0
|
| 238 |
+
bneck_conf(16, 3, 64, 24, False, "RE", strides[1], (1, dilation_list_t_dim[1])), # 1 - C1
|
| 239 |
+
bneck_conf(24, 3, 72, 24, False, "RE", strides[2], (1, dilation_list_t_dim[2])), # 2
|
| 240 |
+
bneck_conf(24, 5, 72, 40, True, "RE", strides[3], (1, dilation_list_t_dim[3])), # 3 - C2
|
| 241 |
+
bneck_conf(40, 5, 120, 40, True, "RE", strides[4], (1, dilation_list_t_dim[4])), # 4
|
| 242 |
+
bneck_conf(40, 5, 120, 40, True, "RE", strides[5], (1, dilation_list_t_dim[5])), # 5
|
| 243 |
+
bneck_conf(40, 3, 240, 80, False, "HS", strides[6], (1, dilation_list_t_dim[6])), # 6 - C3
|
| 244 |
+
bneck_conf(80, 3, 200, 80, False, "HS", strides[7], (1, dilation_list_t_dim[7])), # 7
|
| 245 |
+
bneck_conf(80, 3, 184, 80, False, "HS", strides[8], (1, dilation_list_t_dim[8])), # 8
|
| 246 |
+
bneck_conf(80, 3, 184, 80, False, "HS", strides[9], (1, dilation_list_t_dim[9])), # 9
|
| 247 |
+
bneck_conf(80, 3, 480, 112, True, "HS", strides[10], (1, dilation_list_t_dim[10])), # 10
|
| 248 |
+
bneck_conf(112, 3, 672, 112, True, "HS", strides[11], (1, dilation_list_t_dim[11])), # 11
|
| 249 |
+
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", strides[12], (1, dilation_list_t_dim[12])),
|
| 250 |
+
# 12 - C4 # dilation
|
| 251 |
+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", strides[13],
|
| 252 |
+
(1, dilation_list_t_dim[13])), # 13 # dilation
|
| 253 |
+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", strides[14],
|
| 254 |
+
(1, dilation_list_t_dim[14])), # 14 # dilation
|
| 255 |
+
]
|
| 256 |
+
last_channel = adjust_channels(1280 // reduce_divider)
|
| 257 |
+
|
| 258 |
+
return inverted_residual_setting, last_channel
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _mobilenet_v3(
|
| 262 |
+
inverted_residual_setting: List[InvertedResidualConfig],
|
| 263 |
+
pretrained_name: str,
|
| 264 |
+
**kwargs: Any,
|
| 265 |
+
):
|
| 266 |
+
model = MN(inverted_residual_setting, **kwargs)
|
| 267 |
+
|
| 268 |
+
if pretrained_name in pretrained_models:
|
| 269 |
+
model_url = pretrained_models.get(pretrained_name)
|
| 270 |
+
state_dict = load_state_dict_from_url(model_url, model_dir=model_dir, map_location="cpu")
|
| 271 |
+
if kwargs['head_type'] == "mlp":
|
| 272 |
+
num_classes = state_dict['classifier.5.bias'].size(0)
|
| 273 |
+
elif kwargs['head_type'] == "fully_convolutional":
|
| 274 |
+
num_classes = state_dict['classifier.1.bias'].size(0)
|
| 275 |
+
else:
|
| 276 |
+
print("Loading weights for classifier only implemented for head types 'mlp' and 'fully_convolutional'")
|
| 277 |
+
num_classes = -1
|
| 278 |
+
if kwargs['num_classes'] != num_classes:
|
| 279 |
+
# if the number of logits is not matching the state dict,
|
| 280 |
+
# drop the corresponding pre-trained part
|
| 281 |
+
pretrain_logits = state_dict['classifier.5.bias'].size(0) if kwargs['head_type'] == "mlp" \
|
| 282 |
+
else state_dict['classifier.1.bias'].size(0)
|
| 283 |
+
print(f"Number of classes defined: {kwargs['num_classes']}, "
|
| 284 |
+
f"but try to load pre-trained layer with logits: {pretrain_logits}\n"
|
| 285 |
+
"Dropping last layer.")
|
| 286 |
+
if kwargs['head_type'] == "mlp":
|
| 287 |
+
del state_dict['classifier.5.weight']
|
| 288 |
+
del state_dict['classifier.5.bias']
|
| 289 |
+
else:
|
| 290 |
+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith('classifier')}
|
| 291 |
+
try:
|
| 292 |
+
model.load_state_dict(state_dict)
|
| 293 |
+
except RuntimeError as e:
|
| 294 |
+
print(str(e))
|
| 295 |
+
print("Loading weights pre-trained weights in a non-strict manner.")
|
| 296 |
+
model.load_state_dict(state_dict, strict=False)
|
| 297 |
+
elif pretrained_name:
|
| 298 |
+
raise NotImplementedError(f"Model name '{pretrained_name}' unknown.")
|
| 299 |
+
return model
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def mobilenet_v3(pretrained_name: str = None, **kwargs: Any) \
|
| 303 |
+
-> MN:
|
| 304 |
+
"""
|
| 305 |
+
Constructs a MobileNetV3 architecture from
|
| 306 |
+
"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>".
|
| 307 |
+
"""
|
| 308 |
+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(**kwargs)
|
| 309 |
+
return _mobilenet_v3(inverted_residual_setting, pretrained_name, **kwargs)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def get_model(pretrained_name: str = None, width_mult: float = 1.0,
|
| 313 |
+
reduced_tail: bool = False, dilated: bool = False, dilation_list_t_dim=None,
|
| 314 |
+
strides: Tuple[int, int, int, int] = None,
|
| 315 |
+
head_type: str = "mlp", multihead_attention_heads: int = 4, input_dim_f: int = 128,
|
| 316 |
+
input_dim_t: int = 1000, se_dims: str = 'c', se_agg: str = "max", se_r: int = 4):
|
| 317 |
+
"""
|
| 318 |
+
Arguments to modify the instantiation of a MobileNetv3
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
pretrained_name (str): Specifies name of pre-trained model to load
|
| 322 |
+
width_mult (float): Scales width of network
|
| 323 |
+
reduced_tail (bool): Scales down network tail
|
| 324 |
+
dilated (bool): Applies dilated convolution to network tail
|
| 325 |
+
dilation_list_t_dim (List): List of dilation factors to apply to network tail
|
| 326 |
+
strides (Tuple): Strides that are set to '2' in original implementation;
|
| 327 |
+
might be changed to modify the size of receptive field and the downsampling factor in
|
| 328 |
+
time and frequency dimension
|
| 329 |
+
head_type (str): decides which classification head to use
|
| 330 |
+
multihead_attention_heads (int): number of heads in case 'multihead_attention_heads' is used
|
| 331 |
+
input_dim_f (int): number of frequency bands
|
| 332 |
+
input_dim_t (int): number of time frames
|
| 333 |
+
se_dims (Tuple): choose dimension to apply squeeze-excitation on, if multiple dimensions are chosen, then
|
| 334 |
+
squeeze-excitation is applied concurrently and se layer outputs are fused by se_agg operation
|
| 335 |
+
se_agg (str): operation to fuse output of concurrent se layers
|
| 336 |
+
se_r (int): squeeze excitation bottleneck size
|
| 337 |
+
se_dims (str): contains letters corresponding to dimensions 'c' - channel, 'f' - frequency, 't' - time
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
dim_map = {'c': 1, 'f': 2, 't': 3}
|
| 341 |
+
assert len(se_dims) <= 3 and all([s in dim_map.keys() for s in se_dims]) or se_dims == 'none'
|
| 342 |
+
input_dims = (input_dim_f, input_dim_t)
|
| 343 |
+
if se_dims == 'none':
|
| 344 |
+
se_dims = None
|
| 345 |
+
else:
|
| 346 |
+
se_dims = [dim_map[s] for s in se_dims]
|
| 347 |
+
se_conf = dict(se_dims=se_dims, se_agg=se_agg, se_r=se_r)
|
| 348 |
+
m = mobilenet_v3(pretrained_name=pretrained_name,
|
| 349 |
+
width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated,
|
| 350 |
+
dilation_list_t_dim=dilation_list_t_dim,
|
| 351 |
+
strides=strides,
|
| 352 |
+
head_type=head_type, multihead_attention_heads=multihead_attention_heads,
|
| 353 |
+
input_dims=input_dims, se_conf=se_conf
|
| 354 |
+
)
|
| 355 |
+
print(m)
|
| 356 |
+
return m
|
models/frame_mn/utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Callable
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def NAME_TO_WIDTH(name):
|
| 9 |
+
frame_mn_map = {
|
| 10 |
+
'frame_mn01': 0.1,
|
| 11 |
+
'frame_mn02': 0.2,
|
| 12 |
+
'frame_mn04': 0.4,
|
| 13 |
+
'frame_mn05': 0.5,
|
| 14 |
+
'frame_mn06': 0.6,
|
| 15 |
+
'frame_mn08': 0.8,
|
| 16 |
+
'frame_mn10': 1.0,
|
| 17 |
+
'frame_mn12': 1.2,
|
| 18 |
+
'frame_mn14': 1.4,
|
| 19 |
+
'frame_mn16': 1.6,
|
| 20 |
+
'frame_mn20': 2.0,
|
| 21 |
+
'frame_mn30': 3.0,
|
| 22 |
+
'frame_mn40': 4.0,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
frame_dymn_map = {
|
| 26 |
+
'frame_dymn04': 0.4,
|
| 27 |
+
'frame_dymn10': 1.0,
|
| 28 |
+
'frame_dymn20': 2.0,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
if name.startswith('frame_dymn'):
|
| 33 |
+
w = frame_dymn_map[name[:len('frame_dymnxx')]]
|
| 34 |
+
else:
|
| 35 |
+
w = frame_mn_map[name[:len('frame_mnxx')]]
|
| 36 |
+
except:
|
| 37 |
+
w = 1.0
|
| 38 |
+
|
| 39 |
+
return w
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
|
| 43 |
+
"""
|
| 44 |
+
This function is taken from the original tf repo.
|
| 45 |
+
It ensures that all layers have a channel number that is divisible by 8
|
| 46 |
+
It can be seen here:
|
| 47 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
| 48 |
+
"""
|
| 49 |
+
if min_value is None:
|
| 50 |
+
min_value = divisor
|
| 51 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 52 |
+
# Make sure that round down does not go down by more than 10%.
|
| 53 |
+
if new_v < 0.9 * v:
|
| 54 |
+
new_v += divisor
|
| 55 |
+
return new_v
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def cnn_out_size(in_size, padding, dilation, kernel, stride):
|
| 59 |
+
s = in_size + 2 * padding - dilation * (kernel - 1) - 1
|
| 60 |
+
return math.floor(s / stride + 1)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def collapse_dim(x: Tensor, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean,
|
| 64 |
+
combine_dim: int = None):
|
| 65 |
+
"""
|
| 66 |
+
Collapses dimension of multi-dimensional tensor by pooling or combining dimensions
|
| 67 |
+
:param x: input Tensor
|
| 68 |
+
:param dim: dimension to collapse
|
| 69 |
+
:param mode: 'pool' or 'combine'
|
| 70 |
+
:param pool_fn: function to be applied in case of pooling
|
| 71 |
+
:param combine_dim: dimension to join 'dim' to
|
| 72 |
+
:return: collapsed tensor
|
| 73 |
+
"""
|
| 74 |
+
if mode == "pool":
|
| 75 |
+
return pool_fn(x, dim)
|
| 76 |
+
elif mode == "combine":
|
| 77 |
+
s = list(x.size())
|
| 78 |
+
s[combine_dim] *= dim
|
| 79 |
+
s[dim] //= dim
|
| 80 |
+
return x.view(s)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class CollapseDim(nn.Module):
|
| 84 |
+
def __init__(self, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean,
|
| 85 |
+
combine_dim: int = None):
|
| 86 |
+
super(CollapseDim, self).__init__()
|
| 87 |
+
self.dim = dim
|
| 88 |
+
self.mode = mode
|
| 89 |
+
self.pool_fn = pool_fn
|
| 90 |
+
self.combine_dim = combine_dim
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
return collapse_dim(x, dim=self.dim, mode=self.mode, pool_fn=self.pool_fn, combine_dim=self.combine_dim)
|
models/frame_passt/fpasst.py
ADDED
|
@@ -0,0 +1,963 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Most of this code comes from the timm library.
|
| 3 |
+
We tried to disentangle from the timm library version.
|
| 4 |
+
|
| 5 |
+
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
import collections
|
| 9 |
+
import logging
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from functools import partial
|
| 15 |
+
from itertools import repeat
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
from models.frame_passt.vit_helpers import (DropPath, trunc_normal_,
|
| 21 |
+
build_model_with_cfg, adapt_input_conv)
|
| 22 |
+
|
| 23 |
+
_logger = logging.getLogger()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# From PyTorch internals
|
| 27 |
+
def _ntuple(n):
|
| 28 |
+
def parse(x):
|
| 29 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 30 |
+
return tuple(x)
|
| 31 |
+
return tuple(repeat(x, n))
|
| 32 |
+
|
| 33 |
+
return parse
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
to_2tuple = _ntuple(2)
|
| 37 |
+
|
| 38 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 39 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 40 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
| 41 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _cfg(url='', **kwargs):
|
| 45 |
+
return {
|
| 46 |
+
'url': url,
|
| 47 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 48 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
| 49 |
+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
| 50 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
| 51 |
+
**kwargs
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
default_cfgs = {
|
| 56 |
+
# patch models (weights from official Google JAX impl)
|
| 57 |
+
'vit_tiny_patch16_224': _cfg(
|
| 58 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 59 |
+
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
| 60 |
+
'vit_tiny_patch16_384': _cfg(
|
| 61 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 62 |
+
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
| 63 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 64 |
+
'vit_small_patch32_224': _cfg(
|
| 65 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 66 |
+
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
| 67 |
+
'vit_small_patch32_384': _cfg(
|
| 68 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 69 |
+
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
| 70 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 71 |
+
'vit_small_patch16_224': _cfg(
|
| 72 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 73 |
+
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
| 74 |
+
'vit_small_patch16_384': _cfg(
|
| 75 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 76 |
+
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
| 77 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 78 |
+
'vit_base_patch32_224': _cfg(
|
| 79 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 80 |
+
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
| 81 |
+
'vit_base_patch32_384': _cfg(
|
| 82 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 83 |
+
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
| 84 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 85 |
+
'vit_base_patch16_224': _cfg(
|
| 86 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 87 |
+
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
|
| 88 |
+
'vit_base_patch16_384': _cfg(
|
| 89 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 90 |
+
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
| 91 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 92 |
+
'vit_large_patch32_224': _cfg(
|
| 93 |
+
url='', # no official model weights for this combo, only for in21k
|
| 94 |
+
),
|
| 95 |
+
'vit_large_patch32_384': _cfg(
|
| 96 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
| 97 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 98 |
+
'vit_large_patch16_224': _cfg(
|
| 99 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 100 |
+
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
|
| 101 |
+
'vit_large_patch16_384': _cfg(
|
| 102 |
+
url='https://storage.googleapis.com/vit_models/augreg/'
|
| 103 |
+
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
| 104 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 105 |
+
|
| 106 |
+
# patch models, imagenet21k (weights from official Google JAX impl)
|
| 107 |
+
'vit_tiny_patch16_224_in21k': _cfg(
|
| 108 |
+
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
| 109 |
+
num_classes=21843),
|
| 110 |
+
'vit_small_patch32_224_in21k': _cfg(
|
| 111 |
+
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
| 112 |
+
num_classes=21843),
|
| 113 |
+
'vit_small_patch16_224_in21k': _cfg(
|
| 114 |
+
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
| 115 |
+
num_classes=21843),
|
| 116 |
+
'vit_base_patch32_224_in21k': _cfg(
|
| 117 |
+
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
|
| 118 |
+
num_classes=21843),
|
| 119 |
+
'vit_base_patch16_224_in21k': _cfg(
|
| 120 |
+
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
|
| 121 |
+
num_classes=21843),
|
| 122 |
+
'vit_large_patch32_224_in21k': _cfg(
|
| 123 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
| 124 |
+
num_classes=21843),
|
| 125 |
+
'vit_large_patch16_224_in21k': _cfg(
|
| 126 |
+
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
|
| 127 |
+
num_classes=21843),
|
| 128 |
+
'vit_huge_patch14_224_in21k': _cfg(
|
| 129 |
+
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
|
| 130 |
+
hf_hub='timm/vit_huge_patch14_224_in21k',
|
| 131 |
+
num_classes=21843),
|
| 132 |
+
|
| 133 |
+
# SAM trained models (https://arxiv.org/abs/2106.01548)
|
| 134 |
+
'vit_base_patch32_sam_224': _cfg(
|
| 135 |
+
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
|
| 136 |
+
'vit_base_patch16_sam_224': _cfg(
|
| 137 |
+
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
|
| 138 |
+
|
| 139 |
+
# deit models (FB weights)
|
| 140 |
+
'deit_tiny_patch16_224': _cfg(
|
| 141 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
|
| 142 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
| 143 |
+
'deit_small_patch16_224': _cfg(
|
| 144 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
|
| 145 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
| 146 |
+
'deit_base_patch16_224': _cfg(
|
| 147 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
|
| 148 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
| 149 |
+
'deit_base_patch16_384': _cfg(
|
| 150 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
| 151 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
|
| 152 |
+
'deit_tiny_distilled_patch16_224': _cfg(
|
| 153 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
| 154 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
| 155 |
+
'deit_small_distilled_patch16_224': _cfg(
|
| 156 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
| 157 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
| 158 |
+
'deit_base_distilled_patch16_224': _cfg(
|
| 159 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
| 160 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
| 161 |
+
'deit_base_distilled_patch16_384': _cfg(
|
| 162 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
| 163 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
|
| 164 |
+
classifier=('head', 'head_dist')),
|
| 165 |
+
|
| 166 |
+
# ViT ImageNet-21K-P pretraining by MILL
|
| 167 |
+
'vit_base_patch16_224_miil_in21k': _cfg(
|
| 168 |
+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
|
| 169 |
+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
| 170 |
+
),
|
| 171 |
+
'vit_base_patch16_224_miil': _cfg(
|
| 172 |
+
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
|
| 173 |
+
'/vit_base_patch16_224_1k_miil_84_4.pth',
|
| 174 |
+
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
|
| 175 |
+
),
|
| 176 |
+
# PaSST
|
| 177 |
+
'passt_s_swa_p16_128_ap476': _cfg(
|
| 178 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt',
|
| 179 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 180 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 181 |
+
'passt_s_kd_p16_128_ap486': _cfg(
|
| 182 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt',
|
| 183 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 184 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 185 |
+
'passt_l_kd_p16_128_ap47': _cfg(
|
| 186 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.10/passt-l-kd-ap.47.pt',
|
| 187 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 188 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 189 |
+
'passt_s_swa_p16_128_ap4761': _cfg(
|
| 190 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.4761-swa.pt',
|
| 191 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 192 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 193 |
+
'passt_s_p16_128_ap472': _cfg(
|
| 194 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.472.pt',
|
| 195 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 196 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 197 |
+
'passt_s_p16_s16_128_ap468': _cfg(
|
| 198 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt',
|
| 199 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 200 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 201 |
+
'passt_s_swa_p16_s16_128_ap473': _cfg(
|
| 202 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.473-swa.pt',
|
| 203 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 204 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 205 |
+
'passt_s_swa_p16_s14_128_ap471': _cfg(
|
| 206 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.471-swa.pt',
|
| 207 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 208 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 209 |
+
'passt_s_p16_s14_128_ap469': _cfg(
|
| 210 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.469.pt',
|
| 211 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 212 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 213 |
+
'passt_s_swa_p16_s12_128_ap473': _cfg(
|
| 214 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.473-swa.pt',
|
| 215 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 216 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 217 |
+
'passt_s_p16_s12_128_ap470': _cfg(
|
| 218 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.470.pt',
|
| 219 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
|
| 220 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 221 |
+
'passt_s_swa_f128_stfthop100_p16_s10_ap473': _cfg(
|
| 222 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop100-p16-s10-ap.473-swa.pt',
|
| 223 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
|
| 224 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 225 |
+
'passt_s_swa_f128_stfthop160_p16_s10_ap473': _cfg(
|
| 226 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop160-p16-s10-ap.473-swa.pt',
|
| 227 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
|
| 228 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 229 |
+
'passt-s-f128-20sec-p16-s10-ap474-swa': _cfg(
|
| 230 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-20sec-p16-s10-ap.474-swa.pt',
|
| 231 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
|
| 232 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 233 |
+
'passt-s-f128-30sec-p16-s10-ap473-swa': _cfg(
|
| 234 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-30sec-p16-s10-ap.473-swa.pt',
|
| 235 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3000), crop_pct=1.0,
|
| 236 |
+
classifier=('head.1', 'head_dist'), num_classes=527),
|
| 237 |
+
'openmic2008_passt_u_f128_p16_s10_ap85_swa': _cfg(
|
| 238 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85-swa.pt',
|
| 239 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
|
| 240 |
+
classifier=('head.1', 'head_dist'), num_classes=20),
|
| 241 |
+
'openmic2008_passt_u_f128_p16_s10_ap85 ': _cfg(
|
| 242 |
+
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85.pt',
|
| 243 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
|
| 244 |
+
classifier=('head.1', 'head_dist'), num_classes=20),
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class Mlp(nn.Module):
|
| 249 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 253 |
+
super().__init__()
|
| 254 |
+
out_features = out_features or in_features
|
| 255 |
+
hidden_features = hidden_features or in_features
|
| 256 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 257 |
+
self.act = act_layer()
|
| 258 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 259 |
+
self.drop = nn.Dropout(drop)
|
| 260 |
+
|
| 261 |
+
def forward(self, x):
|
| 262 |
+
x = self.fc1(x)
|
| 263 |
+
x = self.act(x)
|
| 264 |
+
x = self.drop(x)
|
| 265 |
+
x = self.fc2(x)
|
| 266 |
+
x = self.drop(x)
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
first_RUN = True
|
| 271 |
+
|
| 272 |
+
PLUS1_TRICK = False
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class PatchEmbed(nn.Module):
|
| 276 |
+
""" 2D Image to Patch Embedding
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(self, img_size=224, in_chans=1, frame_nr=1, stride=1, overlap=1, embed_dim=768, norm_layer=None):
|
| 280 |
+
super().__init__()
|
| 281 |
+
img_size = to_2tuple(img_size)
|
| 282 |
+
frame_nr = frame_nr
|
| 283 |
+
stride = stride
|
| 284 |
+
self.img_size = img_size
|
| 285 |
+
self.frame_nr = frame_nr
|
| 286 |
+
self.stride = stride
|
| 287 |
+
self.seq_len = int(img_size[1]) // frame_nr
|
| 288 |
+
self.num_patches = self.seq_len // stride
|
| 289 |
+
self.embed_dim = embed_dim
|
| 290 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=(int(img_size[0]), stride + overlap),
|
| 291 |
+
stride=stride, padding=(0, 1)) # 128 x 2 kernel
|
| 292 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 293 |
+
|
| 294 |
+
def forward(self, x):
|
| 295 |
+
B, C, F, T = x.shape
|
| 296 |
+
if not (F == self.img_size[0] and abs(T - self.img_size[1]) <= 1): # allows for a difference of 1
|
| 297 |
+
warnings.warn(f"Input image size ({F}*{T}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
| 298 |
+
x = self.proj(x)[:, :, :, 1:] # B embed_dim 1 T (F=1)
|
| 299 |
+
x = self.norm(x)
|
| 300 |
+
if first_RUN: print("self.norm(x)", x.size())
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class Attention(nn.Module):
|
| 305 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 306 |
+
super().__init__()
|
| 307 |
+
self.num_heads = num_heads
|
| 308 |
+
head_dim = dim // num_heads
|
| 309 |
+
self.scale = head_dim ** -0.5
|
| 310 |
+
|
| 311 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 312 |
+
self.attn_drop = attn_drop
|
| 313 |
+
self.proj = nn.Linear(dim, dim)
|
| 314 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 315 |
+
|
| 316 |
+
def forward(self, x):
|
| 317 |
+
B, N, C = x.shape
|
| 318 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 319 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 320 |
+
|
| 321 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_drop,
|
| 322 |
+
is_causal=False, scale=self.scale)
|
| 323 |
+
|
| 324 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 325 |
+
x = self.proj(x)
|
| 326 |
+
x = self.proj_drop(x)
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class Block(nn.Module):
|
| 331 |
+
|
| 332 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
| 333 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.norm1 = norm_layer(dim)
|
| 336 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 337 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 338 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 339 |
+
self.norm2 = norm_layer(dim)
|
| 340 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 341 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 345 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 346 |
+
return x
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class PaSST(nn.Module):
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
Based on the implementation of Vision Transformer in timm library.
|
| 353 |
+
Take a look at the get_model function, adapting the weights of pretrained imagenet models.
|
| 354 |
+
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(self, img_size=(128, 998),
|
| 358 |
+
in_chans=1, num_classes=527, embed_dim=768, depth=12,
|
| 359 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
|
| 360 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
| 361 |
+
act_layer=None, weight_init='',
|
| 362 |
+
frame_patchout=300, frame_nr=1, pos_embed_length=1000):
|
| 363 |
+
"""
|
| 364 |
+
Args:
|
| 365 |
+
img_size (int, tuple): input image size
|
| 366 |
+
in_chans (int): number of input channels
|
| 367 |
+
num_classes (int): number of classes for classification head
|
| 368 |
+
embed_dim (int): embedding dimension
|
| 369 |
+
depth (int): depth of transformer
|
| 370 |
+
num_heads (int): number of attention heads
|
| 371 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 372 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 373 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 374 |
+
distilled (bool): model includes a distillation token and head as in DeiT models
|
| 375 |
+
drop_rate (float): dropout rate
|
| 376 |
+
attn_drop_rate (float): attention dropout rate
|
| 377 |
+
drop_path_rate (float): stochastic depth rate
|
| 378 |
+
embed_layer (nn.Module): patch embedding layer
|
| 379 |
+
norm_layer: (nn.Module): normalization layer
|
| 380 |
+
act_layer: (nn.Module): activation layer
|
| 381 |
+
weight_init: (str): weight init scheme
|
| 382 |
+
frame_patchout (int): number of frames to patch out
|
| 383 |
+
frame_nr (int): the second dimension of the proj-convolution kernel
|
| 384 |
+
pos_embed_length (int): length of the positional embedding
|
| 385 |
+
"""
|
| 386 |
+
super().__init__()
|
| 387 |
+
self.num_classes = num_classes
|
| 388 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 389 |
+
self.num_tokens = 2 if distilled else 1
|
| 390 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 391 |
+
act_layer = act_layer or nn.GELU
|
| 392 |
+
self.act_layer = act_layer()
|
| 393 |
+
self.in_chans = in_chans
|
| 394 |
+
self.frame_patchout = frame_patchout
|
| 395 |
+
self.pos_embed_len = pos_embed_length
|
| 396 |
+
|
| 397 |
+
# these three convolution are different compared to the vanilla passt
|
| 398 |
+
self.conv_in_1 = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
| 399 |
+
self.conv_in_2 = nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
| 400 |
+
self.conv_in_3 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # 64 instead of 4
|
| 401 |
+
img_size = (img_size[0], pos_embed_length) # 128, 250
|
| 402 |
+
|
| 403 |
+
self.patch_embed = embed_layer(
|
| 404 |
+
img_size=img_size, in_chans=in_chans, frame_nr=frame_nr, stride=frame_nr, embed_dim=embed_dim)
|
| 405 |
+
num_patches = self.patch_embed.num_patches
|
| 406 |
+
|
| 407 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 408 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
| 409 |
+
# PaSST
|
| 410 |
+
# refer to https://arxiv.org/abs/2110.05069 Section 2
|
| 411 |
+
self.new_pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) # for C and D tokens
|
| 412 |
+
self.freq_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, 1)) # | f
|
| 413 |
+
self.time_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, self.pos_embed_len)) # __ t
|
| 414 |
+
####
|
| 415 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 416 |
+
|
| 417 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 418 |
+
self.blocks = nn.Sequential(*[
|
| 419 |
+
Block(
|
| 420 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
| 421 |
+
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
| 422 |
+
for i in range(depth)])
|
| 423 |
+
self.norm = norm_layer(embed_dim)
|
| 424 |
+
|
| 425 |
+
# Representation layer
|
| 426 |
+
if representation_size and not distilled:
|
| 427 |
+
self.num_features = representation_size
|
| 428 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
| 429 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
| 430 |
+
('act', nn.Tanh())
|
| 431 |
+
]))
|
| 432 |
+
else:
|
| 433 |
+
self.pre_logits = nn.Identity()
|
| 434 |
+
|
| 435 |
+
self.init_weights(weight_init)
|
| 436 |
+
|
| 437 |
+
def init_weights(self, mode=''):
|
| 438 |
+
assert mode in ('jax', 'jax_nlhb', 'nlhb', ''), f"mode: {mode}"
|
| 439 |
+
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
| 440 |
+
trunc_normal_(self.new_pos_embed, std=.02)
|
| 441 |
+
trunc_normal_(self.freq_new_pos_embed, std=.02)
|
| 442 |
+
trunc_normal_(self.time_new_pos_embed, std=.02)
|
| 443 |
+
if self.dist_token is not None:
|
| 444 |
+
trunc_normal_(self.dist_token, std=.02)
|
| 445 |
+
if mode.startswith('jax'):
|
| 446 |
+
# leave cls token as zeros to match jax impl
|
| 447 |
+
raise RuntimeError("Not supported yet")
|
| 448 |
+
else:
|
| 449 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 450 |
+
self.apply(_init_vit_weights)
|
| 451 |
+
|
| 452 |
+
def _init_weights(self, m):
|
| 453 |
+
# this fn left here for compat with downstream users
|
| 454 |
+
_init_vit_weights(m)
|
| 455 |
+
|
| 456 |
+
@torch.jit.ignore
|
| 457 |
+
def no_weight_decay(self):
|
| 458 |
+
return {'new_pos_embed', 'freq_new_pos_embed', 'time_new_pos_embed', 'cls_token', 'dist_token'}
|
| 459 |
+
|
| 460 |
+
def forward_features(self, x):
|
| 461 |
+
global first_RUN # not jit friendly? use trace instead
|
| 462 |
+
|
| 463 |
+
# some 2D convolutions
|
| 464 |
+
f_dim = x.size(2) # 128
|
| 465 |
+
x = self.act_layer(self.conv_in_1(x))
|
| 466 |
+
x = self.act_layer(self.conv_in_2(x))
|
| 467 |
+
x = self.act_layer(self.conv_in_3(x))
|
| 468 |
+
if first_RUN: print("after convs", x.size())
|
| 469 |
+
x = x.reshape(x.shape[0], (x.shape[1] * x.shape[2]) // f_dim, f_dim, x.shape[3])
|
| 470 |
+
if first_RUN: print("after reshape", x.size())
|
| 471 |
+
|
| 472 |
+
x = self.patch_embed(x) # [b, e, f, t]
|
| 473 |
+
B_dim, E_dim, F_dim, T_dim = x.shape # slow
|
| 474 |
+
if first_RUN: print(" patch_embed : ", x.shape)
|
| 475 |
+
# Adding Time/Freq information
|
| 476 |
+
if first_RUN: print(" self.time_new_pos_embed.shape", self.time_new_pos_embed.shape)
|
| 477 |
+
time_new_pos_embed = self.time_new_pos_embed
|
| 478 |
+
if x.shape[-1] < time_new_pos_embed.shape[-1]:
|
| 479 |
+
if self.training:
|
| 480 |
+
toffset = torch.randint(1 + time_new_pos_embed.shape[-1] - x.shape[-1], (1,)).item()
|
| 481 |
+
if first_RUN: print(f" CUT with randomoffset={toffset} time_new_pos_embed.shape",
|
| 482 |
+
time_new_pos_embed.shape)
|
| 483 |
+
time_new_pos_embed = time_new_pos_embed[:, :, :, toffset:toffset + x.shape[-1]]
|
| 484 |
+
else:
|
| 485 |
+
time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]]
|
| 486 |
+
if first_RUN: print(" CUT time_new_pos_embed.shape", time_new_pos_embed.shape)
|
| 487 |
+
else:
|
| 488 |
+
# warnings.warn(
|
| 489 |
+
# f"the patches shape:{x.shape} are larger than the expected time encodings {time_new_pos_embed.shape}, x will be cut")
|
| 490 |
+
x = x[:, :, :, :time_new_pos_embed.shape[-1]]
|
| 491 |
+
x = x + time_new_pos_embed
|
| 492 |
+
if first_RUN: print(" self.freq_new_pos_embed.shape", self.freq_new_pos_embed.shape)
|
| 493 |
+
x = x + self.freq_new_pos_embed
|
| 494 |
+
|
| 495 |
+
# Structured Patchout https://arxiv.org/abs/2110.05069 Section 2.2
|
| 496 |
+
if self.training and self.frame_patchout:
|
| 497 |
+
if first_RUN: print(f"X Before frame Patchout of {self.frame_patchout} ", x.size())
|
| 498 |
+
# ([1, 768, 1, 82])
|
| 499 |
+
random_indices = torch.randperm(T_dim)[:T_dim - self.frame_patchout].sort().values
|
| 500 |
+
x = x[:, :, :, random_indices]
|
| 501 |
+
if first_RUN: print("X after frame Patchout", x.size())
|
| 502 |
+
|
| 503 |
+
x = x.flatten(2).transpose(1, 2)
|
| 504 |
+
|
| 505 |
+
# Add the C/D tokens
|
| 506 |
+
if first_RUN: print(" self.new_pos_embed.shape", self.new_pos_embed.shape)
|
| 507 |
+
cls_tokens = self.cls_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, :1, :]
|
| 508 |
+
if first_RUN: print(" self.cls_tokens.shape", cls_tokens.shape)
|
| 509 |
+
if self.dist_token is None:
|
| 510 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 511 |
+
else:
|
| 512 |
+
dist_token = self.dist_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, 1:, :]
|
| 513 |
+
if first_RUN: print(" self.dist_token.shape", dist_token.shape)
|
| 514 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
| 515 |
+
|
| 516 |
+
if first_RUN: print(" final sequence x", x.shape)
|
| 517 |
+
x = self.pos_drop(x)
|
| 518 |
+
x = self.blocks(x)
|
| 519 |
+
if first_RUN: print(f" after {len(self.blocks)} atten blocks x", x.shape)
|
| 520 |
+
x = self.norm(x)
|
| 521 |
+
return x
|
| 522 |
+
|
| 523 |
+
def forward(self, x):
|
| 524 |
+
global first_RUN
|
| 525 |
+
if first_RUN: print("x", x.size())
|
| 526 |
+
x = self.forward_features(x)
|
| 527 |
+
c, x = x[:, :2].mean(1), x[:, 2:]
|
| 528 |
+
if first_RUN: print("x after forward_features", x.size())
|
| 529 |
+
first_RUN = False
|
| 530 |
+
return x
|
| 531 |
+
|
| 532 |
+
def load_model(self, path, wandb_id):
|
| 533 |
+
ckpt_path = os.path.join(path, wandb_id + ".ckpt")
|
| 534 |
+
|
| 535 |
+
pretrained_weights = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 536 |
+
pretrained_weights = {k[10:]: v for k, v in pretrained_weights.items() if k[:10] == "net.model."}
|
| 537 |
+
self.load_state_dict(pretrained_weights)
|
| 538 |
+
|
| 539 |
+
print("Loaded model successfully. Wandb_id:", wandb_id)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
|
| 543 |
+
""" ViT weight initialization
|
| 544 |
+
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
| 545 |
+
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
| 546 |
+
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
| 547 |
+
"""
|
| 548 |
+
if isinstance(module, nn.Linear):
|
| 549 |
+
if name.startswith('head'):
|
| 550 |
+
nn.init.zeros_(module.weight)
|
| 551 |
+
nn.init.constant_(module.bias, head_bias)
|
| 552 |
+
elif name.startswith('pre_logits'):
|
| 553 |
+
lecun_normal_(module.weight)
|
| 554 |
+
nn.init.zeros_(module.bias)
|
| 555 |
+
else:
|
| 556 |
+
if jax_impl:
|
| 557 |
+
nn.init.xavier_uniform_(module.weight)
|
| 558 |
+
if module.bias is not None:
|
| 559 |
+
if 'mlp' in name:
|
| 560 |
+
nn.init.normal_(module.bias, std=1e-6)
|
| 561 |
+
else:
|
| 562 |
+
nn.init.zeros_(module.bias)
|
| 563 |
+
else:
|
| 564 |
+
trunc_normal_(module.weight, std=.02)
|
| 565 |
+
if module.bias is not None:
|
| 566 |
+
nn.init.zeros_(module.bias)
|
| 567 |
+
elif jax_impl and isinstance(module, nn.Conv2d):
|
| 568 |
+
# NOTE conv was left to pytorch default in my original init
|
| 569 |
+
lecun_normal_(module.weight)
|
| 570 |
+
if module.bias is not None:
|
| 571 |
+
nn.init.zeros_(module.bias)
|
| 572 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 573 |
+
nn.init.zeros_(module.bias)
|
| 574 |
+
nn.init.ones_(module.weight)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), mode='bicubic'):
|
| 578 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
| 579 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
| 580 |
+
_logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, posemb_new.shape,
|
| 581 |
+
num_tokens)
|
| 582 |
+
ntok_new = posemb_new.shape[1]
|
| 583 |
+
if num_tokens:
|
| 584 |
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
| 585 |
+
ntok_new -= num_tokens
|
| 586 |
+
else:
|
| 587 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
| 588 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
| 589 |
+
if not len(gs_new): # backwards compatibility
|
| 590 |
+
gs_new = [int(math.sqrt(ntok_new))] * 2
|
| 591 |
+
assert len(gs_new) >= 2
|
| 592 |
+
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
| 593 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
| 594 |
+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False)
|
| 595 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
| 596 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
| 597 |
+
return posemb
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def adapt_image_pos_embed_to_passt(posemb, num_tokens=1, posemb_len=1000, mode='bicubic'):
|
| 601 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
| 602 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
| 603 |
+
if num_tokens:
|
| 604 |
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
| 605 |
+
else:
|
| 606 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
| 607 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
| 608 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
| 609 |
+
posemb_grid = F.interpolate(posemb_grid, size=(1, posemb_len), mode=mode, align_corners=False)
|
| 610 |
+
|
| 611 |
+
freq_new_pos_embed = posemb_grid.mean(dim=3, keepdim=True)
|
| 612 |
+
time_new_pos_embed = posemb_grid.mean(dim=2, keepdim=True)
|
| 613 |
+
_logger.info('New Position cls/dstl embedding %s', posemb_tok.shape)
|
| 614 |
+
_logger.info('New FREQ Position embedding %s', freq_new_pos_embed.shape)
|
| 615 |
+
_logger.info('New TIME Position embedding %s', time_new_pos_embed.shape)
|
| 616 |
+
return posemb_tok, freq_new_pos_embed, time_new_pos_embed
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def checkpoint_filter_fn(state_dict, model):
|
| 620 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
| 621 |
+
out_dict = {}
|
| 622 |
+
if 'model' in state_dict:
|
| 623 |
+
# For deit models
|
| 624 |
+
state_dict = state_dict['model']
|
| 625 |
+
state_dict = {k: v for k, v in state_dict.items()}
|
| 626 |
+
if "time_new_pos_embed" not in state_dict:
|
| 627 |
+
# we are working with ImageNet model
|
| 628 |
+
_logger.info("Adapting pos embedding from ImageNet pretrained model to PaSST.")
|
| 629 |
+
v = state_dict.pop("pos_embed")
|
| 630 |
+
new_pos_embed, freq_new_pos_embed, time_new_pos_embed = adapt_image_pos_embed_to_passt(
|
| 631 |
+
v, getattr(model, 'num_tokens', 1), model.pos_embed_len)
|
| 632 |
+
state_dict["new_pos_embed"] = new_pos_embed
|
| 633 |
+
state_dict["freq_new_pos_embed"] = freq_new_pos_embed
|
| 634 |
+
state_dict["time_new_pos_embed"] = time_new_pos_embed
|
| 635 |
+
|
| 636 |
+
for k, v in state_dict.items():
|
| 637 |
+
if 'patch_embed.proj.weight' in k:
|
| 638 |
+
embed_dim, C, H, W = v.shape
|
| 639 |
+
v = adapt_input_conv(model.in_chans, v, input_conv_name=k)
|
| 640 |
+
k1, k2 = model.patch_embed.proj.kernel_size # 128, 2
|
| 641 |
+
|
| 642 |
+
# clever reshape
|
| 643 |
+
assert H * W == k1 * k2, "Error in the kernel size of the patch embedding"
|
| 644 |
+
|
| 645 |
+
v = v.reshape(embed_dim, model.in_chans, k1, k2) # [embed_dim, 1, k1, k2]
|
| 646 |
+
|
| 647 |
+
out_dict[k] = v
|
| 648 |
+
return out_dict
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
| 652 |
+
default_cfg = default_cfg or default_cfgs[variant]
|
| 653 |
+
if kwargs.get('features_only', None):
|
| 654 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
| 655 |
+
|
| 656 |
+
# NOTE this extra code to support handling of repr size for in21k pretrained models
|
| 657 |
+
default_num_classes = default_cfg['num_classes']
|
| 658 |
+
num_classes = kwargs.get('num_classes', default_num_classes)
|
| 659 |
+
repr_size = kwargs.pop('representation_size', None)
|
| 660 |
+
if repr_size is not None and num_classes != default_num_classes:
|
| 661 |
+
# Remove representation layer if fine-tuning. This may not always be the desired action,
|
| 662 |
+
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
|
| 663 |
+
_logger.warning("Removing representation layer for fine-tuning.")
|
| 664 |
+
repr_size = None
|
| 665 |
+
|
| 666 |
+
model = build_model_with_cfg(
|
| 667 |
+
PaSST, variant, pretrained,
|
| 668 |
+
default_cfg=default_cfg,
|
| 669 |
+
representation_size=repr_size,
|
| 670 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 671 |
+
pretrained_custom_load='npz' in default_cfg['url'],
|
| 672 |
+
**kwargs)
|
| 673 |
+
return model
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
| 677 |
+
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
| 678 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 679 |
+
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
|
| 680 |
+
"""
|
| 681 |
+
model_kwargs = dict(
|
| 682 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
|
| 683 |
+
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
|
| 684 |
+
return model
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
| 688 |
+
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
| 689 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 690 |
+
"""
|
| 691 |
+
|
| 692 |
+
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 693 |
+
model = _create_vision_transformer(
|
| 694 |
+
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 695 |
+
return model
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):
|
| 699 |
+
""" PaSST pre-trained on AudioSet
|
| 700 |
+
"""
|
| 701 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=476 SWA \n\n")
|
| 702 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 703 |
+
if model_kwargs.get("stride") != (10, 10):
|
| 704 |
+
warnings.warn(
|
| 705 |
+
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 706 |
+
model = _create_vision_transformer(
|
| 707 |
+
'passt_s_swa_p16_128_ap476', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 708 |
+
return model
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
|
| 712 |
+
""" PaSST pre-trained on AudioSet
|
| 713 |
+
"""
|
| 714 |
+
print("\n\n Loading PaSST pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=486 \n\n")
|
| 715 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 716 |
+
if model_kwargs.get("stride") != (10, 10):
|
| 717 |
+
warnings.warn(
|
| 718 |
+
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 719 |
+
model = _create_vision_transformer(
|
| 720 |
+
'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 721 |
+
return model
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs):
|
| 725 |
+
""" PaSST pre-trained on AudioSet
|
| 726 |
+
"""
|
| 727 |
+
print(
|
| 728 |
+
"\n\n Loading PaSST-L (light, reduced depth=7) pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=4708 \n\n")
|
| 729 |
+
model_kwargs = dict(patch_size=16, embed_dim=768,
|
| 730 |
+
depth=7, num_heads=12, **kwargs)
|
| 731 |
+
if model_kwargs.get("stride") != (10, 10):
|
| 732 |
+
warnings.warn(
|
| 733 |
+
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 734 |
+
model = _create_vision_transformer(
|
| 735 |
+
'passt_l_kd_p16_128_ap47', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 736 |
+
return model
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs):
|
| 740 |
+
""" PaSST pre-trained on AudioSet
|
| 741 |
+
"""
|
| 742 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=4763 SWA \n\n")
|
| 743 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 744 |
+
if model_kwargs.get("stride") != (10, 10):
|
| 745 |
+
warnings.warn(
|
| 746 |
+
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 747 |
+
model = _create_vision_transformer(
|
| 748 |
+
'passt_s_swa_p16_128_ap4761', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 749 |
+
return model
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def passt_s_p16_128_ap472(pretrained=False, **kwargs):
|
| 753 |
+
""" PaSST pre-trained on AudioSet
|
| 754 |
+
"""
|
| 755 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=472 \n\n")
|
| 756 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 757 |
+
if model_kwargs.get("stride") != (10, 10):
|
| 758 |
+
warnings.warn(
|
| 759 |
+
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 760 |
+
model = _create_vision_transformer(
|
| 761 |
+
'passt_s_p16_128_ap472', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 762 |
+
return model
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def passt_s_p16_s12_128_ap470(pretrained=False, **kwargs):
|
| 766 |
+
""" PaSST pre-trained on AudioSet
|
| 767 |
+
"""
|
| 768 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n")
|
| 769 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 770 |
+
if model_kwargs.get("stride") != (12, 12):
|
| 771 |
+
warnings.warn(
|
| 772 |
+
f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 773 |
+
model = _create_vision_transformer(
|
| 774 |
+
'passt_s_p16_s12_128_ap470', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 775 |
+
return model
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs):
|
| 779 |
+
print("\n\n Loading PASST TRAINED ON AUDISET with 20 Second time encodings, with STFT hop of 160 \n\n")
|
| 780 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 781 |
+
model = _create_vision_transformer(
|
| 782 |
+
'passt-s-f128-20sec-p16-s10-ap474-swa', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 783 |
+
return model
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs):
|
| 787 |
+
print("\n\n Loading PASST TRAINED ON AUDISET with 30 Second time encodings, with STFT hop of 160 \n\n")
|
| 788 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 789 |
+
model = _create_vision_transformer(
|
| 790 |
+
'passt-s-f128-30sec-p16-s10-ap473-swa', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 791 |
+
return model
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
def passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs):
|
| 795 |
+
""" PaSST pre-trained on AudioSet
|
| 796 |
+
"""
|
| 797 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n")
|
| 798 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 799 |
+
if model_kwargs.get("stride") != (12, 12):
|
| 800 |
+
warnings.warn(
|
| 801 |
+
f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 802 |
+
model = _create_vision_transformer(
|
| 803 |
+
'passt_s_swa_p16_s12_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 804 |
+
return model
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def passt_s_p16_s14_128_ap469(pretrained=False, **kwargs):
|
| 808 |
+
""" PaSST pre-trained on AudioSet
|
| 809 |
+
"""
|
| 810 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n")
|
| 811 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 812 |
+
if model_kwargs.get("stride") != (14, 14):
|
| 813 |
+
warnings.warn(
|
| 814 |
+
f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 815 |
+
model = _create_vision_transformer(
|
| 816 |
+
'passt_s_p16_s14_128_ap469', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 817 |
+
return model
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
def passt_s_swa_p16_s14_128_ap471(pretrained=False, **kwargs):
|
| 821 |
+
""" PaSST pre-trained on AudioSet
|
| 822 |
+
"""
|
| 823 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n")
|
| 824 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 825 |
+
if model_kwargs.get("stride") != (14, 14):
|
| 826 |
+
warnings.warn(
|
| 827 |
+
f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 828 |
+
model = _create_vision_transformer(
|
| 829 |
+
'passt_s_swa_p16_s14_128_ap471', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 830 |
+
return model
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
def passt_s_swa_p16_s16_128_ap473(pretrained=False, **kwargs):
|
| 834 |
+
""" PaSST pre-trained on AudioSet
|
| 835 |
+
"""
|
| 836 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n")
|
| 837 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 838 |
+
if model_kwargs.get("stride") != (16, 16):
|
| 839 |
+
warnings.warn(
|
| 840 |
+
f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 841 |
+
model = _create_vision_transformer(
|
| 842 |
+
'passt_s_swa_p16_s16_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 843 |
+
return model
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):
|
| 847 |
+
""" PaSST pre-trained on AudioSet
|
| 848 |
+
"""
|
| 849 |
+
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n")
|
| 850 |
+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
| 851 |
+
if model_kwargs.get("stride") != (16, 16):
|
| 852 |
+
warnings.warn(
|
| 853 |
+
f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
|
| 854 |
+
model = _create_vision_transformer(
|
| 855 |
+
'passt_s_p16_s16_128_ap468', pretrained=pretrained, distilled=True, **model_kwargs)
|
| 856 |
+
return model
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def fix_embedding_layer(model, embed="default"):
|
| 860 |
+
if embed == "default":
|
| 861 |
+
return model
|
| 862 |
+
if embed == "overlap":
|
| 863 |
+
model.patch_embed = PatchEmbedAdaptiveMean(replace=model.patch_embed)
|
| 864 |
+
if embed == "am_keepconv":
|
| 865 |
+
model.patch_embed = PatchEmbedAdaptiveMeanKeepConv(replace=model.patch_embed)
|
| 866 |
+
return model
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def lighten_model(model, cut_depth=0):
|
| 870 |
+
if cut_depth == 0:
|
| 871 |
+
return model
|
| 872 |
+
if cut_depth:
|
| 873 |
+
if cut_depth < 0:
|
| 874 |
+
print(f"\n Reducing model depth by removing every {-cut_depth} layer \n\n")
|
| 875 |
+
else:
|
| 876 |
+
print(f"\n Reducing model depth by {cut_depth} \n\n")
|
| 877 |
+
if len(model.blocks) < cut_depth + 2:
|
| 878 |
+
raise ValueError(f"Cut depth a VIT with {len(model.blocks)} "
|
| 879 |
+
f"layers should be between 1 and {len(model.blocks) - 2}")
|
| 880 |
+
print(f"\n Before Cutting it was {len(model.blocks)} \n\n")
|
| 881 |
+
|
| 882 |
+
old_blocks = list(model.blocks.children())
|
| 883 |
+
if cut_depth < 0:
|
| 884 |
+
print(f"cut_depth={cut_depth}")
|
| 885 |
+
old_blocks = [old_blocks[0]] + old_blocks[1:-1:-cut_depth] + [old_blocks[-1]]
|
| 886 |
+
else:
|
| 887 |
+
old_blocks = [old_blocks[0]] + old_blocks[cut_depth + 1:]
|
| 888 |
+
model.blocks = nn.Sequential(*old_blocks)
|
| 889 |
+
print(f"\n Atfer Cutting it is {len(model.blocks)} \n\n")
|
| 890 |
+
return model
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, in_channels=1,
|
| 894 |
+
input_fdim=128, input_tdim=998, frame_patchout=300, pos_embed_length=1000
|
| 895 |
+
):
|
| 896 |
+
"""
|
| 897 |
+
:param arch: Base ViT or Deit architecture
|
| 898 |
+
:param pretrained: use pretrained model on imagenet
|
| 899 |
+
:param n_classes: number of classes
|
| 900 |
+
:param in_channels: number of input channels: 1 for mono
|
| 901 |
+
:param input_fdim: the expected input frequency bins.
|
| 902 |
+
:param input_tdim: the expected input time bins.
|
| 903 |
+
:param frame_patchout: the number of frames to be removed from the input
|
| 904 |
+
@param wandb_id: tries to load model with corresponding wandb_id from 'pretrained_path'
|
| 905 |
+
:return:
|
| 906 |
+
|
| 907 |
+
"""
|
| 908 |
+
model_func = None
|
| 909 |
+
input_size = (input_fdim, input_tdim)
|
| 910 |
+
if arch == "passt_deit_bd_p16_384": # base deit
|
| 911 |
+
model_func = deit_base_distilled_patch16_384
|
| 912 |
+
elif arch == "passt_s_kd_p16_128_ap486": # pretrained
|
| 913 |
+
model_func = passt_s_kd_p16_128_ap486
|
| 914 |
+
elif arch == "passt_l_kd_p16_128_ap47": # pretrained passt-L
|
| 915 |
+
model_func = passt_l_kd_p16_128_ap47
|
| 916 |
+
elif arch == "passt_s_swa_p16_128_ap476": # pretrained
|
| 917 |
+
model_func = passt_s_swa_p16_128_ap476
|
| 918 |
+
elif arch == "passt_s_swa_p16_128_ap4761":
|
| 919 |
+
model_func = passt_s_swa_p16_128_ap4761
|
| 920 |
+
elif arch == "passt_s_p16_128_ap472":
|
| 921 |
+
model_func = passt_s_p16_128_ap472
|
| 922 |
+
elif arch == "passt_s_p16_s16_128_ap468":
|
| 923 |
+
model_func = passt_s_p16_s16_128_ap468
|
| 924 |
+
elif arch == "passt_s_swa_p16_s16_128_ap473":
|
| 925 |
+
model_func = passt_s_swa_p16_s16_128_ap473
|
| 926 |
+
elif arch == "passt_s_swa_p16_s14_128_ap471":
|
| 927 |
+
model_func = passt_s_swa_p16_s14_128_ap471
|
| 928 |
+
elif arch == "passt_s_p16_s14_128_ap469":
|
| 929 |
+
model_func = passt_s_p16_s14_128_ap469
|
| 930 |
+
elif arch == "passt_s_swa_p16_s12_128_ap473":
|
| 931 |
+
model_func = passt_s_swa_p16_s12_128_ap473
|
| 932 |
+
elif arch == "passt_s_p16_s12_128_ap470":
|
| 933 |
+
model_func = passt_s_p16_s12_128_ap470
|
| 934 |
+
elif arch == "passt_s_f128_20sec_p16_s10_ap474":
|
| 935 |
+
model_func = passt_s_f128_20sec_p16_s10_ap474_swa
|
| 936 |
+
elif arch == "passt_s_f128_30sec_p16_s10_ap473":
|
| 937 |
+
model_func = passt_s_f128_30sec_p16_s10_ap473_swa
|
| 938 |
+
|
| 939 |
+
if model_func is None:
|
| 940 |
+
raise RuntimeError(f"Unknown model {arch}")
|
| 941 |
+
model = model_func(pretrained=pretrained, num_classes=n_classes, in_chans=in_channels,
|
| 942 |
+
img_size=input_size, frame_patchout=frame_patchout, pos_embed_length=pos_embed_length)
|
| 943 |
+
model = fix_embedding_layer(model)
|
| 944 |
+
model = lighten_model(model)
|
| 945 |
+
return model
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
class EnsembelerModel(nn.Module):
|
| 949 |
+
def __init__(self, models):
|
| 950 |
+
super(EnsembelerModel, self).__init__()
|
| 951 |
+
self.models = nn.ModuleList(models)
|
| 952 |
+
|
| 953 |
+
def forward(self, x):
|
| 954 |
+
# ModuleList can act as an iterable, or be indexed using ints
|
| 955 |
+
all_out = None
|
| 956 |
+
for i, m in enumerate(self.models):
|
| 957 |
+
out, _ = m(x)
|
| 958 |
+
if all_out is None:
|
| 959 |
+
all_out = out
|
| 960 |
+
else:
|
| 961 |
+
all_out = out + all_out
|
| 962 |
+
all_out = all_out / len(self.models)
|
| 963 |
+
return all_out, all_out
|
models/frame_passt/fpasst_wrapper.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.frame_passt.fpasst import get_model
|
| 2 |
+
from models.frame_passt.preprocess import AugmentMelSTFT
|
| 3 |
+
from models.transformer_wrapper import BaseModelWrapper
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FPaSSTWrapper(BaseModelWrapper):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.mel = AugmentMelSTFT(
|
| 10 |
+
n_mels=128,
|
| 11 |
+
sr=16_000,
|
| 12 |
+
win_length=400,
|
| 13 |
+
hopsize=160,
|
| 14 |
+
n_fft=512,
|
| 15 |
+
freqm=0,
|
| 16 |
+
timem=0,
|
| 17 |
+
htk=False,
|
| 18 |
+
fmin=0.0,
|
| 19 |
+
fmax=None,
|
| 20 |
+
norm=1,
|
| 21 |
+
fmin_aug_range=10,
|
| 22 |
+
fmax_aug_range=2000,
|
| 23 |
+
fast_norm=True,
|
| 24 |
+
preamp=True,
|
| 25 |
+
)
|
| 26 |
+
self.fpasst = get_model(
|
| 27 |
+
arch="passt_deit_bd_p16_384",
|
| 28 |
+
n_classes=527,
|
| 29 |
+
pos_embed_length=250,
|
| 30 |
+
frame_patchout=0,
|
| 31 |
+
in_channels=16
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def mel_forward(self, x):
|
| 35 |
+
return self.mel(x)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
return self.fpasst(x)
|
| 39 |
+
|
| 40 |
+
def separate_params(self):
|
| 41 |
+
pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
|
| 42 |
+
for k, p in self.fpasst.named_parameters():
|
| 43 |
+
if k in ['cls_token',
|
| 44 |
+
'dist_token',
|
| 45 |
+
'new_pos_embed',
|
| 46 |
+
'freq_new_pos_embed',
|
| 47 |
+
'time_new_pos_embed',
|
| 48 |
+
'conv_in_1.weight',
|
| 49 |
+
'conv_in_1.bias',
|
| 50 |
+
'conv_in_2.weight',
|
| 51 |
+
'conv_in_2.bias',
|
| 52 |
+
'conv_in_3.weight',
|
| 53 |
+
'conv_in_3.bias',
|
| 54 |
+
'patch_embed.proj.weight',
|
| 55 |
+
'patch_embed.proj.bias',
|
| 56 |
+
]:
|
| 57 |
+
pt_params[0].append(p)
|
| 58 |
+
elif 'blocks.0.' in k:
|
| 59 |
+
pt_params[0].append(p)
|
| 60 |
+
elif 'blocks.1.' in k:
|
| 61 |
+
pt_params[1].append(p)
|
| 62 |
+
elif 'blocks.2.' in k:
|
| 63 |
+
pt_params[2].append(p)
|
| 64 |
+
elif 'blocks.3.' in k:
|
| 65 |
+
pt_params[3].append(p)
|
| 66 |
+
elif 'blocks.4.' in k:
|
| 67 |
+
pt_params[4].append(p)
|
| 68 |
+
elif 'blocks.5.' in k:
|
| 69 |
+
pt_params[5].append(p)
|
| 70 |
+
elif 'blocks.6.' in k:
|
| 71 |
+
pt_params[6].append(p)
|
| 72 |
+
elif 'blocks.7.' in k:
|
| 73 |
+
pt_params[7].append(p)
|
| 74 |
+
elif 'blocks.8.' in k:
|
| 75 |
+
pt_params[8].append(p)
|
| 76 |
+
elif 'blocks.9.' in k:
|
| 77 |
+
pt_params[9].append(p)
|
| 78 |
+
elif 'blocks.10.' in k:
|
| 79 |
+
pt_params[10].append(p)
|
| 80 |
+
elif 'blocks.11.' in k:
|
| 81 |
+
pt_params[11].append(p)
|
| 82 |
+
elif k in ['norm.weight', 'norm.bias']:
|
| 83 |
+
pt_params[11].append(p)
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Check separate params for frame-passt! Unexpected key: {k}")
|
| 86 |
+
return list(reversed(pt_params))
|
models/frame_passt/preprocess.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchaudio
|
| 4 |
+
|
| 5 |
+
sz_float = 4 # size of a float
|
| 6 |
+
epsilon = 10e-8 # fudge factor for normalization
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AugmentMelSTFT(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
n_mels=128,
|
| 13 |
+
sr=32000,
|
| 14 |
+
win_length=None,
|
| 15 |
+
hopsize=320,
|
| 16 |
+
n_fft=1024,
|
| 17 |
+
freqm=0,
|
| 18 |
+
timem=0,
|
| 19 |
+
htk=False,
|
| 20 |
+
fmin=0.0,
|
| 21 |
+
fmax=None,
|
| 22 |
+
norm=1,
|
| 23 |
+
fmin_aug_range=1,
|
| 24 |
+
fmax_aug_range=1,
|
| 25 |
+
fast_norm=False,
|
| 26 |
+
preamp=True,
|
| 27 |
+
padding="center",
|
| 28 |
+
periodic_window=True,
|
| 29 |
+
):
|
| 30 |
+
torch.nn.Module.__init__(self)
|
| 31 |
+
# adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e
|
| 32 |
+
# Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast
|
| 33 |
+
|
| 34 |
+
if win_length is None:
|
| 35 |
+
win_length = n_fft
|
| 36 |
+
|
| 37 |
+
if isinstance(win_length, list) or isinstance(win_length, tuple):
|
| 38 |
+
assert isinstance(n_fft, list) or isinstance(n_fft, tuple)
|
| 39 |
+
assert len(win_length) == len(n_fft)
|
| 40 |
+
else:
|
| 41 |
+
win_length = [win_length]
|
| 42 |
+
n_fft = [n_fft]
|
| 43 |
+
|
| 44 |
+
self.win_length = win_length
|
| 45 |
+
self.n_mels = n_mels
|
| 46 |
+
self.n_fft = n_fft
|
| 47 |
+
self.sr = sr
|
| 48 |
+
self.htk = htk
|
| 49 |
+
self.fmin = fmin
|
| 50 |
+
if fmax is None:
|
| 51 |
+
fmax = sr // 2 - fmax_aug_range // 2
|
| 52 |
+
self.fmax = fmax
|
| 53 |
+
self.norm = norm
|
| 54 |
+
self.hopsize = hopsize
|
| 55 |
+
self.preamp = preamp
|
| 56 |
+
for win_l in self.win_length:
|
| 57 |
+
self.register_buffer(
|
| 58 |
+
f"window_{win_l}",
|
| 59 |
+
torch.hann_window(win_l, periodic=periodic_window),
|
| 60 |
+
persistent=False,
|
| 61 |
+
)
|
| 62 |
+
assert (
|
| 63 |
+
fmin_aug_range >= 1
|
| 64 |
+
), f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
|
| 65 |
+
assert (
|
| 66 |
+
fmin_aug_range >= 1
|
| 67 |
+
), f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
|
| 68 |
+
self.fmin_aug_range = fmin_aug_range
|
| 69 |
+
self.fmax_aug_range = fmax_aug_range
|
| 70 |
+
|
| 71 |
+
self.register_buffer(
|
| 72 |
+
"preemphasis_coefficient", torch.as_tensor([[[-0.97, 1]]]), persistent=False
|
| 73 |
+
)
|
| 74 |
+
if freqm == 0:
|
| 75 |
+
self.freqm = torch.nn.Identity()
|
| 76 |
+
else:
|
| 77 |
+
self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=False)
|
| 78 |
+
if timem == 0:
|
| 79 |
+
self.timem = torch.nn.Identity()
|
| 80 |
+
else:
|
| 81 |
+
self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=False)
|
| 82 |
+
self.fast_norm = fast_norm
|
| 83 |
+
self.padding = padding
|
| 84 |
+
if padding not in ["center", "same"]:
|
| 85 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 86 |
+
self.iden = nn.Identity()
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
if self.preamp:
|
| 90 |
+
x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient)
|
| 91 |
+
x = x.squeeze(1)
|
| 92 |
+
|
| 93 |
+
fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
|
| 94 |
+
fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
|
| 95 |
+
|
| 96 |
+
# don't augment eval data
|
| 97 |
+
if not self.training:
|
| 98 |
+
fmin = self.fmin
|
| 99 |
+
fmax = self.fmax
|
| 100 |
+
|
| 101 |
+
mels = []
|
| 102 |
+
for n_fft, win_length in zip(self.n_fft, self.win_length):
|
| 103 |
+
x_temp = x
|
| 104 |
+
if self.padding == "same":
|
| 105 |
+
pad = win_length - self.hopsize
|
| 106 |
+
self.iden(x_temp) # printing
|
| 107 |
+
x_temp = torch.nn.functional.pad(x_temp, (pad // 2, pad // 2), mode="reflect")
|
| 108 |
+
self.iden(x_temp) # printing
|
| 109 |
+
|
| 110 |
+
x_temp = torch.stft(
|
| 111 |
+
x_temp,
|
| 112 |
+
n_fft,
|
| 113 |
+
hop_length=self.hopsize,
|
| 114 |
+
win_length=win_length,
|
| 115 |
+
center=self.padding == "center",
|
| 116 |
+
normalized=False,
|
| 117 |
+
window=getattr(self, f"window_{win_length}"),
|
| 118 |
+
return_complex=True
|
| 119 |
+
)
|
| 120 |
+
x_temp = torch.view_as_real(x_temp)
|
| 121 |
+
x_temp = (x_temp ** 2).sum(dim=-1) # power mag
|
| 122 |
+
|
| 123 |
+
mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, n_fft, self.sr,
|
| 124 |
+
fmin, fmax, vtln_low=100.0, vtln_high=-500.,
|
| 125 |
+
vtln_warp_factor=1.0)
|
| 126 |
+
mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
|
| 127 |
+
device=x.device)
|
| 128 |
+
|
| 129 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 130 |
+
x_temp = torch.matmul(mel_basis, x_temp)
|
| 131 |
+
|
| 132 |
+
x_temp = torch.log(torch.clip(x_temp, min=1e-7))
|
| 133 |
+
|
| 134 |
+
mels.append(x_temp)
|
| 135 |
+
|
| 136 |
+
mels = torch.stack(mels, dim=1)
|
| 137 |
+
|
| 138 |
+
if self.training:
|
| 139 |
+
mels = self.freqm(mels)
|
| 140 |
+
mels = self.timem(mels)
|
| 141 |
+
if self.fast_norm:
|
| 142 |
+
mels = (mels + 4.5) / 5.0 # fast normalization
|
| 143 |
+
|
| 144 |
+
return mels
|
| 145 |
+
|
| 146 |
+
def extra_repr(self):
|
| 147 |
+
return "winsize={}, hopsize={}".format(self.win_length, self.hopsize)
|
models/frame_passt/vit_helpers.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 3 |
+
Credit to @leo19941227 for remove timm dependencies here : https://github.com/s3prl/passt_hear21/blob/48a0dc1b824641ca59884ced53f5b86053fed141/hear21passt/models/helpers/vit_helpers.py
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
import warnings
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from timm.models._hub import download_cached_file
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Global variables for rarely used pretrained checkpoint download progress and hash check.
|
| 18 |
+
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
|
| 19 |
+
_DOWNLOAD_PROGRESS = True
|
| 20 |
+
_CHECK_HASH = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def adapt_input_conv(in_chans, conv_weight, input_conv_name="(name not given)"):
|
| 27 |
+
conv_type = conv_weight.dtype
|
| 28 |
+
conv_weight = (
|
| 29 |
+
conv_weight.float()
|
| 30 |
+
) # Some weights are in torch.half, ensure it's float for sum on CPU
|
| 31 |
+
O, I, J, K = conv_weight.shape
|
| 32 |
+
if in_chans == 1:
|
| 33 |
+
print(f"adapt_input_conv: Converted from {I} to 1 channel")
|
| 34 |
+
if I > 3:
|
| 35 |
+
assert conv_weight.shape[1] % 3 == 0
|
| 36 |
+
# For models with space2depth stems
|
| 37 |
+
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
| 38 |
+
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
| 39 |
+
else:
|
| 40 |
+
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
| 41 |
+
elif in_chans != 3:
|
| 42 |
+
if I != 3:
|
| 43 |
+
# loading a model pretrained on AudioSet for the downstream-task
|
| 44 |
+
if I == in_chans:
|
| 45 |
+
print(f"adapt_input_conv: Loading pretrained weights for {input_conv_name}, "
|
| 46 |
+
f"Assuming same input-conv and proj-conv configuration (1:1).")
|
| 47 |
+
pass
|
| 48 |
+
else:
|
| 49 |
+
print(f"adapt_input_conv: Converted input conv {input_conv_name} weights from 3 to {in_chans} channel(s)")
|
| 50 |
+
# NOTE this strategy should be better than random init, but there could be other combinations of
|
| 51 |
+
# the original RGB input layer weights that'd work better for specific cases.
|
| 52 |
+
repeat = int(math.ceil(in_chans / 3))
|
| 53 |
+
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
| 54 |
+
conv_weight *= 3 / float(in_chans)
|
| 55 |
+
conv_weight = conv_weight.to(conv_type)
|
| 56 |
+
return conv_weight
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_pretrained(
|
| 60 |
+
model,
|
| 61 |
+
default_cfg=None,
|
| 62 |
+
num_classes=1000,
|
| 63 |
+
in_chans=3,
|
| 64 |
+
filter_fn=None,
|
| 65 |
+
strict=True,
|
| 66 |
+
progress=False,
|
| 67 |
+
):
|
| 68 |
+
"""Load pretrained checkpoint
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
model (nn.Module) : PyTorch model module
|
| 72 |
+
default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset
|
| 73 |
+
num_classes (int): num_classes for model
|
| 74 |
+
in_chans (int): in_chans for model
|
| 75 |
+
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
|
| 76 |
+
strict (bool): strict load of checkpoint
|
| 77 |
+
progress (bool): enable progress bar for weight download
|
| 78 |
+
|
| 79 |
+
"""
|
| 80 |
+
default_cfg = default_cfg or getattr(model, "default_cfg", None) or {}
|
| 81 |
+
pretrained_url = default_cfg.get("url", None)
|
| 82 |
+
|
| 83 |
+
if not pretrained_url:
|
| 84 |
+
_logger.warning(
|
| 85 |
+
"No pretrained weights exist for this model. Using random initialization."
|
| 86 |
+
)
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
_logger.info(f"Loading pretrained weights from url ({pretrained_url})")
|
| 90 |
+
pretrained_loc = download_cached_file(
|
| 91 |
+
pretrained_url,
|
| 92 |
+
check_hash=_CHECK_HASH,
|
| 93 |
+
progress=_DOWNLOAD_PROGRESS,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
state_dict = torch.load(pretrained_loc, map_location="cpu")
|
| 97 |
+
|
| 98 |
+
if filter_fn is not None:
|
| 99 |
+
# for backwards compat with filter fn that take one arg, try one first, the two
|
| 100 |
+
try:
|
| 101 |
+
state_dict = filter_fn(state_dict)
|
| 102 |
+
except TypeError:
|
| 103 |
+
state_dict = filter_fn(state_dict, model)
|
| 104 |
+
|
| 105 |
+
input_convs = default_cfg.get("first_conv", None)
|
| 106 |
+
if input_convs is not None and in_chans != 3:
|
| 107 |
+
if isinstance(input_convs, str):
|
| 108 |
+
input_convs = (input_convs,)
|
| 109 |
+
for input_conv_name in input_convs:
|
| 110 |
+
weight_name = input_conv_name + ".weight"
|
| 111 |
+
try:
|
| 112 |
+
state_dict[weight_name] = adapt_input_conv(
|
| 113 |
+
in_chans, state_dict[weight_name], input_conv_name
|
| 114 |
+
)
|
| 115 |
+
# _logger.info(
|
| 116 |
+
# f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)"
|
| 117 |
+
# )
|
| 118 |
+
except (NotImplementedError, KeyError) as e:
|
| 119 |
+
if weight_name in state_dict:
|
| 120 |
+
del state_dict[weight_name]
|
| 121 |
+
strict = False
|
| 122 |
+
_logger.warning(
|
| 123 |
+
f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
classifiers = default_cfg.get("classifier", None)
|
| 127 |
+
label_offset = default_cfg.get("label_offset", 0)
|
| 128 |
+
if classifiers is not None:
|
| 129 |
+
if isinstance(classifiers, str):
|
| 130 |
+
classifiers = (classifiers,)
|
| 131 |
+
if num_classes != default_cfg["num_classes"]:
|
| 132 |
+
for classifier_name in classifiers:
|
| 133 |
+
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
| 134 |
+
del state_dict[classifier_name + ".weight"]
|
| 135 |
+
del state_dict[classifier_name + ".bias"]
|
| 136 |
+
strict = False
|
| 137 |
+
elif label_offset > 0:
|
| 138 |
+
for classifier_name in classifiers:
|
| 139 |
+
# special case for pretrained weights with an extra background class in pretrained weights
|
| 140 |
+
classifier_weight = state_dict[classifier_name + ".weight"]
|
| 141 |
+
state_dict[classifier_name + ".weight"] = classifier_weight[
|
| 142 |
+
label_offset:
|
| 143 |
+
]
|
| 144 |
+
classifier_bias = state_dict[classifier_name + ".bias"]
|
| 145 |
+
state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:]
|
| 146 |
+
|
| 147 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def overlay_external_default_cfg(default_cfg, kwargs):
|
| 151 |
+
"""Overlay 'external_default_cfg' in kwargs on top of default_cfg arg."""
|
| 152 |
+
external_default_cfg = kwargs.pop("external_default_cfg", None)
|
| 153 |
+
if external_default_cfg:
|
| 154 |
+
default_cfg.pop("url", None) # url should come from external cfg
|
| 155 |
+
default_cfg.pop("hf_hub", None) # hf hub id should come from external cfg
|
| 156 |
+
default_cfg.update(external_default_cfg)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def filter_kwargs(kwargs, names):
|
| 160 |
+
if not kwargs or not names:
|
| 161 |
+
return
|
| 162 |
+
for n in names:
|
| 163 |
+
kwargs.pop(n, None)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def set_default_kwargs(kwargs, names, default_cfg):
|
| 167 |
+
for n in names:
|
| 168 |
+
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
|
| 169 |
+
# default_cfg has one input_size=(C, H ,W) entry
|
| 170 |
+
if n == "img_size":
|
| 171 |
+
input_size = default_cfg.get("input_size", None)
|
| 172 |
+
if input_size is not None:
|
| 173 |
+
assert len(input_size) == 3
|
| 174 |
+
kwargs.setdefault(n, input_size[-2:])
|
| 175 |
+
elif n == "in_chans":
|
| 176 |
+
input_size = default_cfg.get("input_size", None)
|
| 177 |
+
if input_size is not None:
|
| 178 |
+
assert len(input_size) == 3
|
| 179 |
+
kwargs.setdefault(n, input_size[0])
|
| 180 |
+
else:
|
| 181 |
+
default_val = default_cfg.get(n, None)
|
| 182 |
+
if default_val is not None:
|
| 183 |
+
kwargs.setdefault(n, default_cfg[n])
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
|
| 187 |
+
"""Update the default_cfg and kwargs before passing to model
|
| 188 |
+
|
| 189 |
+
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
|
| 190 |
+
could/should be replaced by an improved configuration mechanism
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
default_cfg: input default_cfg (updated in-place)
|
| 194 |
+
kwargs: keyword args passed to model build fn (updated in-place)
|
| 195 |
+
kwargs_filter: keyword arg keys that must be removed before model __init__
|
| 196 |
+
"""
|
| 197 |
+
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
| 198 |
+
overlay_external_default_cfg(default_cfg, kwargs)
|
| 199 |
+
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
| 200 |
+
default_kwarg_names = ("num_classes", "global_pool", "in_chans")
|
| 201 |
+
if default_cfg.get("fixed_input_size", False):
|
| 202 |
+
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
|
| 203 |
+
default_kwarg_names += ("img_size",)
|
| 204 |
+
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
|
| 205 |
+
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
| 206 |
+
filter_kwargs(kwargs, names=kwargs_filter)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 210 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 211 |
+
|
| 212 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 213 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 214 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 215 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 216 |
+
'survival rate' as the argument.
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
if drop_prob == 0.0 or not training:
|
| 220 |
+
return x
|
| 221 |
+
keep_prob = 1 - drop_prob
|
| 222 |
+
shape = (x.shape[0],) + (1,) * (
|
| 223 |
+
x.ndim - 1
|
| 224 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 225 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 226 |
+
random_tensor.floor_() # binarize
|
| 227 |
+
output = x.div(keep_prob) * random_tensor
|
| 228 |
+
return output
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class DropPath(nn.Module):
|
| 232 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 233 |
+
|
| 234 |
+
def __init__(self, drop_prob=None):
|
| 235 |
+
super(DropPath, self).__init__()
|
| 236 |
+
self.drop_prob = drop_prob
|
| 237 |
+
|
| 238 |
+
def forward(self, x):
|
| 239 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 246 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 247 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 248 |
+
def norm_cdf(x):
|
| 249 |
+
# Computes standard normal cumulative distribution function
|
| 250 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
| 251 |
+
|
| 252 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 253 |
+
warnings.warn(
|
| 254 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 255 |
+
"The distribution of values may be incorrect.",
|
| 256 |
+
stacklevel=2,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
with torch.no_grad():
|
| 260 |
+
# Values are generated by using a truncated uniform distribution and
|
| 261 |
+
# then using the inverse CDF for the normal distribution.
|
| 262 |
+
# Get upper and lower cdf values
|
| 263 |
+
l = norm_cdf((a - mean) / std)
|
| 264 |
+
u = norm_cdf((b - mean) / std)
|
| 265 |
+
|
| 266 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 267 |
+
# [2l-1, 2u-1].
|
| 268 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 269 |
+
|
| 270 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 271 |
+
# standard normal
|
| 272 |
+
tensor.erfinv_()
|
| 273 |
+
|
| 274 |
+
# Transform to proper mean, std
|
| 275 |
+
tensor.mul_(std * math.sqrt(2.0))
|
| 276 |
+
tensor.add_(mean)
|
| 277 |
+
|
| 278 |
+
# Clamp to ensure it's in the proper range
|
| 279 |
+
tensor.clamp_(min=a, max=b)
|
| 280 |
+
return tensor
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
| 284 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
| 285 |
+
normal distribution. The values are effectively drawn from the
|
| 286 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 287 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 288 |
+
the bounds. The method used for generating the random values works
|
| 289 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
| 290 |
+
Args:
|
| 291 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 292 |
+
mean: the mean of the normal distribution
|
| 293 |
+
std: the standard deviation of the normal distribution
|
| 294 |
+
a: the minimum cutoff value
|
| 295 |
+
b: the maximum cutoff value
|
| 296 |
+
Examples:
|
| 297 |
+
>>> w = torch.empty(3, 5)
|
| 298 |
+
>>> nn.init.trunc_normal_(w)
|
| 299 |
+
"""
|
| 300 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
| 304 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
| 305 |
+
if mode == "fan_in":
|
| 306 |
+
denom = fan_in
|
| 307 |
+
elif mode == "fan_out":
|
| 308 |
+
denom = fan_out
|
| 309 |
+
elif mode == "fan_avg":
|
| 310 |
+
denom = (fan_in + fan_out) / 2
|
| 311 |
+
|
| 312 |
+
variance = scale / denom
|
| 313 |
+
|
| 314 |
+
if distribution == "truncated_normal":
|
| 315 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
| 316 |
+
trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
| 317 |
+
elif distribution == "normal":
|
| 318 |
+
tensor.normal_(std=math.sqrt(variance))
|
| 319 |
+
elif distribution == "uniform":
|
| 320 |
+
bound = math.sqrt(3 * variance)
|
| 321 |
+
tensor.uniform_(-bound, bound)
|
| 322 |
+
else:
|
| 323 |
+
raise ValueError(f"invalid distribution {distribution}")
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def lecun_normal_(tensor):
|
| 327 |
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def build_model_with_cfg(
|
| 331 |
+
model_cls,
|
| 332 |
+
variant: str,
|
| 333 |
+
pretrained: bool,
|
| 334 |
+
default_cfg: dict,
|
| 335 |
+
model_cfg=None,
|
| 336 |
+
feature_cfg=None,
|
| 337 |
+
pretrained_strict: bool = True,
|
| 338 |
+
pretrained_filter_fn=None,
|
| 339 |
+
pretrained_custom_load=False,
|
| 340 |
+
kwargs_filter=None,
|
| 341 |
+
**kwargs,
|
| 342 |
+
):
|
| 343 |
+
"""Build model with specified default_cfg and optional model_cfg
|
| 344 |
+
|
| 345 |
+
This helper fn aids in the construction of a model including:
|
| 346 |
+
* handling default_cfg and associated pretained weight loading
|
| 347 |
+
* passing through optional model_cfg for models with config based arch spec
|
| 348 |
+
* features_only model adaptation
|
| 349 |
+
* pruning config / model adaptation
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
model_cls (nn.Module): model class
|
| 353 |
+
variant (str): model variant name
|
| 354 |
+
pretrained (bool): load pretrained weights
|
| 355 |
+
default_cfg (dict): model's default pretrained/task config
|
| 356 |
+
model_cfg (Optional[Dict]): model's architecture config
|
| 357 |
+
feature_cfg (Optional[Dict]: feature extraction adapter config
|
| 358 |
+
pretrained_strict (bool): load pretrained weights strictly
|
| 359 |
+
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
|
| 360 |
+
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
|
| 361 |
+
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
|
| 362 |
+
**kwargs: model args passed through to model __init__
|
| 363 |
+
"""
|
| 364 |
+
pruned = kwargs.pop("pruned", False)
|
| 365 |
+
features = False
|
| 366 |
+
feature_cfg = feature_cfg or {}
|
| 367 |
+
default_cfg = deepcopy(default_cfg) if default_cfg else {}
|
| 368 |
+
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
|
| 369 |
+
default_cfg.setdefault("architecture", variant)
|
| 370 |
+
|
| 371 |
+
# Setup for feature extraction wrapper done at end of this fn
|
| 372 |
+
if kwargs.pop("features_only", False):
|
| 373 |
+
features = True
|
| 374 |
+
feature_cfg.setdefault("out_indices", (0, 1, 2, 3, 4))
|
| 375 |
+
if "out_indices" in kwargs:
|
| 376 |
+
feature_cfg["out_indices"] = kwargs.pop("out_indices")
|
| 377 |
+
|
| 378 |
+
# Build the model
|
| 379 |
+
model = (
|
| 380 |
+
model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
| 381 |
+
)
|
| 382 |
+
model.default_cfg = default_cfg
|
| 383 |
+
|
| 384 |
+
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
| 385 |
+
num_classes_pretrained = (
|
| 386 |
+
0
|
| 387 |
+
if features
|
| 388 |
+
else getattr(model, "num_classes", kwargs.get("num_classes", 1000))
|
| 389 |
+
)
|
| 390 |
+
if pretrained:
|
| 391 |
+
assert not pretrained_custom_load, "URL should not contain npz for PASST models"
|
| 392 |
+
load_pretrained(
|
| 393 |
+
model,
|
| 394 |
+
num_classes=num_classes_pretrained,
|
| 395 |
+
in_chans=kwargs.get("in_chans", 3),
|
| 396 |
+
filter_fn=pretrained_filter_fn,
|
| 397 |
+
strict=pretrained_strict,
|
| 398 |
+
)
|
| 399 |
+
return model
|
models/m2d/M2D_wrapper.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.m2d.portable_m2d import PortableM2D as M2D
|
| 2 |
+
from models.transformer_wrapper import BaseModelWrapper
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class M2DWrapper(BaseModelWrapper):
|
| 6 |
+
def __init__(self) -> None:
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.m2d = M2D()
|
| 9 |
+
|
| 10 |
+
def mel_forward(self, x):
|
| 11 |
+
return self.m2d.to_normalized_feature(x)
|
| 12 |
+
|
| 13 |
+
def forward(self, spec):
|
| 14 |
+
return self.m2d.forward_mel(spec)
|
| 15 |
+
|
| 16 |
+
def separate_params(self):
|
| 17 |
+
pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
|
| 18 |
+
for k, p in self.named_parameters():
|
| 19 |
+
if any(['cls_token' in k,
|
| 20 |
+
'pos_embed' in k,
|
| 21 |
+
'norm_stats' in k,
|
| 22 |
+
'patch_embed' in k]):
|
| 23 |
+
pt_params[0].append(p)
|
| 24 |
+
elif 'blocks.0.' in k:
|
| 25 |
+
pt_params[0].append(p)
|
| 26 |
+
elif 'blocks.1.' in k:
|
| 27 |
+
pt_params[1].append(p)
|
| 28 |
+
elif 'blocks.2.' in k:
|
| 29 |
+
pt_params[2].append(p)
|
| 30 |
+
elif 'blocks.3.' in k:
|
| 31 |
+
pt_params[3].append(p)
|
| 32 |
+
elif 'blocks.4.' in k:
|
| 33 |
+
pt_params[4].append(p)
|
| 34 |
+
elif 'blocks.5.' in k:
|
| 35 |
+
pt_params[5].append(p)
|
| 36 |
+
elif 'blocks.6.' in k:
|
| 37 |
+
pt_params[6].append(p)
|
| 38 |
+
elif 'blocks.7.' in k:
|
| 39 |
+
pt_params[7].append(p)
|
| 40 |
+
elif 'blocks.8.' in k:
|
| 41 |
+
pt_params[8].append(p)
|
| 42 |
+
elif 'blocks.9.' in k:
|
| 43 |
+
pt_params[9].append(p)
|
| 44 |
+
elif 'blocks.10.' in k:
|
| 45 |
+
pt_params[10].append(p)
|
| 46 |
+
elif 'blocks.11.' in k:
|
| 47 |
+
pt_params[11].append(p)
|
| 48 |
+
elif 'backbone.norm.weight' in k or 'backbone.norm.bias' in k:
|
| 49 |
+
pt_params[11].append(p)
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError(f"Check separate params for M2D! Unknown key: {k}")
|
| 52 |
+
return list(reversed(pt_params))
|
models/m2d/portable_m2d.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Masked Modeling Duo (M2D) Portable Runtime.
|
| 2 |
+
|
| 3 |
+
All you need is:
|
| 4 |
+
pip install timm, einops, nnAudio
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from functools import partial
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import nnAudio.features
|
| 12 |
+
import numpy as np
|
| 13 |
+
import timm
|
| 14 |
+
import torch
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
from timm.models.layers import trunc_normal_
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Config:
|
| 20 |
+
weight_file = ''
|
| 21 |
+
feature_d = 768 * 5
|
| 22 |
+
norm_type = all
|
| 23 |
+
pooling_type = 'mean'
|
| 24 |
+
model = ''
|
| 25 |
+
input_size = [80, 208]
|
| 26 |
+
patch_size = [16, 16]
|
| 27 |
+
sr = '16k'
|
| 28 |
+
flat_features = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def expand_size(sz):
|
| 32 |
+
if isinstance(sz, int):
|
| 33 |
+
return [sz, sz]
|
| 34 |
+
return sz
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class PatchEmbed(torch.nn.Module):
|
| 38 |
+
""" 2D Image to Patch Embedding -- borrowed from https://pypi.org/project/timm/0.4.12/"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
| 41 |
+
super().__init__()
|
| 42 |
+
img_size = expand_size(img_size)
|
| 43 |
+
patch_size = expand_size(patch_size)
|
| 44 |
+
self.img_size = img_size
|
| 45 |
+
self.patch_size = patch_size
|
| 46 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 47 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 48 |
+
self.flatten = flatten
|
| 49 |
+
|
| 50 |
+
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 51 |
+
self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
x = self.proj(x)
|
| 55 |
+
if self.flatten:
|
| 56 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 57 |
+
x = self.norm(x)
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LocalViT(timm.models.vision_transformer.VisionTransformer):
|
| 62 |
+
""" Vision Transformer for M2D Audio"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, **kwargs):
|
| 65 |
+
super().__init__(**kwargs)
|
| 66 |
+
# Workaround for PatchEmbed to avoid unintended assertion failure. ex) AssertionError: Input image width (102) doesn't match model (608).
|
| 67 |
+
self.patch_embed = PatchEmbed(self.patch_embed.img_size, self.patch_embed.patch_size,
|
| 68 |
+
self.patch_embed.proj.in_channels, self.patch_embed.proj.out_channels)
|
| 69 |
+
self.norm_stats = torch.nn.Parameter(torch.tensor([-7.1, 4.2]), requires_grad=False)
|
| 70 |
+
# We do not use the default head
|
| 71 |
+
del self.head
|
| 72 |
+
|
| 73 |
+
def patch_size(self):
|
| 74 |
+
return np.array(self.patch_embed.patch_size)
|
| 75 |
+
|
| 76 |
+
def grid_size(self):
|
| 77 |
+
# Workaround for compatibility issue (timm 0.4.5 fails with: return self.patch_embed.grid_size)
|
| 78 |
+
img_size = np.array(self.patch_embed.img_size)
|
| 79 |
+
patch_size = self.patch_size()
|
| 80 |
+
grid_size = img_size // patch_size
|
| 81 |
+
return grid_size
|
| 82 |
+
|
| 83 |
+
def forward_encoder(self, x):
|
| 84 |
+
x = self.patch_embed(x)
|
| 85 |
+
|
| 86 |
+
# add pos embed w/o cls token
|
| 87 |
+
pos_embed = self.pos_embed[:, 1:, :]
|
| 88 |
+
if x.shape[1] < pos_embed.shape[1]: # shorten pos_embed for a short input
|
| 89 |
+
dims = pos_embed.shape[-1]
|
| 90 |
+
fbins = self.grid_size()[0]
|
| 91 |
+
frames = x.shape[1] // fbins
|
| 92 |
+
pos_embed = pos_embed.reshape(1, fbins, -1, dims)[:, :, :frames, :].reshape(1, fbins * frames, dims)
|
| 93 |
+
x = x + pos_embed
|
| 94 |
+
|
| 95 |
+
# append cls token
|
| 96 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
| 97 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 98 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 99 |
+
|
| 100 |
+
# apply Transformer blocks
|
| 101 |
+
for blk in self.blocks:
|
| 102 |
+
x = blk(x)
|
| 103 |
+
x = self.norm(x)
|
| 104 |
+
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def parse_sizes_by_name(name):
|
| 109 |
+
# Parse parameters. "m2d_vit_base-80x1001p16x16p16k" -> input size: 80x1001, patch size: 16x16, sr: 16k
|
| 110 |
+
model_cls = name.split('-')[0]
|
| 111 |
+
params = name.split('-')[1]
|
| 112 |
+
params = params.split('p')[:3]
|
| 113 |
+
input_str, patch_str, sr = params[0], params[1], params[2] if len(params) > 2 else '16k'
|
| 114 |
+
input_size = [int(a) for a in input_str.split('x')]
|
| 115 |
+
patch_size = [int(a) for a in patch_str.split('x')]
|
| 116 |
+
return input_size, patch_size, sr, model_cls
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def drop_non_model_weights(model, checkpoint, filename):
|
| 120 |
+
model_keys = [n for n, p in model.named_parameters()]
|
| 121 |
+
new_ckpt, dropped = {}, []
|
| 122 |
+
for k in checkpoint:
|
| 123 |
+
if k not in model_keys:
|
| 124 |
+
dropped.append(k)
|
| 125 |
+
continue
|
| 126 |
+
new_ckpt[k] = checkpoint[k]
|
| 127 |
+
n_org = len(checkpoint.keys())
|
| 128 |
+
n_cur = len(new_ckpt.keys())
|
| 129 |
+
print(
|
| 130 |
+
f' using {n_cur} parameters, while dropped {n_org - n_cur} out of {n_org} parameters from {Path(filename).parent / Path(filename).name}'
|
| 131 |
+
if n_org > n_cur else f' using {n_cur} parameters from {Path(filename).parent / Path(filename).name}')
|
| 132 |
+
print(' (dropped:', dropped[:5], ')' if len(dropped) < 5 else '...)')
|
| 133 |
+
return new_ckpt
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_evar_head_parameters(checkpoint, head_norm, head):
|
| 137 |
+
# Load the weights of the task head trained in the EVAR fine-tuning.
|
| 138 |
+
if 'module.head.norm.running_mean' in checkpoint:
|
| 139 |
+
head_norm.load_state_dict({to_k: checkpoint[k] for to_k, k in {
|
| 140 |
+
'running_mean': 'module.head.norm.running_mean', 'running_var': 'module.head.norm.running_var'}.items()})
|
| 141 |
+
head.load_state_dict({to_k: checkpoint[k] for to_k, k in {
|
| 142 |
+
'weight': 'module.head.mlp.mlp.0.weight', 'bias': 'module.head.mlp.mlp.0.bias'}.items()})
|
| 143 |
+
else:
|
| 144 |
+
print(' Not an EVAR checkpoint for loading head weights.')
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def reformat_ckpt_keys(checkpoint):
|
| 148 |
+
# In case: checkpoint['model']
|
| 149 |
+
checkpoint = checkpoint['model'] if 'model' in checkpoint else checkpoint
|
| 150 |
+
# The checkpoints saved in a EVAR fine-tuning has a prefix of "module.ar.runtime.backbone", the following removes it.
|
| 151 |
+
new_ckpt = {}
|
| 152 |
+
for k in checkpoint:
|
| 153 |
+
new_k = k.replace('module.ar.runtime.backbone.', '') # replace
|
| 154 |
+
new_ckpt[new_k] = checkpoint[k]
|
| 155 |
+
return new_ckpt
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def make_it_CLAP(model, checkpoint):
|
| 159 |
+
# Add projectors if needed
|
| 160 |
+
if 'audio_proj.0.weight' in checkpoint.keys():
|
| 161 |
+
proj_hidden_dim = embed_dim = checkpoint['audio_proj.0.weight'].shape[1]
|
| 162 |
+
model.audio_proj = torch.nn.Sequential(
|
| 163 |
+
torch.nn.Linear(embed_dim, proj_hidden_dim),
|
| 164 |
+
torch.nn.ReLU(),
|
| 165 |
+
torch.nn.Linear(proj_hidden_dim, embed_dim),
|
| 166 |
+
)
|
| 167 |
+
if 'text_proj.weight' in checkpoint.keys():
|
| 168 |
+
dim = checkpoint['text_proj.weight'].shape
|
| 169 |
+
model.text_proj = torch.nn.Linear(dim[1], dim[0])
|
| 170 |
+
else:
|
| 171 |
+
model.text_proj = torch.nn.Identity()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_backbone(args, weight_file):
|
| 175 |
+
name = Path(weight_file).parent.name if weight_file is not None \
|
| 176 |
+
else "m2d_clap_vit_base-80x1001p16x16-240128_AS-FT_enconly"
|
| 177 |
+
args.input_size, args.patch_size, args.sr, args.beats = parse_sizes_by_name(name)
|
| 178 |
+
|
| 179 |
+
# Create a ViT.
|
| 180 |
+
model = LocalViT(
|
| 181 |
+
in_chans=1, img_size=args.input_size, patch_size=args.patch_size, embed_dim=768, depth=12, num_heads=12,
|
| 182 |
+
mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6))
|
| 183 |
+
|
| 184 |
+
if weight_file is None:
|
| 185 |
+
args.mean, args.std = -7.1, 4.2
|
| 186 |
+
model.eval()
|
| 187 |
+
return model, None
|
| 188 |
+
|
| 189 |
+
# Load checkpoint.
|
| 190 |
+
checkpoint = torch.load(weight_file, map_location='cpu')
|
| 191 |
+
checkpoint = reformat_ckpt_keys(checkpoint)
|
| 192 |
+
# Set normalization statistics for backward compatibility. The [-7.1, 4.2] is for 2022 models.
|
| 193 |
+
if 'norm_stats' not in checkpoint:
|
| 194 |
+
checkpoint['norm_stats'] = torch.tensor([-7.1, 4.2])
|
| 195 |
+
print(' using default norm_stats:', checkpoint['norm_stats'])
|
| 196 |
+
|
| 197 |
+
# Modify the model if it should be a M2D-CLAP.
|
| 198 |
+
make_it_CLAP(model, checkpoint)
|
| 199 |
+
|
| 200 |
+
# Load weights.
|
| 201 |
+
dropped = drop_non_model_weights(model, checkpoint, weight_file)
|
| 202 |
+
msg = model.load_state_dict(dropped)
|
| 203 |
+
print(msg);
|
| 204 |
+
logging.info(msg)
|
| 205 |
+
|
| 206 |
+
# Make normalization statistics for the model easy to use in the downstream task.
|
| 207 |
+
args.mean, args.std = model.state_dict()['norm_stats'].to('cpu').numpy()
|
| 208 |
+
|
| 209 |
+
model.eval()
|
| 210 |
+
return model, checkpoint
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_to_melspec(cfg):
|
| 214 |
+
if cfg.sr == '16k':
|
| 215 |
+
cfg.sample_rate, cfg.n_fft, cfg.window_size, cfg.hop_size = 16000, 400, 400, 160
|
| 216 |
+
cfg.n_mels, cfg.f_min, cfg.f_max = 80, 50, 8000
|
| 217 |
+
elif cfg.sr == '32k':
|
| 218 |
+
cfg.sample_rate, cfg.n_fft, cfg.window_size, cfg.hop_size = 32000, 800, 800, 320
|
| 219 |
+
cfg.n_mels, cfg.f_min, cfg.f_max = 80, 50, 16000
|
| 220 |
+
else:
|
| 221 |
+
assert False, f'Unknown input size: {cfg.input_size}'
|
| 222 |
+
|
| 223 |
+
to_spec = nnAudio.features.MelSpectrogram(
|
| 224 |
+
sr=cfg.sample_rate,
|
| 225 |
+
n_fft=cfg.n_fft,
|
| 226 |
+
win_length=cfg.window_size,
|
| 227 |
+
hop_length=cfg.hop_size,
|
| 228 |
+
n_mels=cfg.n_mels,
|
| 229 |
+
fmin=cfg.f_min,
|
| 230 |
+
fmax=cfg.f_max,
|
| 231 |
+
center=True,
|
| 232 |
+
power=2,
|
| 233 |
+
verbose=False,
|
| 234 |
+
)
|
| 235 |
+
logging.info(f'Runtime MelSpectrogram({cfg.sample_rate}, {cfg.n_fft}, {cfg.window_size}, {cfg.hop_size}, '
|
| 236 |
+
+ f'{cfg.n_mels}, {cfg.f_min}, {cfg.f_max}):')
|
| 237 |
+
logging.info(to_spec)
|
| 238 |
+
return to_spec
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def get_timestamps(cfg, batch_audio, x): # Returns timestamps in milliseconds.
|
| 242 |
+
audio_len = len(batch_audio[0])
|
| 243 |
+
sec = audio_len / cfg.sample_rate
|
| 244 |
+
x_len = len(x[0])
|
| 245 |
+
step = sec / x_len * 1000 # sec -> ms
|
| 246 |
+
ts = torch.tensor([step * i for i in range(x_len)]).unsqueeze(0)
|
| 247 |
+
ts = ts.repeat(len(batch_audio), 1)
|
| 248 |
+
return ts
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class PortableM2D(torch.nn.Module):
|
| 252 |
+
def __init__(self, weight_file=None, num_classes=None, freeze_embed=False, flat_features=None):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.cfg = Config()
|
| 255 |
+
self.cfg.weight_file = weight_file
|
| 256 |
+
self.cfg.freeze_embed = freeze_embed
|
| 257 |
+
self.cfg.flat_features = self.cfg.flat_features if flat_features is None else flat_features
|
| 258 |
+
|
| 259 |
+
# Create backbone model.
|
| 260 |
+
self.backbone, checkpoint = get_backbone(self.cfg, self.cfg.weight_file)
|
| 261 |
+
# Finalize feature dimension.
|
| 262 |
+
d = self.backbone.pos_embed.shape[-1]
|
| 263 |
+
if num_classes is not None and 'module.head.mlp.mlp.0.weight' in checkpoint and \
|
| 264 |
+
checkpoint['module.head.mlp.mlp.0.weight'].shape[-1] == d:
|
| 265 |
+
self.cfg.flat_features = True
|
| 266 |
+
n_stack_feature = 1 if self.cfg.flat_features else (self.cfg.input_size[0] // self.cfg.patch_size[0])
|
| 267 |
+
self.cfg.feature_d = d * n_stack_feature # 768 if flat_features else 768*5=3840
|
| 268 |
+
# Create head.
|
| 269 |
+
if num_classes is not None:
|
| 270 |
+
self.head_norm = torch.nn.BatchNorm1d(self.cfg.feature_d, affine=False)
|
| 271 |
+
self.head = torch.nn.Linear(self.cfg.feature_d, num_classes)
|
| 272 |
+
trunc_normal_(self.head.weight, std=2e-5)
|
| 273 |
+
load_evar_head_parameters(checkpoint, self.head_norm, self.head)
|
| 274 |
+
# Option: freeze patch embedding ([2211.09359] How to Fine-Tune Vision Models with SGD)
|
| 275 |
+
if self.cfg.freeze_embed:
|
| 276 |
+
models_mae.set_requires_grad(self.backbone.patch_embed, False)
|
| 277 |
+
logging.info(' ** Freeze patch_embed **')
|
| 278 |
+
logging.info(self.backbone.patch_embed)
|
| 279 |
+
|
| 280 |
+
logging.info(f'Model input size: {self.cfg.input_size}')
|
| 281 |
+
logging.info(f'Using weights: {self.cfg.weight_file}')
|
| 282 |
+
logging.info(f'Feature dimension: {self.cfg.feature_d}')
|
| 283 |
+
logging.info(f'Norm stats: {self.cfg.mean}, {self.cfg.std}')
|
| 284 |
+
|
| 285 |
+
self.to_spec = get_to_melspec(self.cfg)
|
| 286 |
+
self.eval()
|
| 287 |
+
|
| 288 |
+
def to_log_mel_spec(self, batch_audio):
|
| 289 |
+
x = self.to_spec(batch_audio)
|
| 290 |
+
x = (x + torch.finfo().eps).log()
|
| 291 |
+
x = x.unsqueeze(1)
|
| 292 |
+
return x
|
| 293 |
+
|
| 294 |
+
def normalize_batch(self, x):
|
| 295 |
+
x = (x - self.cfg.mean) / self.cfg.std
|
| 296 |
+
return x
|
| 297 |
+
|
| 298 |
+
def to_normalized_feature(self, batch_audio):
|
| 299 |
+
x = self.to_log_mel_spec(batch_audio)
|
| 300 |
+
x = self.normalize_batch(x)
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
def encode_lms(self, x, average_per_time_frame=False):
|
| 304 |
+
patch_fbins = self.backbone.grid_size()[0]
|
| 305 |
+
unit_frames = self.cfg.input_size[1]
|
| 306 |
+
patch_frames = self.backbone.patch_size()[1]
|
| 307 |
+
embed_d = self.backbone.patch_embed.proj.out_channels
|
| 308 |
+
n_chunk = (x.shape[-1] + unit_frames - 1) // unit_frames
|
| 309 |
+
pad_frames = (patch_frames - (x.shape[-1] % unit_frames % patch_frames)) % patch_frames
|
| 310 |
+
if pad_frames > 0:
|
| 311 |
+
x = torch.nn.functional.pad(x, (0, pad_frames))
|
| 312 |
+
|
| 313 |
+
embeddings = []
|
| 314 |
+
if self.cfg.flat_features:
|
| 315 |
+
# flatten all patch embeddings
|
| 316 |
+
for i in range(n_chunk):
|
| 317 |
+
emb = self.backbone.forward_encoder(x[..., i * unit_frames:(i + 1) * unit_frames])
|
| 318 |
+
emb = emb[..., 1:, :]
|
| 319 |
+
if average_per_time_frame:
|
| 320 |
+
emb = rearrange(emb, 'b (f t) d -> b t d f', f=patch_fbins, d=embed_d).mean(-1)
|
| 321 |
+
embeddings.append(emb)
|
| 322 |
+
else:
|
| 323 |
+
# stack embeddings along time frame
|
| 324 |
+
for i in range(n_chunk):
|
| 325 |
+
emb = self.backbone.forward_encoder(x[..., i * unit_frames:(i + 1) * unit_frames])
|
| 326 |
+
emb = emb[..., 1:, :]
|
| 327 |
+
emb = rearrange(emb, 'b (f t) d -> b t (f d)', f=patch_fbins, d=embed_d)
|
| 328 |
+
embeddings.append(emb)
|
| 329 |
+
# concatenate embedding chunks in the time axis
|
| 330 |
+
x = torch.cat(embeddings, axis=-2)
|
| 331 |
+
return x
|
| 332 |
+
|
| 333 |
+
def encode(self, batch_audio, average_per_time_frame=False):
|
| 334 |
+
x = self.to_normalized_feature(batch_audio)
|
| 335 |
+
return self.encode_lms(x, average_per_time_frame=average_per_time_frame)
|
| 336 |
+
|
| 337 |
+
def forward(self, batch_audio, average_per_time_frame=False):
|
| 338 |
+
x = self.encode(batch_audio, average_per_time_frame=average_per_time_frame)
|
| 339 |
+
if hasattr(self, 'head'):
|
| 340 |
+
x = x.mean(1) # B, D
|
| 341 |
+
x = self.head_norm(x.unsqueeze(-1)).squeeze(-1)
|
| 342 |
+
x = self.head(x)
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
def forward_mel(self, batch_mel, average_per_time_frame=False):
|
| 346 |
+
x = self.encode_lms(batch_mel, average_per_time_frame=average_per_time_frame)
|
| 347 |
+
if hasattr(self, 'head'):
|
| 348 |
+
x = x.mean(1) # B, D
|
| 349 |
+
x = self.head_norm(x.unsqueeze(-1)).squeeze(-1)
|
| 350 |
+
x = self.head(x)
|
| 351 |
+
return x
|
| 352 |
+
|
| 353 |
+
def get_scene_embeddings(self, batch_audio):
|
| 354 |
+
x = self.encode(batch_audio)
|
| 355 |
+
x = torch.mean(x, dim=1)
|
| 356 |
+
return x
|
| 357 |
+
|
| 358 |
+
def get_timestamp_embeddings(self, batch_audio):
|
| 359 |
+
x = self.encode(batch_audio, average_per_time_frame=True)
|
| 360 |
+
ts = get_timestamps(self.cfg, batch_audio, x)
|
| 361 |
+
return x, ts
|
| 362 |
+
|
| 363 |
+
def forward_frames(self, batch_audio):
|
| 364 |
+
x, ts = self.get_timestamp_embeddings(batch_audio)
|
| 365 |
+
if hasattr(self, 'head'):
|
| 366 |
+
x = self.head_norm(x.transpose(-1, -2)).transpose(-2, -1)
|
| 367 |
+
x = self.head(x)
|
| 368 |
+
return x, ts
|
| 369 |
+
|
| 370 |
+
def encode_clap_audio(self, batch_audio):
|
| 371 |
+
audio_embeddings = self.forward(batch_audio)
|
| 372 |
+
audio_embeddings = audio_embeddings.mean(dim=-2)
|
| 373 |
+
audio_embeddings = self.backbone.audio_proj(audio_embeddings)
|
| 374 |
+
return audio_embeddings
|
| 375 |
+
|
| 376 |
+
def encode_clap_text(self, batch_text, truncate=False):
|
| 377 |
+
if not hasattr(self, 'text_encoder'):
|
| 378 |
+
self.text_encoder = GTETextEncoder()
|
| 379 |
+
text_embeddings = self.text_encoder(batch_text, truncate=truncate)
|
| 380 |
+
text_embeddings = self.backbone.text_proj(text_embeddings)
|
| 381 |
+
text_embeddings = text_embeddings.detach().cpu().to(torch.float)
|
| 382 |
+
return text_embeddings
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# For the CLAP models
|
| 386 |
+
|
| 387 |
+
class GTETextEncoder:
|
| 388 |
+
def __init__(self, clip_weight="thenlper/gte-base"):
|
| 389 |
+
from transformers import AutoTokenizer, AutoModel
|
| 390 |
+
import os
|
| 391 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true" # To suppress warnings.
|
| 392 |
+
|
| 393 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_weight)
|
| 394 |
+
self.model = AutoModel.from_pretrained(clip_weight)
|
| 395 |
+
|
| 396 |
+
def __call__(self, texts, truncate=True, max_length=512):
|
| 397 |
+
def average_pool(last_hidden_states, attention_mask):
|
| 398 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 399 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
| 400 |
+
|
| 401 |
+
with torch.no_grad():
|
| 402 |
+
device = next(self.model.parameters()).device
|
| 403 |
+
batch_dict = self.tokenizer(texts, max_length=max_length, padding=True, truncation=truncate,
|
| 404 |
+
return_tensors='pt')
|
| 405 |
+
batch_dict['input_ids'] = batch_dict['input_ids'].to(device)
|
| 406 |
+
batch_dict['token_type_ids'] = batch_dict['token_type_ids'].to(device)
|
| 407 |
+
batch_dict['attention_mask'] = batch_dict['attention_mask'].to(device)
|
| 408 |
+
outputs = self.model.to(device)(**batch_dict)
|
| 409 |
+
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 410 |
+
return embeddings
|
models/prediction_wrapper.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.hub import download_url_to_file
|
| 6 |
+
|
| 7 |
+
from config import RESOURCES_FOLDER, CHECKPOINT_URLS
|
| 8 |
+
from models.seq_models import BidirectionalLSTM, BidirectionalGRU
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PredictionsWrapper(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
A wrapper module that adds an optional sequence model and classification heads on top of a transformer.
|
| 14 |
+
It implements equations (1), (2), and (3) in the paper.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
base_model (BaseModelWrapper): The base model (transformer) providing sequence embeddings
|
| 18 |
+
checkpoint (str, optional): checkpoint name for loading pre-trained weights. Default is None.
|
| 19 |
+
n_classes_strong (int): Number of classes for strong predictions. Default is 447.
|
| 20 |
+
n_classes_weak (int, optional): Number of classes for weak predictions. Default is None,
|
| 21 |
+
which sets it equal to n_classes_strong.
|
| 22 |
+
embed_dim (int, optional): Embedding dimension of the base model output. Default is 768.
|
| 23 |
+
seq_len (int, optional): Desired sequence length. Default is 250 (40 ms resolution).
|
| 24 |
+
seq_model_type (str, optional): Type of sequence model to use.
|
| 25 |
+
Default is None, which means no additional sequence model is used.
|
| 26 |
+
head_type (str, optional): Type of classification head. Choices are ["linear", "attention", "None"].
|
| 27 |
+
Default is "linear". "None" means that sequence embeddings are returned.
|
| 28 |
+
rnn_layers (int, optional): Number of RNN layers if seq_model_type is "rnn". Default is 2.
|
| 29 |
+
rnn_type (str, optional): Type of RNN to use. Choices are ["BiGRU", "BiLSTM"]. Default is "BiGRU".
|
| 30 |
+
rnn_dim (int, optional): Dimension of RNN hidden state if seq_model_type is "rnn". Default is 256.
|
| 31 |
+
rnn_dropout (float, optional): Dropout rate for RNN layers. Default is 0.0.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self,
|
| 35 |
+
base_model,
|
| 36 |
+
checkpoint=None,
|
| 37 |
+
n_classes_strong=447,
|
| 38 |
+
n_classes_weak=None,
|
| 39 |
+
embed_dim=768,
|
| 40 |
+
seq_len=250,
|
| 41 |
+
seq_model_type=None,
|
| 42 |
+
head_type="linear",
|
| 43 |
+
rnn_layers=2,
|
| 44 |
+
rnn_type="BiGRU",
|
| 45 |
+
rnn_dim=2048,
|
| 46 |
+
rnn_dropout=0.0
|
| 47 |
+
):
|
| 48 |
+
super(PredictionsWrapper, self).__init__()
|
| 49 |
+
self.model = base_model
|
| 50 |
+
self.seq_len = seq_len
|
| 51 |
+
self.embed_dim = embed_dim
|
| 52 |
+
self.n_classes_strong = n_classes_strong
|
| 53 |
+
self.n_classes_weak = n_classes_weak if n_classes_weak is not None else n_classes_strong
|
| 54 |
+
self.seq_model_type = seq_model_type
|
| 55 |
+
self.head_type = head_type
|
| 56 |
+
|
| 57 |
+
if self.seq_model_type == "rnn":
|
| 58 |
+
if rnn_type == "BiGRU":
|
| 59 |
+
self.seq_model = BidirectionalGRU(
|
| 60 |
+
n_in=self.embed_dim,
|
| 61 |
+
n_hidden=rnn_dim,
|
| 62 |
+
dropout=rnn_dropout,
|
| 63 |
+
num_layers=rnn_layers
|
| 64 |
+
)
|
| 65 |
+
elif rnn_type == "BiLSTM":
|
| 66 |
+
self.seq_model = BidirectionalLSTM(
|
| 67 |
+
nIn=self.embed_dim,
|
| 68 |
+
nHidden=rnn_dim,
|
| 69 |
+
nOut=rnn_dim * 2,
|
| 70 |
+
dropout=rnn_dropout,
|
| 71 |
+
num_layers=rnn_layers
|
| 72 |
+
)
|
| 73 |
+
num_features = rnn_dim * 2
|
| 74 |
+
elif self.seq_model_type is None:
|
| 75 |
+
self.seq_model = nn.Identity()
|
| 76 |
+
# no additional sequence model
|
| 77 |
+
num_features = self.embed_dim
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(f"Unknown seq_model_type: {self.seq_model_type}")
|
| 80 |
+
|
| 81 |
+
if self.head_type == "attention":
|
| 82 |
+
assert self.n_classes_strong == self.n_classes_weak, "head_type=='attention' requires number of strong and " \
|
| 83 |
+
"weak classes to be the same!"
|
| 84 |
+
|
| 85 |
+
if self.head_type is not None:
|
| 86 |
+
self.strong_head = nn.Linear(num_features, self.n_classes_strong)
|
| 87 |
+
self.weak_head = nn.Linear(num_features, self.n_classes_weak)
|
| 88 |
+
if checkpoint is not None:
|
| 89 |
+
print("Loading pretrained checkpoint: ", checkpoint)
|
| 90 |
+
self.load_checkpoint(checkpoint)
|
| 91 |
+
|
| 92 |
+
def load_checkpoint(self, checkpoint):
|
| 93 |
+
ckpt_file = os.path.join(RESOURCES_FOLDER, checkpoint + ".pt")
|
| 94 |
+
if not os.path.exists(ckpt_file):
|
| 95 |
+
download_url_to_file(CHECKPOINT_URLS[checkpoint], ckpt_file)
|
| 96 |
+
state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True)
|
| 97 |
+
|
| 98 |
+
# compatibility with uniform wrapper structure we introduced for the public repo
|
| 99 |
+
if 'fpasst' in checkpoint:
|
| 100 |
+
state_dict = {("model.fpasst." + k[len("model."):] if k.startswith("model.")
|
| 101 |
+
else k): v for k, v in state_dict.items()}
|
| 102 |
+
elif 'M2D' in checkpoint:
|
| 103 |
+
state_dict = {("model.m2d." + k[len("model."):] if not k.startswith("model.m2d.") and k.startswith("model.")
|
| 104 |
+
else k): v for k, v in state_dict.items()}
|
| 105 |
+
elif 'BEATs' in checkpoint:
|
| 106 |
+
state_dict = {("model.beats." + k[len("model.model."):] if k.startswith("model.model")
|
| 107 |
+
else k): v for k, v in state_dict.items()}
|
| 108 |
+
elif 'ASIT' in checkpoint:
|
| 109 |
+
state_dict = {("model.asit." + k[len("model."):] if k.startswith("model.")
|
| 110 |
+
else k): v for k, v in state_dict.items()}
|
| 111 |
+
|
| 112 |
+
n_classes_weak_in_sd = state_dict['weak_head.bias'].shape[0] if 'weak_head.bias' in state_dict else -1
|
| 113 |
+
n_classes_strong_in_sd = state_dict['strong_head.bias'].shape[0] if 'strong_head.bias' in state_dict else -1
|
| 114 |
+
seq_model_in_sd = any(['seq_model.' in key for key in state_dict.keys()])
|
| 115 |
+
keys_to_remove = []
|
| 116 |
+
strict = True
|
| 117 |
+
expected_missing = 0
|
| 118 |
+
if self.head_type is None:
|
| 119 |
+
# remove all keys related to head
|
| 120 |
+
keys_to_remove.append('weak_head.bias')
|
| 121 |
+
keys_to_remove.append('weak_head.weight')
|
| 122 |
+
keys_to_remove.append('strong_head.bias')
|
| 123 |
+
keys_to_remove.append('strong_head.weight')
|
| 124 |
+
elif self.seq_model_type is not None and not seq_model_in_sd:
|
| 125 |
+
# we want to train a sequence model (e.g., rnn) on top of a
|
| 126 |
+
# pre-trained transformer (e.g., AS weak pretrained)
|
| 127 |
+
keys_to_remove.append('weak_head.bias')
|
| 128 |
+
keys_to_remove.append('weak_head.weight')
|
| 129 |
+
keys_to_remove.append('strong_head.bias')
|
| 130 |
+
keys_to_remove.append('strong_head.weight')
|
| 131 |
+
num_seq_model_keys = len([key for key in self.seq_model.state_dict()])
|
| 132 |
+
expected_missing = len(keys_to_remove) + num_seq_model_keys
|
| 133 |
+
strict = False
|
| 134 |
+
else:
|
| 135 |
+
# head type is not None
|
| 136 |
+
if n_classes_weak_in_sd != self.n_classes_weak:
|
| 137 |
+
# remove weak head from sd
|
| 138 |
+
keys_to_remove.append('weak_head.bias')
|
| 139 |
+
keys_to_remove.append('weak_head.weight')
|
| 140 |
+
strict = False
|
| 141 |
+
if n_classes_strong_in_sd != self.n_classes_strong:
|
| 142 |
+
# remove strong head from sd
|
| 143 |
+
keys_to_remove.append('strong_head.bias')
|
| 144 |
+
keys_to_remove.append('strong_head.weight')
|
| 145 |
+
strict = False
|
| 146 |
+
expected_missing = len(keys_to_remove)
|
| 147 |
+
|
| 148 |
+
# allow missing mel parameters for compatibility
|
| 149 |
+
num_mel_keys = len([key for key in self.state_dict() if 'mel_transform' in key])
|
| 150 |
+
if num_mel_keys > 0:
|
| 151 |
+
expected_missing += num_mel_keys
|
| 152 |
+
strict = False
|
| 153 |
+
|
| 154 |
+
state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove}
|
| 155 |
+
missing, unexpected = self.load_state_dict(state_dict, strict=strict)
|
| 156 |
+
assert len(missing) == expected_missing
|
| 157 |
+
assert len(unexpected) == 0
|
| 158 |
+
|
| 159 |
+
def separate_params(self):
|
| 160 |
+
if hasattr(self, "separate_params"):
|
| 161 |
+
return self.model.separate_params()
|
| 162 |
+
else:
|
| 163 |
+
raise NotImplementedError("The base model has no 'separate_params' method!'")
|
| 164 |
+
|
| 165 |
+
def has_separate_params(self):
|
| 166 |
+
return hasattr(self.model, "separate_params")
|
| 167 |
+
|
| 168 |
+
def mel_forward(self, x):
|
| 169 |
+
return self.model.mel_forward(x)
|
| 170 |
+
|
| 171 |
+
def forward(self, x):
|
| 172 |
+
# base model is expected to output a sequence (see Eq. (1) in paper)
|
| 173 |
+
# (batch size x sequence length x embedding dimension)
|
| 174 |
+
x = self.model(x)
|
| 175 |
+
|
| 176 |
+
# ATST: x.shape: batch size x 250 x 768
|
| 177 |
+
# PaSST: x.shape: batch size x 250 x 768
|
| 178 |
+
# ASiT: x.shape: batch size x 497 x 768
|
| 179 |
+
# M2D: x.shape: batch size x 62 x 3840
|
| 180 |
+
# BEATs: x.shape: batch size x 496 x 768
|
| 181 |
+
|
| 182 |
+
assert len(x.shape) == 3
|
| 183 |
+
|
| 184 |
+
if x.size(-2) > self.seq_len:
|
| 185 |
+
x = torch.nn.functional.adaptive_avg_pool1d(x.transpose(1, 2), self.seq_len).transpose(1, 2)
|
| 186 |
+
elif x.size(-2) < self.seq_len:
|
| 187 |
+
x = torch.nn.functional.interpolate(x.transpose(1, 2), size=self.seq_len,
|
| 188 |
+
mode='linear').transpose(1, 2)
|
| 189 |
+
|
| 190 |
+
# Eq. (3) in the paper
|
| 191 |
+
# for teachers this is an RNN, for students it is nn.Identity
|
| 192 |
+
x = self.seq_model(x)
|
| 193 |
+
|
| 194 |
+
if self.head_type == "attention":
|
| 195 |
+
# attention head to obtain weak from strong predictions
|
| 196 |
+
# this is typically used for the DESED task, which requires both
|
| 197 |
+
# weak and strong predictions
|
| 198 |
+
strong = torch.sigmoid(self.strong_head(x))
|
| 199 |
+
sof = torch.softmax(self.weak_head(x), dim=-1)
|
| 200 |
+
sof = torch.clamp(sof, min=1e-7, max=1)
|
| 201 |
+
weak = (strong * sof).sum(1) / sof.sum(1)
|
| 202 |
+
return strong.transpose(1, 2), weak
|
| 203 |
+
elif self.head_type == "linear":
|
| 204 |
+
# simple linear layers as head (see Eq. (3) in the paper)
|
| 205 |
+
# on AudioSet strong, only strong predictions are used
|
| 206 |
+
# on AudioSet weak, only weak predictions are used
|
| 207 |
+
# why both? because we tried to simultaneously train on AudioSet weak and strong (less successful)
|
| 208 |
+
strong = self.strong_head(x)
|
| 209 |
+
weak = self.weak_head(x.mean(dim=1))
|
| 210 |
+
return strong.transpose(1, 2), weak
|
| 211 |
+
else:
|
| 212 |
+
# no head means the sequence is returned instead of strong and weak predictions
|
| 213 |
+
return x
|
models/seq_models.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BidirectionalGRU(nn.Module):
|
| 5 |
+
def __init__(self, n_in, n_hidden, dropout=0, num_layers=1):
|
| 6 |
+
super(BidirectionalGRU, self).__init__()
|
| 7 |
+
self.rnn = nn.GRU(
|
| 8 |
+
n_in,
|
| 9 |
+
n_hidden,
|
| 10 |
+
bidirectional=True,
|
| 11 |
+
dropout=dropout,
|
| 12 |
+
batch_first=True,
|
| 13 |
+
num_layers=num_layers,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def forward(self, input_feat):
|
| 17 |
+
recurrent, _ = self.rnn(input_feat)
|
| 18 |
+
return recurrent
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BidirectionalLSTM(nn.Module):
|
| 22 |
+
def __init__(self, nIn, nHidden, nOut, dropout=0, num_layers=1):
|
| 23 |
+
super(BidirectionalLSTM, self).__init__()
|
| 24 |
+
self.rnn = nn.LSTM(
|
| 25 |
+
nIn,
|
| 26 |
+
nHidden,
|
| 27 |
+
bidirectional=True,
|
| 28 |
+
batch_first=True,
|
| 29 |
+
dropout=dropout,
|
| 30 |
+
num_layers=num_layers,
|
| 31 |
+
)
|
| 32 |
+
self.embedding = nn.Linear(nHidden * 2, nOut)
|
| 33 |
+
|
| 34 |
+
def forward(self, input_feat):
|
| 35 |
+
recurrent, _ = self.rnn(input_feat)
|
| 36 |
+
b, T, h = recurrent.size()
|
| 37 |
+
t_rec = recurrent.contiguous().view(b * T, h)
|
| 38 |
+
output = self.embedding(t_rec)
|
| 39 |
+
output = output.view(b, T, -1)
|
| 40 |
+
return output
|
models/transformer_wrapper.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BaseModelWrapper(ABC, nn.Module):
|
| 6 |
+
@abstractmethod
|
| 7 |
+
def mel_forward(self, x):
|
| 8 |
+
"""Process input waveform to mel spectrogram."""
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
"""Extract embedding sequence from mel spectrogram."""
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def separate_params(self):
|
| 18 |
+
"""Separate model parameters into predefined groups for layer-wise learning rate decay."""
|
| 19 |
+
pass
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy<2
|
| 2 |
+
librosa
|
| 3 |
+
pandas
|
| 4 |
+
timm
|
| 5 |
+
nnAudio
|
| 6 |
+
av>=10.0.0
|
| 7 |
+
h5py>=3.8.0
|
| 8 |
+
jsonpickle>=3.0.1
|
| 9 |
+
hf_transfer>=0.1.4
|
| 10 |
+
hf-fastup>=0.0.5
|
| 11 |
+
datasets>=2.15.0
|
| 12 |
+
pytorch-lightning>=2.0.0
|
| 13 |
+
wandb
|
| 14 |
+
transformers
|
| 15 |
+
sed_scores_eval==0.0.3
|
| 16 |
+
intervaltree
|
| 17 |
+
more-itertools
|
resources/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
In this folder, we place all files that are automatically downloaded (such as model checkpoints).
|
resources/best_model_BEATs.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e610c0ce85b77d15cdba5d25e02618ae47eada299f0c3d77fd802e19316ed821
|
| 3 |
+
size 361619724
|
resources/eval_durations.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
resources/labelvocabulary.csv
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
idx,label
|
| 2 |
+
0,21
|
| 3 |
+
1,22
|
| 4 |
+
2,23
|
| 5 |
+
3,24
|
| 6 |
+
4,25
|
| 7 |
+
5,26
|
| 8 |
+
6,27
|
| 9 |
+
7,28
|
| 10 |
+
8,29
|
| 11 |
+
9,30
|
| 12 |
+
10,31
|
| 13 |
+
11,32
|
| 14 |
+
12,33
|
| 15 |
+
13,34
|
| 16 |
+
14,35
|
| 17 |
+
15,36
|
| 18 |
+
16,37
|
| 19 |
+
17,38
|
| 20 |
+
18,39
|
| 21 |
+
19,40
|
| 22 |
+
20,41
|
| 23 |
+
21,42
|
| 24 |
+
22,43
|
| 25 |
+
23,44
|
| 26 |
+
24,45
|
| 27 |
+
25,46
|
| 28 |
+
26,47
|
| 29 |
+
27,48
|
| 30 |
+
28,49
|
| 31 |
+
29,50
|
| 32 |
+
30,51
|
| 33 |
+
31,52
|
| 34 |
+
32,53
|
| 35 |
+
33,54
|
| 36 |
+
34,55
|
| 37 |
+
35,56
|
| 38 |
+
36,57
|
| 39 |
+
37,58
|
| 40 |
+
38,59
|
| 41 |
+
39,60
|
| 42 |
+
40,61
|
| 43 |
+
41,62
|
| 44 |
+
42,63
|
| 45 |
+
43,64
|
| 46 |
+
44,65
|
| 47 |
+
45,66
|
| 48 |
+
46,67
|
| 49 |
+
47,68
|
| 50 |
+
48,69
|
| 51 |
+
49,70
|
| 52 |
+
50,71
|
| 53 |
+
51,72
|
| 54 |
+
52,73
|
| 55 |
+
53,74
|
| 56 |
+
54,75
|
| 57 |
+
55,76
|
| 58 |
+
56,77
|
| 59 |
+
57,78
|
| 60 |
+
58,79
|
| 61 |
+
59,80
|
| 62 |
+
60,81
|
| 63 |
+
61,82
|
| 64 |
+
62,83
|
| 65 |
+
63,84
|
| 66 |
+
64,85
|
| 67 |
+
65,86
|
| 68 |
+
66,87
|
| 69 |
+
67,88
|
| 70 |
+
68,89
|
| 71 |
+
69,90
|
| 72 |
+
70,91
|
| 73 |
+
71,92
|
| 74 |
+
72,93
|
| 75 |
+
73,94
|
| 76 |
+
74,95
|
| 77 |
+
75,96
|
| 78 |
+
76,97
|
| 79 |
+
77,98
|
| 80 |
+
78,99
|
| 81 |
+
79,100
|
| 82 |
+
80,101
|
| 83 |
+
81,102
|
| 84 |
+
82,103
|
| 85 |
+
83,104
|
| 86 |
+
84,105
|
| 87 |
+
85,106
|
| 88 |
+
86,107
|
| 89 |
+
87,108
|