Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +33 -0
- LICENSE +21 -0
- README.md +88 -0
- UniSpeech/.gitignore +3 -0
- UniSpeech/ILS-SSL/README.md +63 -0
- UniSpeech/LICENSE +74 -0
- UniSpeech/README.md +148 -0
- UniSpeech/SECURITY.md +41 -0
- UniSpeech/UniSpeech-SAT/README.md +76 -0
- UniSpeech/UniSpeech-SAT/UniSpeech_SAT_SUPERB_Results.png +3 -0
- UniSpeech/UniSpeech/README.md +56 -0
- UniSpeech/WavLM/README.md +125 -0
- UniSpeech/WavLM/WavLM.py +743 -0
- UniSpeech/WavLM/WavLM_ASR.PNG +3 -0
- UniSpeech/WavLM/WavLM_SUPERB_Leaderboard.png +3 -0
- UniSpeech/WavLM/WavLM_SUPERB_Results.png +3 -0
- UniSpeech/WavLM/modules.py +827 -0
- UniSpeech/azure-pipelines.yml +47 -0
- UniSpeech/downstreams/speaker_diarization/README.md +34 -0
- UniSpeech/downstreams/speaker_diarization/config/infer_est_nspk1.yaml +31 -0
- UniSpeech/downstreams/speaker_diarization/config/unispeech_sat.th +0 -0
- UniSpeech/downstreams/speaker_diarization/diarization.py +321 -0
- UniSpeech/downstreams/speaker_diarization/models/models.py +391 -0
- UniSpeech/downstreams/speaker_diarization/models/transformer.py +147 -0
- UniSpeech/downstreams/speaker_diarization/models/utils.py +78 -0
- UniSpeech/downstreams/speaker_diarization/requirements.txt +136 -0
- UniSpeech/downstreams/speaker_diarization/tmp/mix_0000496.wav +1 -0
- UniSpeech/downstreams/speaker_diarization/utils/dataset.py +484 -0
- UniSpeech/downstreams/speaker_diarization/utils/kaldi_data.py +162 -0
- UniSpeech/downstreams/speaker_diarization/utils/parse_options.sh +97 -0
- UniSpeech/downstreams/speaker_diarization/utils/utils.py +189 -0
- UniSpeech/downstreams/speaker_verification/README.md +49 -0
- UniSpeech/downstreams/speaker_verification/config/unispeech_sat.th +0 -0
- UniSpeech/downstreams/speaker_verification/models/__init__.py +0 -0
- UniSpeech/downstreams/speaker_verification/models/ecapa_tdnn.py +301 -0
- UniSpeech/downstreams/speaker_verification/models/utils.py +78 -0
- UniSpeech/downstreams/speaker_verification/requirements.txt +85 -0
- UniSpeech/downstreams/speaker_verification/verification.py +67 -0
- UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav +3 -0
- UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav +3 -0
- UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav +3 -0
- UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/RFyw7V3SOnQ_0000001.wav +3 -0
- UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/HladKGyKTLM_0000006.wav +3 -0
- UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/mHTAr5dlAgc_0000004.wav +3 -0
- UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/WbB8m9-wlIQ_0000001.wav +3 -0
- UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/gFfcgOVmiO0_0000002.wav +3 -0
- UniSpeech/src/CODE_OF_CONDUCT.md +77 -0
- UniSpeech/src/CONTRIBUTING.md +28 -0
- UniSpeech/src/LICENSE +21 -0
- UniSpeech/src/config/config.yaml +111 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
UniSpeech/UniSpeech-SAT/UniSpeech_SAT_SUPERB_Results.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
UniSpeech/WavLM/WavLM_ASR.PNG filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
UniSpeech/WavLM/WavLM_SUPERB_Leaderboard.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
UniSpeech/WavLM/WavLM_SUPERB_Results.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/RFyw7V3SOnQ_0000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/HladKGyKTLM_0000006.wav filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/mHTAr5dlAgc_0000004.wav filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/WbB8m9-wlIQ_0000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/gFfcgOVmiO0_0000002.wav filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
UniSpeech/src/examples/speaker_verification/vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
UniSpeech/src/examples/speaker_verification/vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
UniSpeech/src/examples/speaker_verification/vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
UniSpeech/src/examples/speaker_verification/vox1_data/Josh_Gad/RFyw7V3SOnQ_0000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
UniSpeech/src/examples/speaker_verification/vox1_data/Lea_Thompson/HladKGyKTLM_0000006.wav filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
UniSpeech/src/examples/speaker_verification/vox1_data/Lea_Thompson/mHTAr5dlAgc_0000004.wav filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
UniSpeech/src/examples/speaker_verification/vox1_data/Zulay_Henao/WbB8m9-wlIQ_0000001.wav filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
UniSpeech/src/examples/speaker_verification/vox1_data/Zulay_Henao/gFfcgOVmiO0_0000002.wav filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
UniSpeech/src/examples/unispeech/data/en/pretrain_1350_16k.id filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
UniSpeech/src/examples/unispeech/data/en/pretrain_1350_16k.tsv filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
UniSpeech/src/examples/unispeech/data/es/pretrain_168_16k.phone filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
UniSpeech/src/examples/unispeech/data/es/pretrain_168_16k_sep.id filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
UniSpeech/src/examples/unispeech/data/es/pretrain_168_16k_share.id filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
UniSpeech/src/examples/unispeech/data/fr/pretrain_353_16k.phone filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
UniSpeech/src/examples/unispeech/data/fr/pretrain_353_16k.tsv filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
UniSpeech/src/examples/unispeech/data/fr/pretrain_353_16k_sep.id filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
UniSpeech/src/examples/unispeech/data/fr/pretrain_353_16k_share.id filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
UniSpeech/src/examples/unispeech/data/it/pretrain_90_16k_sep.id filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
UniSpeech/src/fairseq/data/data_utils_fast.cpython-36m-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
UniSpeech/src/fairseq/data/data_utils_fast.cpython-37m-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
audio_high_quality_lmao.txt filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 YE Zhen
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[](https://arxiv.org/abs/2502.04128)
|
| 2 |
+
|
| 3 |
+
# X-Codec-2.0
|
| 4 |
+
Paper: LLaSA: Scaling Train-time and Inference-time Compute for LLaMA-based Speech Synthesis
|
| 5 |
+
|
| 6 |
+
**Update (2025-02-13):** Add [Llasa finetune instruction](https://github.com/zhenye234/LLaSA_training/tree/main/finetune).
|
| 7 |
+
|
| 8 |
+
**Update (2025-02-07):** Our paper has been released!
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
## Directly used on Hugging Face
|
| 12 |
+
|
| 13 |
+
**Codec**: [xcodec2](https://huggingface.co/HKUST-Audio/xcodec2) (Use `xcodec2==0.1.5` for codec inference and llasa fine-tuning. I’ve removed unnecessary dependencies, and it works fine in my testing. However, I’m not sure if other problems may arise. If you prefer more stability, I recommend using `xcodec2==0.1.3` which accurately aligns during my codec training.)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
**Llasa-collections**: [Llasa-collections](https://huggingface.co/collections/HKUSTAudio/llasa-679b87dbd06ac556cc0e0f44)
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- **Single Vector Quantization**
|
| 21 |
+
- 65536 Codebook Size using Finite Scalar Quantization achieving 99% codebook usage. ( comparable to text tokenizers, LLaMA3 128256)
|
| 22 |
+
- 50x1 Tokens per Second
|
| 23 |
+
|
| 24 |
+
- **Multilingual Speech Semantic Support**
|
| 25 |
+
- Uses Wav2Vec2-BERT, a semantic encoder pre-trained on 4.5M hours of unlabeled audio data covering more than 143 languages.
|
| 26 |
+
- Codec trained on 150k hours of multilingual speech data, including Emilia (En/Zh/De/Fr/Ja/Ko) and MLS (En/Fr/De/Nl/Es/It/Pt/Pl).
|
| 27 |
+
|
| 28 |
+
- **High-Quality Speech Reconstruction**
|
| 29 |
+
- Transformer + Vocos Decoder
|
| 30 |
+
- BigCodec encoder
|
| 31 |
+
- Spec discriminator with FFT sizes {78, 126, 206, 334, 542, 876, 1418, 2296} tailored for transformer decoder. [Details here](https://openreview.net/pdf?id=4YpMrGfldX)
|
| 32 |
+
- Achieving UTMOS 4.13 WER 2.47 (hubert-large-ls960-ft) sim 0.82 (wavlm_large_finetune) stoi 0.92 pesq-nb 3.05 pesq-wb 2.44 on librispeech-test-clean reconstruction (gt: WER 1.96 UTMOS 4.09)
|
| 33 |
+
- Only for 16kHz speech
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
## Commandline Usage
|
| 37 |
+
## Setup
|
| 38 |
+
Code is tested on `python3.9`
|
| 39 |
+
|
| 40 |
+
Please follow the following steps to setup your environment
|
| 41 |
+
1. Clone this repo
|
| 42 |
+
2. conda create --name xcodec2 python=3.9
|
| 43 |
+
3. conda activate xcodec2
|
| 44 |
+
2. `pip install -r requirements.txt`
|
| 45 |
+
3. [Download the pretrained checkpoint here](https://huggingface.co/HKUST-Audio/xcodec2/blob/main/ckpt/epoch%3D4-step%3D1400000.ckpt)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## Inference
|
| 49 |
+
```bash
|
| 50 |
+
python inference.py
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Train
|
| 54 |
+
To train a XCodec2, firstly you have to prepare your data
|
| 55 |
+
|
| 56 |
+
1. Make a file list by:
|
| 57 |
+
```bash
|
| 58 |
+
python get_tsv.py
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
2. Train a X-Codec-2.0 with the default setting by:
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
python train.py log_dir=/path/to/log_dir
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Large-scale training, Batch inference and large-scale code extracting:
|
| 68 |
+
|
| 69 |
+
Batch inference
|
| 70 |
+
```bash
|
| 71 |
+
python inference_save_code.py
|
| 72 |
+
```
|
| 73 |
+
Training
|
| 74 |
+
```bash
|
| 75 |
+
Sbatch train_slurm.sh
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
Code extracting
|
| 79 |
+
```bash
|
| 80 |
+
Sbatch large_scale_save_code.sh
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Code will save in output folder with the same subfolder structure for audio file.
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
## Acknowledgement
|
| 88 |
+
I would like to extend a special thanks to authors of BigCodec, since our code base is mainly borrowed from [BigCodec](https://github.com/Aria-K-Alethia/BigCodec).
|
UniSpeech/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.idea/
|
| 2 |
+
|
| 3 |
+
__pycache__/
|
UniSpeech/ILS-SSL/README.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# ILS-SSL
|
| 3 |
+
|
| 4 |
+
> [**ILS-SSL**](https://arxiv.org/pdf/2112.08778.pdf): Self-Supervised Learning for Speech Recognition with Intermediate Layer Supervision
|
| 5 |
+
|
| 6 |
+
The data preparation and pre-training for the first iteration follow the same pipeline as Hubert. We give example scripts for ILS-Hubert pre-training and fine-tuning in src/examples/hubert/scripts
|
| 7 |
+
|
| 8 |
+
## Pre-Trained and Fine-tuned Models
|
| 9 |
+
Model | Pretraining Dataset | Finetuning Dataset | Model
|
| 10 |
+
|---|---|---|---
|
| 11 |
+
ILS-Base | 960h LibriSpeech | - | [Download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/el_hubert_4_12/checkpoint_best.pt?st=2022-01-04T08%3A05%3A24Z&se=2024-01-05T08%3A05%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=JI8ZOgBhrrKUY4DE2ommnKpyAUuX6OrHfWgdjAT2Xnc%3D)
|
| 12 |
+
ILS-Large | 60k hrs Libri-Light | - | [Download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ils_hubert_large/checkpoint_fixed.pt?st=2022-01-04T08%3A24%3A37Z&se=2025-01-05T08%3A24%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=Dv6svAaI7Td%2BZWUTjTFkhChFbpnAAU6xKNjPbPQnIKM%3D)
|
| 13 |
+
ILS-Large | 60k hrs Libri-Light | 960h LibriSpeech | [Download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ils_hubert_large/checkpoint_ft.pt?st=2022-01-04T08%3A40%3A17Z&se=2025-01-05T08%3A40%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=GKIe%2F1kz%2F1fjGTsQsakJy68jlsFDbKmIVYjH61dhrwA%3D)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## Results on Librispeech
|
| 17 |
+
Base Model | Finetuning set| LM | test-clean | test-other
|
| 18 |
+
|---|---|---|---|---
|
| 19 |
+
wav2vec2.0 | 1 hour | None | 24.5 | 29.7
|
| 20 |
+
Hubert | 1 hour | None| 20.9 | 27.5
|
| 21 |
+
ILS-SSL | 1 hour | None | 17.9 | 23.1
|
| 22 |
+
wav2vec2.0 | 1 hour | 4-gram | 5.5 | 11.3
|
| 23 |
+
Hubert | 1 hour | 4-gram | 6.1 | 11.3
|
| 24 |
+
ILS-SSL | 1 hour | 4-gram | 5.4 | 10.2
|
| 25 |
+
wav2vec2.0 | 10 hour | None | 11.1 | 17.6
|
| 26 |
+
Hubert | 10 hour | None| 10.1 | 16.8
|
| 27 |
+
ILS-SSL | 10 hour | None | 8.3 | 13.6
|
| 28 |
+
wav2vec2.0 | 10 hour | 4-gram | 4.3 | 9.5
|
| 29 |
+
Hubert | 10 hour | 4-gram | 4.3 | 9.4
|
| 30 |
+
ILS-SSL | 10 hour | 4-gram | 3.8 | 8.1
|
| 31 |
+
wav2vec2.0 | 100 hour | None | 6.1 | 13.3
|
| 32 |
+
Hubert | 100 hour | None| 6.3 | 13.2
|
| 33 |
+
ILS-SSL | 100 hour | None | 4.7 | 10.1
|
| 34 |
+
wav2vec2.0 | 100 hour | 4-gram | 3.4| 8.0
|
| 35 |
+
Hubert | 100 hour | 4-gram | 3.4 | 8.1
|
| 36 |
+
ILS-SSL | 100 hour | 4-gram | 3.0 | 6.9
|
| 37 |
+
|
| 38 |
+
Large Model | Finetuning set| LM | test-clean | test-other
|
| 39 |
+
|---|---|---|---|---
|
| 40 |
+
wav2vec2.0 | 1 hour | None | 17.2 | 20.3
|
| 41 |
+
Hubert | 1 hour | None| 17.4 | 20.3
|
| 42 |
+
ILS-SSL | 1 hour | None | 14.3 | 16.9
|
| 43 |
+
wav2vec2.0 | 1 hour | Transf | 2.9 | 5.8
|
| 44 |
+
Hubert | 1 hour | Transf | 2.9 | 5.4
|
| 45 |
+
ILS-SSL | 1 hour | Transf | 2.8 | 5.3
|
| 46 |
+
wav2vec2.0 | 10 hour | None | 6.3 | 10.0
|
| 47 |
+
Hubert | 10 hour | None | 6.2 | 9.6
|
| 48 |
+
ILS-SSL | 10 hour | None | 6.1 | 9.1
|
| 49 |
+
wav2vec2.0 | 10 hour | Transf | 2.6 | 4.9
|
| 50 |
+
Hubert | 10 hour | Transf | 2.4 | 4.6
|
| 51 |
+
ILS-SSL | 10 hour | Transf | 2.5 | 4.5
|
| 52 |
+
wav2vec2.0 | 100 hour | None | 3.1 | 6.3
|
| 53 |
+
Hubert | 100 hour | None| 2.9 | 6.0
|
| 54 |
+
ILS-SSL | 100 hour | None | 2.9 | 5.8
|
| 55 |
+
wav2vec2.0 | 100 hour | Transf | 2.0 | 4.0
|
| 56 |
+
Hubert | 100 hour | Transf | 2.1 | 3.9
|
| 57 |
+
ILS-SSL | 100 hour | Transf | 2.0 | 4.0
|
| 58 |
+
wav2vec2.0 | 960 hour | None | 2.2 | 4.5
|
| 59 |
+
Hubert | 960 hour | None | 2.1 | 4.3
|
| 60 |
+
ILS-SSL | 960 hour | None | 1.9 | 3.8
|
| 61 |
+
wav2vec2.0 | 960 hour | Transf | 1.8 | 3.3
|
| 62 |
+
Hubert | 960 hour | Transf | 1.9 | 3.3
|
| 63 |
+
ILS-SSL | 960 hour | Transf | 1.8 | 3.2
|
UniSpeech/LICENSE
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution-ShareAlike 3.0 Unported
|
| 2 |
+
|
| 3 |
+
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS LICENSE DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE INFORMATION PROVIDED, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM ITS USE.
|
| 4 |
+
|
| 5 |
+
License
|
| 6 |
+
|
| 7 |
+
THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED.
|
| 8 |
+
|
| 9 |
+
BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS.
|
| 10 |
+
|
| 11 |
+
1. Definitions
|
| 12 |
+
|
| 13 |
+
"Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License.
|
| 14 |
+
"Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(f) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined below) for the purposes of this License.
|
| 15 |
+
"Creative Commons Compatible License" means a license that is listed at https://creativecommons.org/compatiblelicenses that has been approved by Creative Commons as being essentially equivalent to this License, including, at a minimum, because that license: (i) contains terms that have the same purpose, meaning and effect as the License Elements of this License; and, (ii) explicitly permits the relicensing of adaptations of works made available under that license under this License or a Creative Commons jurisdiction license with the same License Elements as this License.
|
| 16 |
+
"Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership.
|
| 17 |
+
"License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, ShareAlike.
|
| 18 |
+
"Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License.
|
| 19 |
+
"Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast.
|
| 20 |
+
"Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work.
|
| 21 |
+
"You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation.
|
| 22 |
+
"Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images.
|
| 23 |
+
"Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium.
|
| 24 |
+
|
| 25 |
+
2. Fair Dealing Rights. Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws.
|
| 26 |
+
|
| 27 |
+
3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below:
|
| 28 |
+
|
| 29 |
+
to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections;
|
| 30 |
+
to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified.";
|
| 31 |
+
to Distribute and Publicly Perform the Work including as incorporated in Collections; and,
|
| 32 |
+
to Distribute and Publicly Perform Adaptations.
|
| 33 |
+
|
| 34 |
+
For the avoidance of doubt:
|
| 35 |
+
Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License;
|
| 36 |
+
Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and,
|
| 37 |
+
Voluntary License Schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License.
|
| 38 |
+
|
| 39 |
+
The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved.
|
| 40 |
+
|
| 41 |
+
4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions:
|
| 42 |
+
|
| 43 |
+
You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(c), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(c), as requested.
|
| 44 |
+
You may Distribute or Publicly Perform an Adaptation only under the terms of: (i) this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons Compatible License. If you license the Adaptation under one of the licenses mentioned in (iv), you must comply with the terms of that license. If you license the Adaptation under the terms of any of the licenses mentioned in (i), (ii) or (iii) (the "Applicable License"), you must comply with the terms of the Applicable License generally and the following provisions: (I) You must include a copy of, or the URI for, the Applicable License with every copy of each Adaptation You Distribute or Publicly Perform; (II) You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License; (III) You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform; (IV) when You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License.
|
| 45 |
+
If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and (iv) , consistent with Ssection 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(c) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties.
|
| 46 |
+
Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise.
|
| 47 |
+
|
| 48 |
+
5. Representations, Warranties and Disclaimer
|
| 49 |
+
|
| 50 |
+
UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU.
|
| 51 |
+
|
| 52 |
+
6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 53 |
+
|
| 54 |
+
7. Termination
|
| 55 |
+
|
| 56 |
+
This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License.
|
| 57 |
+
Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above.
|
| 58 |
+
|
| 59 |
+
8. Miscellaneous
|
| 60 |
+
|
| 61 |
+
Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License.
|
| 62 |
+
Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License.
|
| 63 |
+
If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
|
| 64 |
+
No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent.
|
| 65 |
+
This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You.
|
| 66 |
+
The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law.
|
| 67 |
+
|
| 68 |
+
Creative Commons Notice
|
| 69 |
+
|
| 70 |
+
Creative Commons is not a party to this License, and makes no warranty whatsoever in connection with the Work. Creative Commons will not be liable to You or any party on any legal theory for any damages whatsoever, including without limitation any general, special, incidental or consequential damages arising in connection to this license. Notwithstanding the foregoing two (2) sentences, if Creative Commons has expressly identified itself as the Licensor hereunder, it shall have all rights and obligations of Licensor.
|
| 71 |
+
|
| 72 |
+
Except for the limited purpose of indicating to the public that the Work is licensed under the CCPL, Creative Commons does not authorize the use by either party of the trademark "Creative Commons" or any related trademark or logo of Creative Commons without the prior written consent of Creative Commons. Any permitted use will be in compliance with Creative Commons' then-current trademark usage guidelines, as may be published on its website or otherwise made available upon request from time to time. For the avoidance of doubt, this trademark restriction does not form part of the License.
|
| 73 |
+
|
| 74 |
+
Creative Commons may be contacted at https://creativecommons.org/.
|
UniSpeech/README.md
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UniSpeech
|
| 2 |
+
|
| 3 |
+
<!--**Pre-trained models for speech related tasks**-->
|
| 4 |
+
|
| 5 |
+
The family of UniSpeech:
|
| 6 |
+
> [**WavLM**](https://arxiv.org/pdf/2110.13900.pdf) (```arXiv```): **WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing**
|
| 7 |
+
|
| 8 |
+
> [**UniSpeech**](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech) (```ICML 2021```): **Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR**
|
| 9 |
+
|
| 10 |
+
> [**UniSpeech-SAT**](https://arxiv.org/pdf/2110.05752.pdf) (```ICASSP 2022 Submission```): **Universal Speech Representation Learning with Speaker Aware Pre-Training**
|
| 11 |
+
|
| 12 |
+
> [**ILS-SSL**](https://arxiv.org/pdf/2112.08778.pdf) (```ICASSP 2022 Submission```): **Self-Supervised Learning for Speech Recognition with Intermediate Layer Supervision**
|
| 13 |
+
|
| 14 |
+
Model introductions, evaluation results, and model inference instructions are located in their corresponding folders. The source code is here [https://github.com/microsoft/UniSpeech/tree/main/src].
|
| 15 |
+
|
| 16 |
+
## Update
|
| 17 |
+
- [HuggingFace Integration] Dec 23, 2021: [**WavLM**](https://huggingface.co/models?other=wavlm) models are on [HuggingFace](https://huggingface.co/models?other=wavlm) .
|
| 18 |
+
- [HuggingFace Integration] Octorber 26, 2021: [**UniSpeech-SAT**](https://huggingface.co/microsoft/unispeech-sat-large) models are on [HuggingFace](https://huggingface.co/models?other=unispeech-sat) .
|
| 19 |
+
- [Model Release] Octorber 13, 2021: [**UniSpeech-SAT**](https://arxiv.org/pdf/2110.05752.pdf) models are releaseed.
|
| 20 |
+
- [HuggingFace Integration] Octorber 11, 2021: [**UniSpeech**](https://huggingface.co/microsoft/unispeech-large-1500h-cv) models are on [HuggingFace](https://huggingface.co/models?other=unispeech) .
|
| 21 |
+
- [Model Release] June, 2021: [**UniSpeech v1**](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech) models are released.
|
| 22 |
+
## Pre-trained models
|
| 23 |
+
We strongly suggest using our UniSpeech-SAT model for speaker related tasks, since it shows very powerful performance on various speaker related benchmarks.
|
| 24 |
+
Model | Pretraining Dataset | Finetuning Dataset | Model
|
| 25 |
+
|---|---|---|-----
|
| 26 |
+
UniSpeech Large EN | [Labeled: 1350 hrs en](https://commonvoice.mozilla.org/) | - | [download](https://releasemodel.blob.core.windows.net/models/CommonVoicePretrainedModel/CommonVoiceEnglishPretrainedModel/checkpoint_best.pt?sv=2019-12-12&st=2021-07-14T09%3A00%3A07Z&se=2022-07-15T09%3A00%3A00Z&sr=b&sp=r&sig=5sxvEwVRoGtkazNQYkOuFLlPYau8nl5Ng%2FfRJa0Vnc4%3D)
|
| 27 |
+
UniSpeech Large Multilingual | [Labeled: 1350 hrs en + 353 hrs fr + 168 hrs es + 90 hrs it](https://commonvoice.mozilla.org/) | - | [download](https://releasemodel.blob.core.windows.net/models/CommonVoicePretrainedModel/CommonVoiceMultilingualPretrainedModel/checkpoint_best.pt?sv=2019-12-12&st=2021-07-14T09%3A00%3A39Z&se=2022-07-15T09%3A00%3A00Z&sr=b&sp=r&sig=y%2Fd3rqtbyqW0ZCwR7Czho5any90khA%2Ft3w9PTZ6N9vU%3D)
|
| 28 |
+
Unispeech Large+ | [Labeled: 1350 hrs en, Unlabeled: 353 hrs fr](https://commonvoice.mozilla.org/) | - | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_fr353.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A44%3A54Z&se=2023-10-26T06%3A44%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=7tYuYMxVFfM2Vgi%2BoqUh%2ByJXD4hSuoafHgBP5VZApw0%3D)
|
| 29 |
+
UniSpeech Large+ | [Labeld: 1350 hrs en, Unlabeled: 168 hrs es](https://commonvoice.mozilla.org/) | - | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_es168.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A39%3A37Z&se=2023-10-26T06%3A39%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=T2B5%2BlOI6v64TNdLSe9rdp3R%2B9Q2E35taUOigGW0nsQ%3D)
|
| 30 |
+
UniSpeech Large+ | [Labeled: 1350 hrs en, Unlabeld: 90 hrs it](https://commonvoice.mozilla.org/) | -| [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_it90.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A52%3A08Z&se=2023-10-26T06%3A52%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=kXsSJXK9r8UEYlUr2LaJxtPf8m9J2G23MfG725k2DBk%3D)
|
| 31 |
+
UniSpeech Large Multilingual | [Labeled: 1350 hrs en + 353 hrs fr + 168 hrs es + 90 hrs it, Unlabeled: 17 hrs ky](https://commonvoice.mozilla.org/) | - | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_ky17.large.many2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A53%3A00Z&se=2022-10-26T06%3A53%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=oCQecalXzC5daaurLLJGQdFNtfYwsBM6pNQrDAsf5i0%3D)
|
| 32 |
+
UniSpeech Large+ | [Labeled: 1350 hrs en, Unlabeled: 353 hrs fr](https://commonvoice.mozilla.org/) | 1 hr fr | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ft_fr-pt_fr353.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A27%3A53Z&se=2023-10-26T06%3A27%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=9vEa3xqzWu7SYkACn9TQqDtcm%2BKmUcOHhabjbjZuPys%3D)
|
| 33 |
+
UniSpeech Large+ | [Labeld: 1350 hrs en, Unlabeled: 168 hrs es](https://commonvoice.mozilla.org/) | 1 hr es | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ft_es-pt_es168.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A21%3A34Z&se=2024-10-26T06%3A21%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=G%2B0RddgOh653UzXG95Ljuwv7aG3tu9gXtPXn1ixCiug%3D)
|
| 34 |
+
UniSpeech Large+ | [Labeled: 1350 hrs en, Unlabeld: 90 hrs it](https://commonvoice.mozilla.org/) | 1 hr it | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ft_it-pt_it90.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A36%3A17Z&se=2023-10-26T06%3A36%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=e1WD9uOCo9sCAdH%2FPZQ4wCD30aCDpZvvu43kJrqq2HE%3D)
|
| 35 |
+
UniSpeech Large Multilingual | [Labeled: 1350 hrs en + 353 hrs fr + 168 hrs es + 90 hrs it, Unlabeled: 17 hrs ky](https://commonvoice.mozilla.org/) | 1 hr ky | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_ky17.large.many2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A54%3A04Z&se=2023-10-26T06%3A54%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=2K3VjMcsbKfBkLVyDlqGhVpIX%2B2ZcA5DTlMhjdkXo3g%3D)
|
| 36 |
+
UniSpeech-SAT Base | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [download](https://valle.blob.core.windows.net/share/unispeech-sat/unispeech_repo/UniSpeech-SAT-Base.pt?sv=2021-10-04&st=2024-01-30T06%3A26%3A06Z&se=2094-01-31T06%3A26%3A00Z&sr=b&sp=r&sig=Ts8al%2FPc%2BksI%2BY4tKvDVZhmyw02c9pMhFDxLrPntSd0%3D)
|
| 37 |
+
UniSpeech-SAT Base+ | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main) | - | [download](https://valle.blob.core.windows.net/share/unispeech-sat/unispeech_repo/UniSpeech-SAT-Base+.pt?sv=2021-10-04&st=2024-01-30T06%3A26%3A25Z&se=2094-01-31T06%3A26%3A00Z&sr=b&sp=r&sig=m6XAIXsC4rVNDW%2FFwXPNKjX2A%2BV9zBwmWAV93vAXcvc%3D)
|
| 38 |
+
UniSpeech-SAT Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main) | - | [download](https://valle.blob.core.windows.net/share/unispeech-sat/unispeech_repo/UniSpeech-SAT-Large.pt?sv=2021-10-04&st=2024-01-30T06%3A26%3A43Z&se=2094-01-31T06%3A26%3A00Z&sr=b&sp=r&sig=TJXNIfJzB%2Bsja3uh2xbCUxdbvp8gQP0zzlmZvU2si1I%3D)
|
| 39 |
+
WavLM Base | [960 hrs LibriSpeech](http://www.openslr.org/12)| - | [download](https://valle.blob.core.windows.net/share/wavlm/unispeech_repo/WavLM-Base.pt?sv=2021-10-04&st=2024-01-30T06%3A27%3A08Z&se=2094-01-31T06%3A27%3A00Z&sr=b&sp=r&sig=ThlNPycn578KFcON1NTJ7hzkpNLZR%2B3D4ImTgXQR%2B9E%3D)
|
| 40 |
+
WavLM Base+ | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [download](https://valle.blob.core.windows.net/share/wavlm/unispeech_repo/WavLM-Base+.pt?sv=2021-10-04&st=2024-01-30T06%3A27%3A23Z&se=2094-01-31T06%3A27%3A00Z&sr=b&sp=r&sig=6XiFiDMiKNRLRYzNNAL9UWm0dAAFuRweRLFZ2h9IYzg%3D)
|
| 41 |
+
WavLM Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [download](https://valle.blob.core.windows.net/share/wavlm/unispeech_repo/WavLM-Large.pt?sv=2021-10-04&st=2024-01-30T06%3A27%3A39Z&se=2094-01-31T06%3A27%3A00Z&sr=b&sp=r&sig=eIFLuFlZxmBR7a642KnnLen7yoSzv465iLHLKokO7VM%3D)
|
| 42 |
+
|
| 43 |
+
## Universal Representation Evaluation on SUPERB
|
| 44 |
+

|
| 45 |
+
|
| 46 |
+
## Downstream Task Performance
|
| 47 |
+
We also evaluate our models on typical speaker related benchmarks.
|
| 48 |
+
### Speaker Verification
|
| 49 |
+
Finetune the model with VoxCeleb2 dev data, and evaluate it on the [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/#:~:text=VoxCeleb%20is%20an%20audio%2Dvisual,interview%20videos%20uploaded%20to%20YouTube)
|
| 50 |
+
| Model |Fix pre-train| Vox1-O | Vox1-E | Vox1-H |
|
| 51 |
+
| ------------- |------------- | ---------- | ---------- | ---------- |
|
| 52 |
+
| ECAPA-TDNN | - | 0.87 | 1.12 | 2.12 |
|
| 53 |
+
| HuBERT large | Yes| 0.888 |0.912| 1.853 |
|
| 54 |
+
| Wav2Vec2.0 (XLSR)| Yes | 0.915| 0.945 |1.895|
|
| 55 |
+
| UniSpeech-SAT large | Yes | 0.771 | 0.781| 1.669|
|
| 56 |
+
| WavLM large | Yes | 0.59 | 0.65| 1.328|
|
| 57 |
+
| WavLM large | No | 0.505 | 0.579| 1.176|
|
| 58 |
+
|+Large Margin Finetune and Score Calibration|
|
| 59 |
+
| HuBERT large | No| 0.585| 0.654 |1.342|
|
| 60 |
+
| Wav2Vec2.0 (XLSR) | No| 0.564| 0.605 |1.23|
|
| 61 |
+
| UniSpeech-SAT large | No | 0.564 | 0.561| 1.23 |
|
| 62 |
+
| **WavLM large (New)** | No | **0.33** | **0.477**| **0.984** |
|
| 63 |
+
|
| 64 |
+
[Large-scale Self-Supervised Speech Representation Learning for Automatic Speaker Verification](https://arxiv.org/pdf/2110.05777.pdf)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
### Speech Separation
|
| 69 |
+
|
| 70 |
+
Evaluation on [LibriCSS](https://github.com/chenzhuo1011/libri_css)
|
| 71 |
+
| Model |0S | 0L | OV10 | OV20 |OV30 |OV40 |
|
| 72 |
+
| ---------------- |------| ------ | ------ | ------ | ------ | ------ |
|
| 73 |
+
| [Conformer](https://ieeexplore.ieee.org/abstract/document/9413423/) (SOTA) | 4.5 | 4.4 |6.2 |8.5| 11 |12.6|
|
| 74 |
+
| UniSpeech-SAT base | 4.4| 4.4 |5.4| 7.2| 9.2 |10.5|
|
| 75 |
+
| UniSpeech-SAT large | 4.3| 4.2 |5.0 |6.3| 8.2| 8.8|
|
| 76 |
+
| WavLM base+ | 4.5| 4.4 |5.6| 7.5| 9.4 |10.9|
|
| 77 |
+
| **WavLM large** | 4.2| 4.1 | 4.8 | 5.8 | 7.4| 8.5|
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
### Speaker Diarization
|
| 81 |
+
|
| 82 |
+
Evaluation on CALLHOME
|
| 83 |
+
| Model |spk_2 |spk_3| spk_4| spk_5| spk_6| spk_all |
|
| 84 |
+
| ---------------- |------| ------ | ------ | ------ | ------ | ------ |
|
| 85 |
+
| [EEND-vector clustering](https://arxiv.org/pdf/2105.09040.pdf) | 7.96| 11.93 |16.38| 21.21| 23.1 |12.49||
|
| 86 |
+
| [EEND-EDA clustering](https://arxiv.org/abs/2107.01545) (SOTA) | 7.11| 11.88 |14.37| 25.95| 21.95 |11.84||
|
| 87 |
+
| UniSpeech-SAT large | 5.93| 10.66| 12.9 |16.48| 23.25| 10.92|
|
| 88 |
+
| WavLM Base| 6.99| 11.12| 15.20 |16.48| 21.61| 11.75|
|
| 89 |
+
| **WavLm large** | 6.46| 10.69| 11.84 |12.89| 20.70| 10.35|
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
## License
|
| 93 |
+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
|
| 94 |
+
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) project.
|
| 95 |
+
|
| 96 |
+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
### Reference
|
| 100 |
+
If you find our work is useful in your research, please cite the following paper:
|
| 101 |
+
``` latex
|
| 102 |
+
@inproceedings{Wang2021UniSpeech,
|
| 103 |
+
author = {Chengyi Wang and Yu Wu and Yao Qian and Kenichi Kumatani and Shujie Liu and Furu Wei and Michael Zeng and Xuedong Huang},
|
| 104 |
+
editor = {Marina Meila and Tong Zhang},
|
| 105 |
+
title = {UniSpeech: Unified Speech Representation Learning with Labeled and
|
| 106 |
+
Unlabeled Data},
|
| 107 |
+
booktitle = {Proceedings of the 38th International Conference on Machine Learning,
|
| 108 |
+
{ICML} 2021, 18-24 July 2021, Virtual Event},
|
| 109 |
+
series = {Proceedings of Machine Learning Research},
|
| 110 |
+
volume = {139},
|
| 111 |
+
pages = {10937--10947},
|
| 112 |
+
publisher = {{PMLR}},
|
| 113 |
+
year = {2021},
|
| 114 |
+
url = {http://proceedings.mlr.press/v139/wang21y.html},
|
| 115 |
+
timestamp = {Thu, 21 Oct 2021 16:06:12 +0200},
|
| 116 |
+
biburl = {https://dblp.org/rec/conf/icml/0002WQK0WZ021.bib},
|
| 117 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
``` latex
|
| 122 |
+
@article{Chen2021WavLM,
|
| 123 |
+
title = {WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing},
|
| 124 |
+
author = {Sanyuan Chen and Chengyi Wang and Zhengyang Chen and Yu Wu and Shujie Liu and Zhuo Chen and Jinyu Li and Naoyuki Kanda and Takuya Yoshioka and Xiong Xiao and Jian Wu and Long Zhou and Shuo Ren and Yanmin Qian and Yao Qian and Jian Wu and Michael Zeng and Furu Wei},
|
| 125 |
+
eprint={2110.13900},
|
| 126 |
+
archivePrefix={arXiv},
|
| 127 |
+
primaryClass={cs.CL},
|
| 128 |
+
year={2021}
|
| 129 |
+
}
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
``` latex
|
| 133 |
+
@article{Chen2021UniSpeechSAT,
|
| 134 |
+
title = {UniSpeech-SAT: Universal Speech Representation Learning with Speaker Aware Pre-Training},
|
| 135 |
+
author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Zhengyang Chen and Zhuo Chen and Shujie Liu and Jian Wu and Yao Qian and Furu Wei and Jinyu Li and Xiangzhan Yu},
|
| 136 |
+
eprint={2110.05752},
|
| 137 |
+
archivePrefix={arXiv},
|
| 138 |
+
primaryClass={cs.CL},
|
| 139 |
+
year={2021}
|
| 140 |
+
}
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
### Contact Information
|
| 145 |
+
|
| 146 |
+
For help or issues using UniSpeech models, please submit a GitHub issue.
|
| 147 |
+
|
| 148 |
+
For other communications related to UniSpeech, please contact Yu Wu (`yuwu1@microsoft.com`).
|
UniSpeech/SECURITY.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.8 BLOCK -->
|
| 2 |
+
|
| 3 |
+
## Security
|
| 4 |
+
|
| 5 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
| 6 |
+
|
| 7 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
|
| 8 |
+
|
| 9 |
+
## Reporting Security Issues
|
| 10 |
+
|
| 11 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
| 12 |
+
|
| 13 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
|
| 14 |
+
|
| 15 |
+
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
|
| 16 |
+
|
| 17 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
|
| 18 |
+
|
| 19 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
| 20 |
+
|
| 21 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
| 22 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
| 23 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
| 24 |
+
* Any special configuration required to reproduce the issue
|
| 25 |
+
* Step-by-step instructions to reproduce the issue
|
| 26 |
+
* Proof-of-concept or exploit code (if possible)
|
| 27 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
| 28 |
+
|
| 29 |
+
This information will help us triage your report more quickly.
|
| 30 |
+
|
| 31 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
|
| 32 |
+
|
| 33 |
+
## Preferred Languages
|
| 34 |
+
|
| 35 |
+
We prefer all communications to be in English.
|
| 36 |
+
|
| 37 |
+
## Policy
|
| 38 |
+
|
| 39 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
|
| 40 |
+
|
| 41 |
+
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
UniSpeech/UniSpeech-SAT/README.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UniSpeech-SAT
|
| 2 |
+
|
| 3 |
+
> [**UniSpeech-SAT**](https://arxiv.org/pdf/2110.05752.pdf) (```ICASSP 2022 Submission```): **Universal Speech Representation Learning with Speaker Aware Pre-Training**
|
| 4 |
+
|
| 5 |
+
## Universal Representation Evaluation on SUPERB
|
| 6 |
+

|
| 7 |
+
|
| 8 |
+
## Downstream Task Performance
|
| 9 |
+
We also evaluate our models on typical speaker related benchmarks.
|
| 10 |
+
### Speaker Verification
|
| 11 |
+
| Model |Fix pre-train| Vox1-O | Vox1-E | Vox1-H |
|
| 12 |
+
| ------------- |------------- | ---------- | ---------- | ---------- |
|
| 13 |
+
| ECAPA-TDNN | - | 0.87 | 1.12 | 2.12 |
|
| 14 |
+
| HuBERT large | Yes| 0.888 |0.912| 1.853 |
|
| 15 |
+
| Wav2Vec2.0 (XLSR)| Yes | 0.915| 0.945 |1.895|
|
| 16 |
+
| UniSpeech-SAT large | Yes | 0.771 | 0.781| 1.669|
|
| 17 |
+
| HuBERT large | No| 0.585| 0.654 |1.342|
|
| 18 |
+
| Wav2Vec2.0 (XLSR) | No| **0.564**| 0.605 |**1.23**|
|
| 19 |
+
| **UniSpeech-SAT large** | No | **0.564** | **0.561** | **1.23** |
|
| 20 |
+
|
| 21 |
+
[Our paper for verification](https://arxiv.org/pdf/2110.05777.pdf)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
### Speech Separation
|
| 26 |
+
|
| 27 |
+
Evaluation on [LibriCSS](https://github.com/chenzhuo1011/libri_css)
|
| 28 |
+
|
| 29 |
+
| Model |0S | 0L | OV10 | OV20 |OV30 |OV40 |
|
| 30 |
+
| ---------------- |------| ------ | ------ | ------ | ------ | ------ |
|
| 31 |
+
| [Conformer](https://ieeexplore.ieee.org/abstract/document/9413423/) (SOTA) | 4.5 | 4.4 |6.2 |8.5| 11 |12.6|
|
| 32 |
+
| HuBERT base | 4.7| 4.6 | 6.1 | 7.9| 10.6| 12.3|
|
| 33 |
+
| UniSpeech-SAT base+ | 4.4| 4.4 |5.4| 7.2| 9.2 |10.5|
|
| 34 |
+
| **UniSpeech-SAT large** | **4.3**| **4.2** |**5.0** |**6.3**| **8.2**| **8.8**|
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
### Speaker Diarization
|
| 38 |
+
|
| 39 |
+
Evaluation on CALLHOME
|
| 40 |
+
|
| 41 |
+
| Model |spk_2 |spk_3| spk_4| spk_5| spk_6| spk_all |
|
| 42 |
+
| ---------------- |------| ------ | ------ | ------ | ------ | ------ |
|
| 43 |
+
| [EEND-vector clustering](https://arxiv.org/pdf/2105.09040.pdf) | 7.96| 11.93 |16.38| 21.21| 23.1 |12.49||
|
| 44 |
+
| [EEND-EDA clustering](https://arxiv.org/abs/2107.01545) (SOTA) | 7.11| 11.88 |14.37| 25.95| 21.95 |11.84||
|
| 45 |
+
| HuBERT base| 7.93|12.07| 15.21 |19.59| 23.32| 12.63|
|
| 46 |
+
| HuBERT large| 7.39| 11.97| 15.76 |19.82| 22.10| 12.40|
|
| 47 |
+
| **UniSpeech-SAT large** | **5.93**| **10.66**| **12.9** |**16.48**| **23.25**| **10.92**|
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## License
|
| 51 |
+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
|
| 52 |
+
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) project.
|
| 53 |
+
|
| 54 |
+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
### Reference
|
| 58 |
+
If you find our work is useful in your research, please cite the following paper:
|
| 59 |
+
|
| 60 |
+
``` latex
|
| 61 |
+
@article{Chen2021UniSpeechSAT,
|
| 62 |
+
title = {UniSpeech-SAT: Universal Speech Representation Learning with Speaker Aware Pre-Training},
|
| 63 |
+
author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Zhengyang Chen and Zhuo Chen and Shujie Liu and Jian Wu and Yao Qian and Furu Wei and Jinyu Li and Xiangzhan Yu},
|
| 64 |
+
eprint={2110.05752},
|
| 65 |
+
archivePrefix={arXiv},
|
| 66 |
+
primaryClass={cs.CL},
|
| 67 |
+
year={2021}
|
| 68 |
+
}
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
### Contact Information
|
| 73 |
+
|
| 74 |
+
For help or issues using UniSpeech models, please submit a GitHub issue.
|
| 75 |
+
|
| 76 |
+
For other communications related to UniSpeech, please contact Yu Wu (`yuwu1@microsoft.com`).
|
UniSpeech/UniSpeech-SAT/UniSpeech_SAT_SUPERB_Results.png
ADDED
|
Git LFS Details
|
UniSpeech/UniSpeech/README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UniSpeech
|
| 2 |
+
|
| 3 |
+
This is the official implementation of paper "[UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597)". The implementation mainly based on [fairseq](https://github.com/pytorch/fairseq) codebase. We release the training recipes on CommonVoice dataset.
|
| 4 |
+
|
| 5 |
+
## Requirements and Installation
|
| 6 |
+
|
| 7 |
+
- Pytorch >= 1.6.0
|
| 8 |
+
- python version >= 3.6
|
| 9 |
+
``` bash
|
| 10 |
+
cd src
|
| 11 |
+
pip install soundfile
|
| 12 |
+
pip install librosa
|
| 13 |
+
pip install pydub
|
| 14 |
+
pip install --editable ./
|
| 15 |
+
```
|
| 16 |
+
## Data Preparation
|
| 17 |
+
Download pretraining audio data from [here](https://commonvoice.mozilla.org/datasets). (We use the June 2020 release version in our paper).
|
| 18 |
+
Get the wav list and the transcription for each dataset by run:
|
| 19 |
+
```
|
| 20 |
+
python examples/unispeech/unispeech_manifest.py input_meta_file --dest examples/unispeech/data/LANG
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Then convert the audio files in common voices to 16k HZ using the commond:
|
| 24 |
+
```
|
| 25 |
+
python examples/unispeech/adjust_sample_rate.py --wav-path /path/to/wav/ --dest-path /path/to/16kwav/ --input examples/unispeech/data/LANG/*.tsv --output examples/unispeech/data/LANG/*_16k.tsv
|
| 26 |
+
```
|
| 27 |
+
For the finetuning data, our train/val/test splits are following [this](https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz).
|
| 28 |
+
The phoneme transcriptions are generated by [phonemizer](https://github.com/bootphon/phonemizer) to convert texts to phonemes. Then we create .id files using different vocabularies. All our pre-processed data as well as the dictionaries can be downloaded from [here].
|
| 29 |
+
|
| 30 |
+
## Pretraining
|
| 31 |
+
|
| 32 |
+
We give the training examples for large model here.
|
| 33 |
+
### Stage 1. Pretraining UniSpeech with labeled data.
|
| 34 |
+
The following script can be used to pre-train an English model:
|
| 35 |
+
```
|
| 36 |
+
bash examples/unispeech/scripts/one2one_large_pretrain_en1350.sh
|
| 37 |
+
```
|
| 38 |
+
To train a multilingual model:
|
| 39 |
+
```
|
| 40 |
+
bash examples/unispeech/scripts/multilingual_large_pretrain.sh
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### Stage 2. Continue pre-training with low-resource unlabeled data. (Optional)
|
| 44 |
+
After stage 1, you can continue pre-training the UniSpeech model with only contrastive loss:
|
| 45 |
+
```
|
| 46 |
+
bash examples/unispeech/scripts/continue_pretran.sh
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### Stage 3. Finetuning with low-resource labeled data.
|
| 50 |
+
Finally, fint-tune the model with 1 hour labeled data.
|
| 51 |
+
For multilingual models, you can choose to use separate vocabulary (examples/unispeech/data/en/vocab_sep.json) or shared vocabulary (examples/unispeech/data/en/vocab_share.json)
|
| 52 |
+
```
|
| 53 |
+
bash examples/unispeech/scripts/finetune.sh
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
|
UniSpeech/WavLM/README.md
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# WavLM
|
| 3 |
+
|
| 4 |
+
<!--**Pre-trained models for speech related tasks**-->
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
[**WavLM**](https://arxiv.org/pdf/2110.13900.pdf) : **WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing**
|
| 8 |
+
|
| 9 |
+
Official PyTorch implementation and pretrained models of WavLM
|
| 10 |
+
|
| 11 |
+
- Oct 2021: release preprint in [arXiv](https://arxiv.org/pdf/2110.13900.pdf)
|
| 12 |
+
|
| 13 |
+
## Pre-Trained Models
|
| 14 |
+
Model | Pre-training Dataset | Fine-tuning Dataset | Model
|
| 15 |
+
|---|---|---|---
|
| 16 |
+
WavLM Base | [960 hrs LibriSpeech](http://www.openslr.org/12)| - | [Azure Storage](https://valle.blob.core.windows.net/share/wavlm/WavLM-Base.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) <br> [Google Drive](https://drive.google.com/file/d/1BhTPLUkfN6e2xkqR8LEm9lByXbLY1IYd/view?usp=share_link)
|
| 17 |
+
WavLM Base+ | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [Azure Storage](https://valle.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) <br> [Google Drive](https://drive.google.com/file/d/1-zlAj2SyVJVsbhifwpTlAfrgc9qu-HDb/view?usp=share_link)
|
| 18 |
+
WavLM Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [Azure Storage](https://valle.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) <br> [Google Drive](https://drive.google.com/file/d/12-cB34qCTvByWT-QtOcZaqwwO21FLSqU/view?usp=share_link)
|
| 19 |
+
|
| 20 |
+
## Load Pre-Trained Models for Inference
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
import torch
|
| 24 |
+
from WavLM import WavLM, WavLMConfig
|
| 25 |
+
|
| 26 |
+
# load the pre-trained checkpoints
|
| 27 |
+
checkpoint = torch.load('/path/to/wavlm.pt')
|
| 28 |
+
cfg = WavLMConfig(checkpoint['cfg'])
|
| 29 |
+
model = WavLM(cfg)
|
| 30 |
+
model.load_state_dict(checkpoint['model'])
|
| 31 |
+
model.eval()
|
| 32 |
+
|
| 33 |
+
# extract the the representation of last layer
|
| 34 |
+
wav_input_16khz = torch.randn(1,10000)
|
| 35 |
+
rep = model.extract_features(wav_input_16khz)[0]
|
| 36 |
+
|
| 37 |
+
# extract the the representation of each layer
|
| 38 |
+
wav_input_16khz = torch.randn(1,10000)
|
| 39 |
+
rep, layer_results = model.extract_features(wav_input_16khz, output_layer=model.cfg.encoder_layers, ret_layer_results=True)[0]
|
| 40 |
+
layer_reps = [x.transpose(0, 1) for x, _ in layer_results]
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
## Universal Representation Evaluation on SUPERB
|
| 45 |
+

|
| 46 |
+
|
| 47 |
+

|
| 48 |
+
## Downstream Task Performance
|
| 49 |
+
We also evaluate our models on typical speech processing benchmarks.
|
| 50 |
+
### Speaker Verification
|
| 51 |
+
|
| 52 |
+
Evaluate on the [VoxCeleb](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/#:~:text=VoxCeleb%20is%20an%20audio%2Dvisual,interview%20videos%20uploaded%20to%20YouTube)
|
| 53 |
+
|
| 54 |
+
| Model |Fix pre-train| Vox1-O | Vox1-E | Vox1-H |
|
| 55 |
+
| ------------- |------------- | ---------- | ---------- | ---------- |
|
| 56 |
+
| ECAPA-TDNN | - | 0.87 | 1.12 | 2.12 |
|
| 57 |
+
| HuBERT large | Yes| 0.888 |0.912| 1.853 |
|
| 58 |
+
| Wav2Vec2.0 (XLSR)| Yes | 0.915| 0.945 |1.895|
|
| 59 |
+
| UniSpeech-SAT large | Yes | 0.771 | 0.781| 1.669|
|
| 60 |
+
| WavLM large | Yes | 0.638 | 0.687| 1.457|
|
| 61 |
+
| HuBERT large | No| 0.585| 0.654 |1.342|
|
| 62 |
+
| Wav2Vec2.0 (XLSR) | No| 0.564| 0.605 |1.23|
|
| 63 |
+
| UniSpeech-SAT large | No | 0.564 | 0.561| 1.23 |
|
| 64 |
+
| **WavLM large** | No | **0.431** | **0.538**| **1.154** |
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
### Speech Separation
|
| 69 |
+
|
| 70 |
+
Evaluation on the [LibriCSS](https://github.com/chenzhuo1011/libri_css)
|
| 71 |
+
| Model |0S | 0L | OV10 | OV20 |OV30 |OV40 |
|
| 72 |
+
| ---------------- |------| ------ | ------ | ------ | ------ | ------ |
|
| 73 |
+
| [Conformer](https://ieeexplore.ieee.org/abstract/document/9413423/) (SOTA) | 4.5 | 4.4 |6.2 |8.5| 11 |12.6|
|
| 74 |
+
| HuBERT base | 4.7| 4.6 | 6.1 | 7.9| 10.6| 12.3|
|
| 75 |
+
| UniSpeech-SAT base | 4.4| 4.4 |5.4| 7.2| 9.2 |10.5|
|
| 76 |
+
| UniSpeech-SAT large | 4.3| 4.2 |5.0 |6.3| 8.2| 8.8|
|
| 77 |
+
| WavLM base+ | 4.5| 4.4 |5.6| 7.5| 9.4 |10.9|
|
| 78 |
+
| **WavLM large** | 4.2| 4.1 | 4.8 | 5.8 | 7.4| 8.5|
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
### Speaker Diarization
|
| 82 |
+
|
| 83 |
+
Evaluation on the [CALLHOME](https://arxiv.org/pdf/1909.06247.pdf)
|
| 84 |
+
| Model |spk_2 |spk_3| spk_4| spk_5| spk_6| spk_all |
|
| 85 |
+
| ---------------- |------| ------ | ------ | ------ | ------ | ------ |
|
| 86 |
+
| [EEND-vector clustering](https://arxiv.org/pdf/2105.09040.pdf) | 7.96| 11.93 |16.38| 21.21| 23.1 |12.49||
|
| 87 |
+
| [EEND-EDA clustering](https://arxiv.org/abs/2107.01545) (SOTA) | 7.11| 11.88 |14.37| 25.95| 21.95 |11.84||
|
| 88 |
+
| HuBERT base| 7.93|12.07| 15.21 |19.59| 23.32| 12.63|
|
| 89 |
+
| HuBERT large| 7.39| 11.97| 15.76 |19.82| 22.10| 12.40|
|
| 90 |
+
| UniSpeech-SAT large| 5.93| 10.66| 12.9 |16.48| 23.25| 10.92|
|
| 91 |
+
| WavLM Base| 6.99| 11.12| 15.20 |16.48| 21.61| 11.75|
|
| 92 |
+
| **WavLm large** | 6.46| 10.69| 11.84 |12.89| 20.70| 10.35|
|
| 93 |
+
|
| 94 |
+
### Speech Recogntion
|
| 95 |
+
Evaluate on the [LibriSpeech](https://www.openslr.org/12)
|
| 96 |
+
|
| 97 |
+

|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
## License
|
| 101 |
+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
|
| 102 |
+
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) project.
|
| 103 |
+
|
| 104 |
+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
### Reference
|
| 108 |
+
If you find our work is useful in your research, please cite the following paper:
|
| 109 |
+
``` latex
|
| 110 |
+
@article{Chen2021WavLM,
|
| 111 |
+
title = {WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing},
|
| 112 |
+
author = {Sanyuan Chen and Chengyi Wang and Zhengyang Chen and Yu Wu and Shujie Liu and Zhuo Chen and Jinyu Li and Naoyuki Kanda and Takuya Yoshioka and Xiong Xiao and Jian Wu and Long Zhou and Shuo Ren and Yanmin Qian and Yao Qian and Jian Wu and Micheal Zeng and Furu Wei},
|
| 113 |
+
eprint={2110.13900},
|
| 114 |
+
archivePrefix={arXiv},
|
| 115 |
+
primaryClass={cs.CL},
|
| 116 |
+
year={2021}
|
| 117 |
+
}
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
### Contact Information
|
| 122 |
+
|
| 123 |
+
For help or issues using WavLM models, please submit a GitHub issue.
|
| 124 |
+
|
| 125 |
+
For other communications related to WavLM, please contact Yu Wu (`yuwu1@microsoft.com`).
|
UniSpeech/WavLM/WavLM.py
ADDED
|
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
| 4 |
+
# Copyright (c) 2021 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.nn import LayerNorm
|
| 20 |
+
from modules import (
|
| 21 |
+
Fp32GroupNorm,
|
| 22 |
+
Fp32LayerNorm,
|
| 23 |
+
GradMultiply,
|
| 24 |
+
MultiheadAttention,
|
| 25 |
+
SamePad,
|
| 26 |
+
init_bert_params,
|
| 27 |
+
get_activation_fn,
|
| 28 |
+
TransposeLast,
|
| 29 |
+
GLU_Linear,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def compute_mask_indices(
|
| 36 |
+
shape: Tuple[int, int],
|
| 37 |
+
padding_mask: Optional[torch.Tensor],
|
| 38 |
+
mask_prob: float,
|
| 39 |
+
mask_length: int,
|
| 40 |
+
mask_type: str = "static",
|
| 41 |
+
mask_other: float = 0.0,
|
| 42 |
+
min_masks: int = 0,
|
| 43 |
+
no_overlap: bool = False,
|
| 44 |
+
min_space: int = 0,
|
| 45 |
+
) -> np.ndarray:
|
| 46 |
+
"""
|
| 47 |
+
Computes random mask spans for a given shape
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
shape: the the shape for which to compute masks.
|
| 51 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
| 52 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
| 53 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
| 54 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
| 55 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
| 56 |
+
mask_type: how to compute mask lengths
|
| 57 |
+
static = fixed size
|
| 58 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
| 59 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
| 60 |
+
poisson = sample from possion distribution with lambda = mask length
|
| 61 |
+
min_masks: minimum number of masked spans
|
| 62 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
| 63 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
bsz, all_sz = shape
|
| 67 |
+
mask = np.full((bsz, all_sz), False)
|
| 68 |
+
|
| 69 |
+
all_num_mask = int(
|
| 70 |
+
# add a random number for probabilistic rounding
|
| 71 |
+
mask_prob * all_sz / float(mask_length)
|
| 72 |
+
+ np.random.rand()
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
all_num_mask = max(min_masks, all_num_mask)
|
| 76 |
+
|
| 77 |
+
mask_idcs = []
|
| 78 |
+
for i in range(bsz):
|
| 79 |
+
if padding_mask is not None:
|
| 80 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
| 81 |
+
num_mask = int(
|
| 82 |
+
# add a random number for probabilistic rounding
|
| 83 |
+
mask_prob * sz / float(mask_length)
|
| 84 |
+
+ np.random.rand()
|
| 85 |
+
)
|
| 86 |
+
num_mask = max(min_masks, num_mask)
|
| 87 |
+
else:
|
| 88 |
+
sz = all_sz
|
| 89 |
+
num_mask = all_num_mask
|
| 90 |
+
|
| 91 |
+
if mask_type == "static":
|
| 92 |
+
lengths = np.full(num_mask, mask_length)
|
| 93 |
+
elif mask_type == "uniform":
|
| 94 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 95 |
+
elif mask_type == "normal":
|
| 96 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
| 97 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
| 98 |
+
elif mask_type == "poisson":
|
| 99 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
| 100 |
+
lengths = [int(round(x)) for x in lengths]
|
| 101 |
+
else:
|
| 102 |
+
raise Exception("unknown mask selection " + mask_type)
|
| 103 |
+
|
| 104 |
+
if sum(lengths) == 0:
|
| 105 |
+
lengths[0] = min(mask_length, sz - 1)
|
| 106 |
+
|
| 107 |
+
if no_overlap:
|
| 108 |
+
mask_idc = []
|
| 109 |
+
|
| 110 |
+
def arrange(s, e, length, keep_length):
|
| 111 |
+
span_start = np.random.randint(s, e - length)
|
| 112 |
+
mask_idc.extend(span_start + i for i in range(length))
|
| 113 |
+
|
| 114 |
+
new_parts = []
|
| 115 |
+
if span_start - s - min_space >= keep_length:
|
| 116 |
+
new_parts.append((s, span_start - min_space + 1))
|
| 117 |
+
if e - span_start - keep_length - min_space > keep_length:
|
| 118 |
+
new_parts.append((span_start + length + min_space, e))
|
| 119 |
+
return new_parts
|
| 120 |
+
|
| 121 |
+
parts = [(0, sz)]
|
| 122 |
+
min_length = min(lengths)
|
| 123 |
+
for length in sorted(lengths, reverse=True):
|
| 124 |
+
lens = np.fromiter(
|
| 125 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
| 126 |
+
np.int,
|
| 127 |
+
)
|
| 128 |
+
l_sum = np.sum(lens)
|
| 129 |
+
if l_sum == 0:
|
| 130 |
+
break
|
| 131 |
+
probs = lens / np.sum(lens)
|
| 132 |
+
c = np.random.choice(len(parts), p=probs)
|
| 133 |
+
s, e = parts.pop(c)
|
| 134 |
+
parts.extend(arrange(s, e, length, min_length))
|
| 135 |
+
mask_idc = np.asarray(mask_idc)
|
| 136 |
+
else:
|
| 137 |
+
min_len = min(lengths)
|
| 138 |
+
if sz - min_len <= num_mask:
|
| 139 |
+
min_len = sz - num_mask - 1
|
| 140 |
+
|
| 141 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
| 142 |
+
|
| 143 |
+
mask_idc = np.asarray(
|
| 144 |
+
[
|
| 145 |
+
mask_idc[j] + offset
|
| 146 |
+
for j in range(len(mask_idc))
|
| 147 |
+
for offset in range(lengths[j])
|
| 148 |
+
]
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
| 152 |
+
|
| 153 |
+
min_len = min([len(m) for m in mask_idcs])
|
| 154 |
+
for i, mask_idc in enumerate(mask_idcs):
|
| 155 |
+
if len(mask_idc) > min_len:
|
| 156 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
| 157 |
+
mask[i, mask_idc] = True
|
| 158 |
+
|
| 159 |
+
return mask
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class WavLMConfig:
|
| 163 |
+
def __init__(self, cfg=None):
|
| 164 |
+
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
| 165 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
| 166 |
+
|
| 167 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
| 168 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
| 169 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
| 170 |
+
self.activation_fn: str = "gelu" # activation function to use
|
| 171 |
+
|
| 172 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
| 173 |
+
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
| 174 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
| 175 |
+
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
| 176 |
+
|
| 177 |
+
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
| 178 |
+
|
| 179 |
+
# dropouts
|
| 180 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
| 181 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
| 182 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
| 183 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
| 184 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
| 185 |
+
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
| 186 |
+
|
| 187 |
+
# masking
|
| 188 |
+
self.mask_length: int = 10 # mask length)
|
| 189 |
+
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
| 190 |
+
self.mask_selection: str = "static" # how to choose mask length
|
| 191 |
+
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
| 192 |
+
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
| 193 |
+
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
| 194 |
+
|
| 195 |
+
# channel masking
|
| 196 |
+
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
| 197 |
+
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
| 198 |
+
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
| 199 |
+
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
| 200 |
+
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
| 201 |
+
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
| 202 |
+
|
| 203 |
+
# positional embeddings
|
| 204 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
| 205 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
| 206 |
+
|
| 207 |
+
# relative position embedding
|
| 208 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
| 209 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
| 210 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
| 211 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
| 212 |
+
|
| 213 |
+
if cfg is not None:
|
| 214 |
+
self.update(cfg)
|
| 215 |
+
|
| 216 |
+
def update(self, cfg: dict):
|
| 217 |
+
self.__dict__.update(cfg)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class WavLM(nn.Module):
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
cfg: WavLMConfig,
|
| 224 |
+
) -> None:
|
| 225 |
+
super().__init__()
|
| 226 |
+
logger.info(f"WavLM Config: {cfg.__dict__}")
|
| 227 |
+
|
| 228 |
+
self.cfg = cfg
|
| 229 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
| 230 |
+
self.embed = feature_enc_layers[-1][0]
|
| 231 |
+
|
| 232 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
| 233 |
+
conv_layers=feature_enc_layers,
|
| 234 |
+
dropout=0.0,
|
| 235 |
+
mode=cfg.extractor_mode,
|
| 236 |
+
conv_bias=cfg.conv_bias,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.post_extract_proj = (
|
| 240 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
| 241 |
+
if self.embed != cfg.encoder_embed_dim
|
| 242 |
+
else None
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
self.mask_prob = cfg.mask_prob
|
| 246 |
+
self.mask_selection = cfg.mask_selection
|
| 247 |
+
self.mask_other = cfg.mask_other
|
| 248 |
+
self.mask_length = cfg.mask_length
|
| 249 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
| 250 |
+
self.mask_min_space = cfg.mask_min_space
|
| 251 |
+
|
| 252 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
| 253 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
| 254 |
+
self.mask_channel_other = cfg.mask_channel_other
|
| 255 |
+
self.mask_channel_length = cfg.mask_channel_length
|
| 256 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
| 257 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
| 258 |
+
|
| 259 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 260 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
| 261 |
+
|
| 262 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
| 263 |
+
|
| 264 |
+
self.mask_emb = nn.Parameter(
|
| 265 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.encoder = TransformerEncoder(cfg)
|
| 269 |
+
self.layer_norm = LayerNorm(self.embed)
|
| 270 |
+
|
| 271 |
+
def apply_mask(self, x, padding_mask):
|
| 272 |
+
B, T, C = x.shape
|
| 273 |
+
if self.mask_prob > 0:
|
| 274 |
+
mask_indices = compute_mask_indices(
|
| 275 |
+
(B, T),
|
| 276 |
+
padding_mask,
|
| 277 |
+
self.mask_prob,
|
| 278 |
+
self.mask_length,
|
| 279 |
+
self.mask_selection,
|
| 280 |
+
self.mask_other,
|
| 281 |
+
min_masks=2,
|
| 282 |
+
no_overlap=self.no_mask_overlap,
|
| 283 |
+
min_space=self.mask_min_space,
|
| 284 |
+
)
|
| 285 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
| 286 |
+
x[mask_indices] = self.mask_emb
|
| 287 |
+
else:
|
| 288 |
+
mask_indices = None
|
| 289 |
+
|
| 290 |
+
if self.mask_channel_prob > 0:
|
| 291 |
+
mask_channel_indices = compute_mask_indices(
|
| 292 |
+
(B, C),
|
| 293 |
+
None,
|
| 294 |
+
self.mask_channel_prob,
|
| 295 |
+
self.mask_channel_length,
|
| 296 |
+
self.mask_channel_selection,
|
| 297 |
+
self.mask_channel_other,
|
| 298 |
+
no_overlap=self.no_mask_channel_overlap,
|
| 299 |
+
min_space=self.mask_channel_min_space,
|
| 300 |
+
)
|
| 301 |
+
mask_channel_indices = (
|
| 302 |
+
torch.from_numpy(mask_channel_indices)
|
| 303 |
+
.to(x.device)
|
| 304 |
+
.unsqueeze(1)
|
| 305 |
+
.expand(-1, T, -1)
|
| 306 |
+
)
|
| 307 |
+
x[mask_channel_indices] = 0
|
| 308 |
+
|
| 309 |
+
return x, mask_indices
|
| 310 |
+
|
| 311 |
+
def forward_padding_mask(
|
| 312 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
| 313 |
+
) -> torch.Tensor:
|
| 314 |
+
extra = padding_mask.size(1) % features.size(1)
|
| 315 |
+
if extra > 0:
|
| 316 |
+
padding_mask = padding_mask[:, :-extra]
|
| 317 |
+
padding_mask = padding_mask.view(
|
| 318 |
+
padding_mask.size(0), features.size(1), -1
|
| 319 |
+
)
|
| 320 |
+
padding_mask = padding_mask.all(-1)
|
| 321 |
+
return padding_mask
|
| 322 |
+
|
| 323 |
+
def extract_features(
|
| 324 |
+
self,
|
| 325 |
+
source: torch.Tensor,
|
| 326 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 327 |
+
mask: bool = False,
|
| 328 |
+
ret_conv: bool = False,
|
| 329 |
+
output_layer: Optional[int] = None,
|
| 330 |
+
ret_layer_results: bool = False,
|
| 331 |
+
):
|
| 332 |
+
|
| 333 |
+
if self.feature_grad_mult > 0:
|
| 334 |
+
features = self.feature_extractor(source)
|
| 335 |
+
if self.feature_grad_mult != 1.0:
|
| 336 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
| 337 |
+
else:
|
| 338 |
+
with torch.no_grad():
|
| 339 |
+
features = self.feature_extractor(source)
|
| 340 |
+
|
| 341 |
+
features = features.transpose(1, 2)
|
| 342 |
+
features = self.layer_norm(features)
|
| 343 |
+
|
| 344 |
+
if padding_mask is not None:
|
| 345 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 346 |
+
|
| 347 |
+
if self.post_extract_proj is not None:
|
| 348 |
+
features = self.post_extract_proj(features)
|
| 349 |
+
|
| 350 |
+
features = self.dropout_input(features)
|
| 351 |
+
|
| 352 |
+
if mask:
|
| 353 |
+
x, mask_indices = self.apply_mask(
|
| 354 |
+
features, padding_mask
|
| 355 |
+
)
|
| 356 |
+
else:
|
| 357 |
+
x = features
|
| 358 |
+
|
| 359 |
+
# feature: (B, T, D), float
|
| 360 |
+
# target: (B, T), long
|
| 361 |
+
# x: (B, T, D), float
|
| 362 |
+
# padding_mask: (B, T), bool
|
| 363 |
+
# mask_indices: (B, T), bool
|
| 364 |
+
x, layer_results = self.encoder(
|
| 365 |
+
x,
|
| 366 |
+
padding_mask=padding_mask,
|
| 367 |
+
layer=None if output_layer is None else output_layer - 1
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
| 371 |
+
|
| 372 |
+
feature = res["features"] if ret_conv else res["x"]
|
| 373 |
+
if ret_layer_results:
|
| 374 |
+
feature = (feature, res["layer_results"])
|
| 375 |
+
return feature, res["padding_mask"]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class ConvFeatureExtractionModel(nn.Module):
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
conv_layers: List[Tuple[int, int, int]],
|
| 382 |
+
dropout: float = 0.0,
|
| 383 |
+
mode: str = "default",
|
| 384 |
+
conv_bias: bool = False,
|
| 385 |
+
conv_type: str = "default"
|
| 386 |
+
):
|
| 387 |
+
super().__init__()
|
| 388 |
+
|
| 389 |
+
assert mode in {"default", "layer_norm"}
|
| 390 |
+
|
| 391 |
+
def block(
|
| 392 |
+
n_in,
|
| 393 |
+
n_out,
|
| 394 |
+
k,
|
| 395 |
+
stride,
|
| 396 |
+
is_layer_norm=False,
|
| 397 |
+
is_group_norm=False,
|
| 398 |
+
conv_bias=False,
|
| 399 |
+
):
|
| 400 |
+
def make_conv():
|
| 401 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
| 402 |
+
nn.init.kaiming_normal_(conv.weight)
|
| 403 |
+
return conv
|
| 404 |
+
|
| 405 |
+
assert (
|
| 406 |
+
is_layer_norm and is_group_norm
|
| 407 |
+
) == False, "layer norm and group norm are exclusive"
|
| 408 |
+
|
| 409 |
+
if is_layer_norm:
|
| 410 |
+
return nn.Sequential(
|
| 411 |
+
make_conv(),
|
| 412 |
+
nn.Dropout(p=dropout),
|
| 413 |
+
nn.Sequential(
|
| 414 |
+
TransposeLast(),
|
| 415 |
+
Fp32LayerNorm(dim, elementwise_affine=True),
|
| 416 |
+
TransposeLast(),
|
| 417 |
+
),
|
| 418 |
+
nn.GELU(),
|
| 419 |
+
)
|
| 420 |
+
elif is_group_norm:
|
| 421 |
+
return nn.Sequential(
|
| 422 |
+
make_conv(),
|
| 423 |
+
nn.Dropout(p=dropout),
|
| 424 |
+
Fp32GroupNorm(dim, dim, affine=True),
|
| 425 |
+
nn.GELU(),
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
| 429 |
+
|
| 430 |
+
self.conv_type = conv_type
|
| 431 |
+
if self.conv_type == "default":
|
| 432 |
+
in_d = 1
|
| 433 |
+
self.conv_layers = nn.ModuleList()
|
| 434 |
+
for i, cl in enumerate(conv_layers):
|
| 435 |
+
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
| 436 |
+
(dim, k, stride) = cl
|
| 437 |
+
|
| 438 |
+
self.conv_layers.append(
|
| 439 |
+
block(
|
| 440 |
+
in_d,
|
| 441 |
+
dim,
|
| 442 |
+
k,
|
| 443 |
+
stride,
|
| 444 |
+
is_layer_norm=mode == "layer_norm",
|
| 445 |
+
is_group_norm=mode == "default" and i == 0,
|
| 446 |
+
conv_bias=conv_bias,
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
in_d = dim
|
| 450 |
+
elif self.conv_type == "conv2d":
|
| 451 |
+
in_d = 1
|
| 452 |
+
self.conv_layers = nn.ModuleList()
|
| 453 |
+
for i, cl in enumerate(conv_layers):
|
| 454 |
+
assert len(cl) == 3
|
| 455 |
+
(dim, k, stride) = cl
|
| 456 |
+
|
| 457 |
+
self.conv_layers.append(
|
| 458 |
+
torch.nn.Conv2d(in_d, dim, k, stride)
|
| 459 |
+
)
|
| 460 |
+
self.conv_layers.append(torch.nn.ReLU())
|
| 461 |
+
in_d = dim
|
| 462 |
+
elif self.conv_type == "custom":
|
| 463 |
+
in_d = 1
|
| 464 |
+
idim = 80
|
| 465 |
+
self.conv_layers = nn.ModuleList()
|
| 466 |
+
for i, cl in enumerate(conv_layers):
|
| 467 |
+
assert len(cl) == 3
|
| 468 |
+
(dim, k, stride) = cl
|
| 469 |
+
self.conv_layers.append(
|
| 470 |
+
torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
|
| 471 |
+
)
|
| 472 |
+
self.conv_layers.append(
|
| 473 |
+
torch.nn.LayerNorm([dim, idim])
|
| 474 |
+
)
|
| 475 |
+
self.conv_layers.append(torch.nn.ReLU())
|
| 476 |
+
in_d = dim
|
| 477 |
+
if (i + 1) % 2 == 0:
|
| 478 |
+
self.conv_layers.append(
|
| 479 |
+
torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 480 |
+
)
|
| 481 |
+
idim = int(math.ceil(idim / 2))
|
| 482 |
+
else:
|
| 483 |
+
pass
|
| 484 |
+
|
| 485 |
+
def forward(self, x, mask=None):
|
| 486 |
+
|
| 487 |
+
# BxT -> BxCxT
|
| 488 |
+
x = x.unsqueeze(1)
|
| 489 |
+
if self.conv_type == "custom":
|
| 490 |
+
for conv in self.conv_layers:
|
| 491 |
+
if isinstance(conv, nn.LayerNorm):
|
| 492 |
+
x = x.transpose(1, 2)
|
| 493 |
+
x = conv(x).transpose(1, 2)
|
| 494 |
+
else:
|
| 495 |
+
x = conv(x)
|
| 496 |
+
x = x.transpose(2, 3).contiguous()
|
| 497 |
+
x = x.view(x.size(0), -1, x.size(-1))
|
| 498 |
+
else:
|
| 499 |
+
for conv in self.conv_layers:
|
| 500 |
+
x = conv(x)
|
| 501 |
+
if self.conv_type == "conv2d":
|
| 502 |
+
b, c, t, f = x.size()
|
| 503 |
+
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
| 504 |
+
return x
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class TransformerEncoder(nn.Module):
|
| 508 |
+
def __init__(self, args):
|
| 509 |
+
super().__init__()
|
| 510 |
+
|
| 511 |
+
self.dropout = args.dropout
|
| 512 |
+
self.embedding_dim = args.encoder_embed_dim
|
| 513 |
+
|
| 514 |
+
self.pos_conv = nn.Conv1d(
|
| 515 |
+
self.embedding_dim,
|
| 516 |
+
self.embedding_dim,
|
| 517 |
+
kernel_size=args.conv_pos,
|
| 518 |
+
padding=args.conv_pos // 2,
|
| 519 |
+
groups=args.conv_pos_groups,
|
| 520 |
+
)
|
| 521 |
+
dropout = 0
|
| 522 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
| 523 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
| 524 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
| 525 |
+
|
| 526 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
| 527 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
| 528 |
+
|
| 529 |
+
if hasattr(args, "relative_position_embedding"):
|
| 530 |
+
self.relative_position_embedding = args.relative_position_embedding
|
| 531 |
+
self.num_buckets = args.num_buckets
|
| 532 |
+
self.max_distance = args.max_distance
|
| 533 |
+
else:
|
| 534 |
+
self.relative_position_embedding = False
|
| 535 |
+
self.num_buckets = 0
|
| 536 |
+
self.max_distance = 0
|
| 537 |
+
|
| 538 |
+
self.layers = nn.ModuleList(
|
| 539 |
+
[
|
| 540 |
+
TransformerSentenceEncoderLayer(
|
| 541 |
+
embedding_dim=self.embedding_dim,
|
| 542 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
| 543 |
+
num_attention_heads=args.encoder_attention_heads,
|
| 544 |
+
dropout=self.dropout,
|
| 545 |
+
attention_dropout=args.attention_dropout,
|
| 546 |
+
activation_dropout=args.activation_dropout,
|
| 547 |
+
activation_fn=args.activation_fn,
|
| 548 |
+
layer_norm_first=args.layer_norm_first,
|
| 549 |
+
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
| 550 |
+
num_buckets=self.num_buckets,
|
| 551 |
+
max_distance=self.max_distance,
|
| 552 |
+
gru_rel_pos=args.gru_rel_pos,
|
| 553 |
+
)
|
| 554 |
+
for i in range(args.encoder_layers)
|
| 555 |
+
]
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
self.layer_norm_first = args.layer_norm_first
|
| 559 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
| 560 |
+
self.layerdrop = args.encoder_layerdrop
|
| 561 |
+
|
| 562 |
+
self.apply(init_bert_params)
|
| 563 |
+
|
| 564 |
+
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
| 565 |
+
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
| 566 |
+
|
| 567 |
+
if self.layer_norm_first and layer is None:
|
| 568 |
+
x = self.layer_norm(x)
|
| 569 |
+
|
| 570 |
+
return x, layer_results
|
| 571 |
+
|
| 572 |
+
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
| 573 |
+
|
| 574 |
+
if padding_mask is not None:
|
| 575 |
+
x[padding_mask] = 0
|
| 576 |
+
|
| 577 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
| 578 |
+
x_conv = x_conv.transpose(1, 2)
|
| 579 |
+
x += x_conv
|
| 580 |
+
|
| 581 |
+
if not self.layer_norm_first:
|
| 582 |
+
x = self.layer_norm(x)
|
| 583 |
+
|
| 584 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 585 |
+
|
| 586 |
+
# B x T x C -> T x B x C
|
| 587 |
+
x = x.transpose(0, 1)
|
| 588 |
+
|
| 589 |
+
layer_results = []
|
| 590 |
+
z = None
|
| 591 |
+
if tgt_layer is not None:
|
| 592 |
+
layer_results.append((x, z))
|
| 593 |
+
r = None
|
| 594 |
+
pos_bias = None
|
| 595 |
+
for i, layer in enumerate(self.layers):
|
| 596 |
+
dropout_probability = np.random.random()
|
| 597 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
| 598 |
+
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
|
| 599 |
+
self_attn_mask=streaming_mask, pos_bias=pos_bias)
|
| 600 |
+
if tgt_layer is not None:
|
| 601 |
+
layer_results.append((x, z))
|
| 602 |
+
if i == tgt_layer:
|
| 603 |
+
r = x
|
| 604 |
+
break
|
| 605 |
+
|
| 606 |
+
if r is not None:
|
| 607 |
+
x = r
|
| 608 |
+
|
| 609 |
+
# T x B x C -> B x T x C
|
| 610 |
+
x = x.transpose(0, 1)
|
| 611 |
+
|
| 612 |
+
return x, layer_results
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
| 616 |
+
"""
|
| 617 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
| 618 |
+
models.
|
| 619 |
+
"""
|
| 620 |
+
|
| 621 |
+
def __init__(
|
| 622 |
+
self,
|
| 623 |
+
embedding_dim: float = 768,
|
| 624 |
+
ffn_embedding_dim: float = 3072,
|
| 625 |
+
num_attention_heads: float = 8,
|
| 626 |
+
dropout: float = 0.1,
|
| 627 |
+
attention_dropout: float = 0.1,
|
| 628 |
+
activation_dropout: float = 0.1,
|
| 629 |
+
activation_fn: str = "relu",
|
| 630 |
+
layer_norm_first: bool = False,
|
| 631 |
+
has_relative_attention_bias: bool = False,
|
| 632 |
+
num_buckets: int = 0,
|
| 633 |
+
max_distance: int = 0,
|
| 634 |
+
rescale_init: bool = False,
|
| 635 |
+
gru_rel_pos: bool = False,
|
| 636 |
+
) -> None:
|
| 637 |
+
|
| 638 |
+
super().__init__()
|
| 639 |
+
# Initialize parameters
|
| 640 |
+
self.embedding_dim = embedding_dim
|
| 641 |
+
self.dropout = dropout
|
| 642 |
+
self.activation_dropout = activation_dropout
|
| 643 |
+
|
| 644 |
+
# Initialize blocks
|
| 645 |
+
self.activation_name = activation_fn
|
| 646 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
| 647 |
+
self.self_attn = MultiheadAttention(
|
| 648 |
+
self.embedding_dim,
|
| 649 |
+
num_attention_heads,
|
| 650 |
+
dropout=attention_dropout,
|
| 651 |
+
self_attention=True,
|
| 652 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
| 653 |
+
num_buckets=num_buckets,
|
| 654 |
+
max_distance=max_distance,
|
| 655 |
+
rescale_init=rescale_init,
|
| 656 |
+
gru_rel_pos=gru_rel_pos,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 660 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
| 661 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 662 |
+
|
| 663 |
+
self.layer_norm_first = layer_norm_first
|
| 664 |
+
|
| 665 |
+
# layer norm associated with the self attention layer
|
| 666 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
| 667 |
+
|
| 668 |
+
if self.activation_name == "glu":
|
| 669 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
| 670 |
+
else:
|
| 671 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
| 672 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
| 673 |
+
|
| 674 |
+
# layer norm associated with the position wise feed-forward NN
|
| 675 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
| 676 |
+
|
| 677 |
+
def forward(
|
| 678 |
+
self,
|
| 679 |
+
x: torch.Tensor,
|
| 680 |
+
self_attn_mask: torch.Tensor = None,
|
| 681 |
+
self_attn_padding_mask: torch.Tensor = None,
|
| 682 |
+
need_weights: bool = False,
|
| 683 |
+
pos_bias=None
|
| 684 |
+
):
|
| 685 |
+
"""
|
| 686 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
| 687 |
+
modules similar to the original Transformer imlementation.
|
| 688 |
+
"""
|
| 689 |
+
residual = x
|
| 690 |
+
|
| 691 |
+
if self.layer_norm_first:
|
| 692 |
+
x = self.self_attn_layer_norm(x)
|
| 693 |
+
x, attn, pos_bias = self.self_attn(
|
| 694 |
+
query=x,
|
| 695 |
+
key=x,
|
| 696 |
+
value=x,
|
| 697 |
+
key_padding_mask=self_attn_padding_mask,
|
| 698 |
+
need_weights=False,
|
| 699 |
+
attn_mask=self_attn_mask,
|
| 700 |
+
position_bias=pos_bias
|
| 701 |
+
)
|
| 702 |
+
x = self.dropout1(x)
|
| 703 |
+
x = residual + x
|
| 704 |
+
|
| 705 |
+
residual = x
|
| 706 |
+
x = self.final_layer_norm(x)
|
| 707 |
+
if self.activation_name == "glu":
|
| 708 |
+
x = self.fc1(x)
|
| 709 |
+
else:
|
| 710 |
+
x = self.activation_fn(self.fc1(x))
|
| 711 |
+
x = self.dropout2(x)
|
| 712 |
+
x = self.fc2(x)
|
| 713 |
+
x = self.dropout3(x)
|
| 714 |
+
x = residual + x
|
| 715 |
+
else:
|
| 716 |
+
x, attn, pos_bias = self.self_attn(
|
| 717 |
+
query=x,
|
| 718 |
+
key=x,
|
| 719 |
+
value=x,
|
| 720 |
+
key_padding_mask=self_attn_padding_mask,
|
| 721 |
+
need_weights=need_weights,
|
| 722 |
+
attn_mask=self_attn_mask,
|
| 723 |
+
position_bias=pos_bias
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
x = self.dropout1(x)
|
| 727 |
+
x = residual + x
|
| 728 |
+
|
| 729 |
+
x = self.self_attn_layer_norm(x)
|
| 730 |
+
|
| 731 |
+
residual = x
|
| 732 |
+
if self.activation_name == "glu":
|
| 733 |
+
x = self.fc1(x)
|
| 734 |
+
else:
|
| 735 |
+
x = self.activation_fn(self.fc1(x))
|
| 736 |
+
x = self.dropout2(x)
|
| 737 |
+
x = self.fc2(x)
|
| 738 |
+
x = self.dropout3(x)
|
| 739 |
+
x = residual + x
|
| 740 |
+
x = self.final_layer_norm(x)
|
| 741 |
+
|
| 742 |
+
return x, attn, pos_bias
|
| 743 |
+
|
UniSpeech/WavLM/WavLM_ASR.PNG
ADDED
|
|
Git LFS Details
|
UniSpeech/WavLM/WavLM_SUPERB_Leaderboard.png
ADDED
|
Git LFS Details
|
UniSpeech/WavLM/WavLM_SUPERB_Results.png
ADDED
|
Git LFS Details
|
UniSpeech/WavLM/modules.py
ADDED
|
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
| 4 |
+
# Copyright (c) 2021 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import warnings
|
| 12 |
+
from typing import Dict, Optional, Tuple
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
from torch.nn import Parameter
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TransposeLast(nn.Module):
|
| 20 |
+
def __init__(self, deconstruct_idx=None):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.deconstruct_idx = deconstruct_idx
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
if self.deconstruct_idx is not None:
|
| 26 |
+
x = x[self.deconstruct_idx]
|
| 27 |
+
return x.transpose(-2, -1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Fp32LayerNorm(nn.LayerNorm):
|
| 31 |
+
def __init__(self, *args, **kwargs):
|
| 32 |
+
super().__init__(*args, **kwargs)
|
| 33 |
+
|
| 34 |
+
def forward(self, input):
|
| 35 |
+
output = F.layer_norm(
|
| 36 |
+
input.float(),
|
| 37 |
+
self.normalized_shape,
|
| 38 |
+
self.weight.float() if self.weight is not None else None,
|
| 39 |
+
self.bias.float() if self.bias is not None else None,
|
| 40 |
+
self.eps,
|
| 41 |
+
)
|
| 42 |
+
return output.type_as(input)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
| 46 |
+
def __init__(self, *args, **kwargs):
|
| 47 |
+
super().__init__(*args, **kwargs)
|
| 48 |
+
|
| 49 |
+
def forward(self, input):
|
| 50 |
+
output = F.group_norm(
|
| 51 |
+
input.float(),
|
| 52 |
+
self.num_groups,
|
| 53 |
+
self.weight.float() if self.weight is not None else None,
|
| 54 |
+
self.bias.float() if self.bias is not None else None,
|
| 55 |
+
self.eps,
|
| 56 |
+
)
|
| 57 |
+
return output.type_as(input)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class GradMultiply(torch.autograd.Function):
|
| 61 |
+
@staticmethod
|
| 62 |
+
def forward(ctx, x, scale):
|
| 63 |
+
ctx.scale = scale
|
| 64 |
+
res = x.new(x)
|
| 65 |
+
return res
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def backward(ctx, grad):
|
| 69 |
+
return grad * ctx.scale, None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class SamePad(nn.Module):
|
| 73 |
+
def __init__(self, kernel_size, causal=False):
|
| 74 |
+
super().__init__()
|
| 75 |
+
if causal:
|
| 76 |
+
self.remove = kernel_size - 1
|
| 77 |
+
else:
|
| 78 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
if self.remove > 0:
|
| 82 |
+
x = x[:, :, : -self.remove]
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Swish(nn.Module):
|
| 87 |
+
"""Swish function
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self):
|
| 91 |
+
"""Construct an MultiHeadedAttention object."""
|
| 92 |
+
super(Swish, self).__init__()
|
| 93 |
+
self.act = torch.nn.Sigmoid()
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
return x * self.act(x)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class GLU_Linear(nn.Module):
|
| 100 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
| 101 |
+
super(GLU_Linear, self).__init__()
|
| 102 |
+
|
| 103 |
+
self.glu_type = glu_type
|
| 104 |
+
self.output_dim = output_dim
|
| 105 |
+
|
| 106 |
+
if glu_type == "sigmoid":
|
| 107 |
+
self.glu_act = torch.nn.Sigmoid()
|
| 108 |
+
elif glu_type == "swish":
|
| 109 |
+
self.glu_act = Swish()
|
| 110 |
+
elif glu_type == "relu":
|
| 111 |
+
self.glu_act = torch.nn.ReLU()
|
| 112 |
+
elif glu_type == "gelu":
|
| 113 |
+
self.glu_act = torch.nn.GELU()
|
| 114 |
+
|
| 115 |
+
if bias_in_glu:
|
| 116 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
| 117 |
+
else:
|
| 118 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
| 122 |
+
x = self.linear(x)
|
| 123 |
+
|
| 124 |
+
if self.glu_type == "bilinear":
|
| 125 |
+
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
| 126 |
+
else:
|
| 127 |
+
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
| 128 |
+
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def gelu_accurate(x):
|
| 133 |
+
if not hasattr(gelu_accurate, "_a"):
|
| 134 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
| 135 |
+
return (
|
| 136 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_activation_fn(activation: str):
|
| 145 |
+
"""Returns the activation function corresponding to `activation`"""
|
| 146 |
+
|
| 147 |
+
if activation == "relu":
|
| 148 |
+
return F.relu
|
| 149 |
+
elif activation == "gelu":
|
| 150 |
+
return gelu
|
| 151 |
+
elif activation == "gelu_fast":
|
| 152 |
+
warnings.warn(
|
| 153 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
| 154 |
+
)
|
| 155 |
+
return gelu_accurate
|
| 156 |
+
elif activation == "gelu_accurate":
|
| 157 |
+
return gelu_accurate
|
| 158 |
+
elif activation == "tanh":
|
| 159 |
+
return torch.tanh
|
| 160 |
+
elif activation == "linear":
|
| 161 |
+
return lambda x: x
|
| 162 |
+
elif activation == "glu":
|
| 163 |
+
return lambda x: x
|
| 164 |
+
else:
|
| 165 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def init_bert_params(module):
|
| 169 |
+
"""
|
| 170 |
+
Initialize the weights specific to the BERT Model.
|
| 171 |
+
This overrides the default initializations depending on the specified arguments.
|
| 172 |
+
1. If normal_init_linear_weights is set then weights of linear
|
| 173 |
+
layer will be initialized using the normal distribution and
|
| 174 |
+
bais will be set to the specified value.
|
| 175 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
| 176 |
+
layer will be initialized using the normal distribution.
|
| 177 |
+
3. If normal_init_proj_weights is set then weights of
|
| 178 |
+
in_project_weight for MultiHeadAttention initialized using
|
| 179 |
+
the normal distribution (to be validated).
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def normal_(data):
|
| 183 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
| 184 |
+
# so that the RNG is consistent with and without FSDP
|
| 185 |
+
data.copy_(
|
| 186 |
+
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if isinstance(module, nn.Linear):
|
| 190 |
+
normal_(module.weight.data)
|
| 191 |
+
if module.bias is not None:
|
| 192 |
+
module.bias.data.zero_()
|
| 193 |
+
if isinstance(module, nn.Embedding):
|
| 194 |
+
normal_(module.weight.data)
|
| 195 |
+
if module.padding_idx is not None:
|
| 196 |
+
module.weight.data[module.padding_idx].zero_()
|
| 197 |
+
if isinstance(module, MultiheadAttention):
|
| 198 |
+
normal_(module.q_proj.weight.data)
|
| 199 |
+
normal_(module.k_proj.weight.data)
|
| 200 |
+
normal_(module.v_proj.weight.data)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def quant_noise(module, p, block_size):
|
| 204 |
+
"""
|
| 205 |
+
Wraps modules and applies quantization noise to the weights for
|
| 206 |
+
subsequent quantization with Iterative Product Quantization as
|
| 207 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
- module: nn.Module
|
| 211 |
+
- p: amount of Quantization Noise
|
| 212 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
| 213 |
+
|
| 214 |
+
Remarks:
|
| 215 |
+
- Module weights must have the right sizes wrt the block size
|
| 216 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
| 217 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
| 218 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
| 219 |
+
- We implement the simplest form of noise here as stated in the paper
|
| 220 |
+
which consists in randomly dropping blocks
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
# if no quantization noise, don't register hook
|
| 224 |
+
if p <= 0:
|
| 225 |
+
return module
|
| 226 |
+
|
| 227 |
+
# supported modules
|
| 228 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
| 229 |
+
|
| 230 |
+
# test whether module.weight has the right sizes wrt block_size
|
| 231 |
+
is_conv = module.weight.ndim == 4
|
| 232 |
+
|
| 233 |
+
# 2D matrix
|
| 234 |
+
if not is_conv:
|
| 235 |
+
assert (
|
| 236 |
+
module.weight.size(1) % block_size == 0
|
| 237 |
+
), "Input features must be a multiple of block sizes"
|
| 238 |
+
|
| 239 |
+
# 4D matrix
|
| 240 |
+
else:
|
| 241 |
+
# 1x1 convolutions
|
| 242 |
+
if module.kernel_size == (1, 1):
|
| 243 |
+
assert (
|
| 244 |
+
module.in_channels % block_size == 0
|
| 245 |
+
), "Input channels must be a multiple of block sizes"
|
| 246 |
+
# regular convolutions
|
| 247 |
+
else:
|
| 248 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
| 249 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
| 250 |
+
|
| 251 |
+
def _forward_pre_hook(mod, input):
|
| 252 |
+
# no noise for evaluation
|
| 253 |
+
if mod.training:
|
| 254 |
+
if not is_conv:
|
| 255 |
+
# gather weight and sizes
|
| 256 |
+
weight = mod.weight
|
| 257 |
+
in_features = weight.size(1)
|
| 258 |
+
out_features = weight.size(0)
|
| 259 |
+
|
| 260 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 261 |
+
mask = torch.zeros(
|
| 262 |
+
in_features // block_size * out_features, device=weight.device
|
| 263 |
+
)
|
| 264 |
+
mask.bernoulli_(p)
|
| 265 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 266 |
+
|
| 267 |
+
else:
|
| 268 |
+
# gather weight and sizes
|
| 269 |
+
weight = mod.weight
|
| 270 |
+
in_channels = mod.in_channels
|
| 271 |
+
out_channels = mod.out_channels
|
| 272 |
+
|
| 273 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 274 |
+
if mod.kernel_size == (1, 1):
|
| 275 |
+
mask = torch.zeros(
|
| 276 |
+
int(in_channels // block_size * out_channels),
|
| 277 |
+
device=weight.device,
|
| 278 |
+
)
|
| 279 |
+
mask.bernoulli_(p)
|
| 280 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 281 |
+
else:
|
| 282 |
+
mask = torch.zeros(
|
| 283 |
+
weight.size(0), weight.size(1), device=weight.device
|
| 284 |
+
)
|
| 285 |
+
mask.bernoulli_(p)
|
| 286 |
+
mask = (
|
| 287 |
+
mask.unsqueeze(2)
|
| 288 |
+
.unsqueeze(3)
|
| 289 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# scale weights and apply mask
|
| 293 |
+
mask = mask.to(
|
| 294 |
+
torch.bool
|
| 295 |
+
) # x.bool() is not currently supported in TorchScript
|
| 296 |
+
s = 1 / (1 - p)
|
| 297 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 298 |
+
|
| 299 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
| 300 |
+
return module
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class MultiheadAttention(nn.Module):
|
| 304 |
+
"""Multi-headed attention.
|
| 305 |
+
|
| 306 |
+
See "Attention Is All You Need" for more details.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
def __init__(
|
| 310 |
+
self,
|
| 311 |
+
embed_dim,
|
| 312 |
+
num_heads,
|
| 313 |
+
kdim=None,
|
| 314 |
+
vdim=None,
|
| 315 |
+
dropout=0.0,
|
| 316 |
+
bias=True,
|
| 317 |
+
add_bias_kv=False,
|
| 318 |
+
add_zero_attn=False,
|
| 319 |
+
self_attention=False,
|
| 320 |
+
encoder_decoder_attention=False,
|
| 321 |
+
q_noise=0.0,
|
| 322 |
+
qn_block_size=8,
|
| 323 |
+
has_relative_attention_bias=False,
|
| 324 |
+
num_buckets=32,
|
| 325 |
+
max_distance=128,
|
| 326 |
+
gru_rel_pos=False,
|
| 327 |
+
rescale_init=False,
|
| 328 |
+
):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.embed_dim = embed_dim
|
| 331 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 332 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 333 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 334 |
+
|
| 335 |
+
self.num_heads = num_heads
|
| 336 |
+
self.dropout_module = nn.Dropout(dropout)
|
| 337 |
+
|
| 338 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
| 339 |
+
self.num_buckets = num_buckets
|
| 340 |
+
self.max_distance = max_distance
|
| 341 |
+
if self.has_relative_attention_bias:
|
| 342 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
| 343 |
+
|
| 344 |
+
self.head_dim = embed_dim // num_heads
|
| 345 |
+
self.q_head_dim = self.head_dim
|
| 346 |
+
self.k_head_dim = self.head_dim
|
| 347 |
+
assert (
|
| 348 |
+
self.head_dim * num_heads == self.embed_dim
|
| 349 |
+
), "embed_dim must be divisible by num_heads"
|
| 350 |
+
self.scaling = self.head_dim ** -0.5
|
| 351 |
+
|
| 352 |
+
self.self_attention = self_attention
|
| 353 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 354 |
+
|
| 355 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
| 356 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
k_bias = True
|
| 360 |
+
if rescale_init:
|
| 361 |
+
k_bias = False
|
| 362 |
+
|
| 363 |
+
k_embed_dim = embed_dim
|
| 364 |
+
q_embed_dim = embed_dim
|
| 365 |
+
|
| 366 |
+
self.k_proj = quant_noise(
|
| 367 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
| 368 |
+
)
|
| 369 |
+
self.v_proj = quant_noise(
|
| 370 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 371 |
+
)
|
| 372 |
+
self.q_proj = quant_noise(
|
| 373 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
self.out_proj = quant_noise(
|
| 377 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
if add_bias_kv:
|
| 381 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 382 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 383 |
+
else:
|
| 384 |
+
self.bias_k = self.bias_v = None
|
| 385 |
+
|
| 386 |
+
self.add_zero_attn = add_zero_attn
|
| 387 |
+
|
| 388 |
+
self.gru_rel_pos = gru_rel_pos
|
| 389 |
+
if self.gru_rel_pos:
|
| 390 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
| 391 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
| 392 |
+
|
| 393 |
+
self.reset_parameters()
|
| 394 |
+
|
| 395 |
+
def reset_parameters(self):
|
| 396 |
+
if self.qkv_same_dim:
|
| 397 |
+
# Empirically observed the convergence to be much better with
|
| 398 |
+
# the scaled initialization
|
| 399 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 400 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 401 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 402 |
+
else:
|
| 403 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 404 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 405 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 406 |
+
|
| 407 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 408 |
+
if self.out_proj.bias is not None:
|
| 409 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 410 |
+
if self.bias_k is not None:
|
| 411 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 412 |
+
if self.bias_v is not None:
|
| 413 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 414 |
+
if self.has_relative_attention_bias:
|
| 415 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
| 416 |
+
|
| 417 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
| 418 |
+
num_buckets = self.num_buckets
|
| 419 |
+
max_distance = self.max_distance
|
| 420 |
+
relative_buckets = 0
|
| 421 |
+
|
| 422 |
+
if bidirectional:
|
| 423 |
+
num_buckets = num_buckets // 2
|
| 424 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
| 425 |
+
relative_positions = torch.abs(relative_positions)
|
| 426 |
+
else:
|
| 427 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
| 428 |
+
|
| 429 |
+
max_exact = num_buckets // 2
|
| 430 |
+
is_small = relative_positions < max_exact
|
| 431 |
+
|
| 432 |
+
relative_postion_if_large = max_exact + (
|
| 433 |
+
torch.log(relative_positions.float() / max_exact)
|
| 434 |
+
/ math.log(max_distance / max_exact)
|
| 435 |
+
* (num_buckets - max_exact)
|
| 436 |
+
).to(torch.long)
|
| 437 |
+
relative_postion_if_large = torch.min(
|
| 438 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
| 442 |
+
return relative_buckets
|
| 443 |
+
|
| 444 |
+
def compute_bias(self, query_length, key_length):
|
| 445 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
| 446 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
| 447 |
+
relative_position = memory_position - context_position
|
| 448 |
+
relative_position_bucket = self._relative_positions_bucket(
|
| 449 |
+
relative_position,
|
| 450 |
+
bidirectional=True
|
| 451 |
+
)
|
| 452 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
| 453 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
| 454 |
+
values = values.permute([2, 0, 1])
|
| 455 |
+
return values
|
| 456 |
+
|
| 457 |
+
def forward(
|
| 458 |
+
self,
|
| 459 |
+
query,
|
| 460 |
+
key: Optional[Tensor],
|
| 461 |
+
value: Optional[Tensor],
|
| 462 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 463 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 464 |
+
need_weights: bool = True,
|
| 465 |
+
static_kv: bool = False,
|
| 466 |
+
attn_mask: Optional[Tensor] = None,
|
| 467 |
+
before_softmax: bool = False,
|
| 468 |
+
need_head_weights: bool = False,
|
| 469 |
+
position_bias: Optional[Tensor] = None
|
| 470 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 471 |
+
"""Input shape: Time x Batch x Channel
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 475 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 476 |
+
padding elements are indicated by 1s.
|
| 477 |
+
need_weights (bool, optional): return the attention weights,
|
| 478 |
+
averaged over heads (default: False).
|
| 479 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 480 |
+
implement causal attention, where the mask prevents the
|
| 481 |
+
attention from looking forward in time (default: None).
|
| 482 |
+
before_softmax (bool, optional): return the raw attention
|
| 483 |
+
weights and values before the attention softmax.
|
| 484 |
+
need_head_weights (bool, optional): return the attention
|
| 485 |
+
weights for each head. Implies *need_weights*. Default:
|
| 486 |
+
return the average attention weights over all heads.
|
| 487 |
+
"""
|
| 488 |
+
if need_head_weights:
|
| 489 |
+
need_weights = True
|
| 490 |
+
|
| 491 |
+
is_tpu = query.device.type == "xla"
|
| 492 |
+
|
| 493 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 494 |
+
src_len = tgt_len
|
| 495 |
+
assert embed_dim == self.embed_dim
|
| 496 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 497 |
+
if key is not None:
|
| 498 |
+
src_len, key_bsz, _ = key.size()
|
| 499 |
+
if not torch.jit.is_scripting():
|
| 500 |
+
assert key_bsz == bsz
|
| 501 |
+
assert value is not None
|
| 502 |
+
assert src_len, bsz == value.shape[:2]
|
| 503 |
+
|
| 504 |
+
if self.has_relative_attention_bias and position_bias is None:
|
| 505 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
| 506 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
| 507 |
+
|
| 508 |
+
if (
|
| 509 |
+
not is_tpu # don't use PyTorch version on TPUs
|
| 510 |
+
and incremental_state is None
|
| 511 |
+
and not static_kv
|
| 512 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
| 513 |
+
# treats bias in linear module as method.
|
| 514 |
+
and not torch.jit.is_scripting()
|
| 515 |
+
and self.q_head_dim == self.head_dim
|
| 516 |
+
):
|
| 517 |
+
assert key is not None and value is not None
|
| 518 |
+
assert attn_mask is None
|
| 519 |
+
|
| 520 |
+
attn_mask_rel_pos = None
|
| 521 |
+
if position_bias is not None:
|
| 522 |
+
attn_mask_rel_pos = position_bias
|
| 523 |
+
if self.gru_rel_pos:
|
| 524 |
+
query_layer = query.transpose(0, 1)
|
| 525 |
+
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
| 526 |
+
query_layer = query_layer.view(*new_x_shape)
|
| 527 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
| 528 |
+
_B, _H, _L, __ = query_layer.size()
|
| 529 |
+
|
| 530 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
| 531 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
| 532 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 533 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
| 534 |
+
|
| 535 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
| 536 |
+
k_proj_bias = self.k_proj.bias
|
| 537 |
+
if k_proj_bias is None:
|
| 538 |
+
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
| 539 |
+
|
| 540 |
+
x, attn = F.multi_head_attention_forward(
|
| 541 |
+
query,
|
| 542 |
+
key,
|
| 543 |
+
value,
|
| 544 |
+
self.embed_dim,
|
| 545 |
+
self.num_heads,
|
| 546 |
+
torch.empty([0]),
|
| 547 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
| 548 |
+
self.bias_k,
|
| 549 |
+
self.bias_v,
|
| 550 |
+
self.add_zero_attn,
|
| 551 |
+
self.dropout_module.p,
|
| 552 |
+
self.out_proj.weight,
|
| 553 |
+
self.out_proj.bias,
|
| 554 |
+
self.training,
|
| 555 |
+
# self.training or self.dropout_module.apply_during_inference,
|
| 556 |
+
key_padding_mask,
|
| 557 |
+
need_weights,
|
| 558 |
+
attn_mask_rel_pos,
|
| 559 |
+
use_separate_proj_weight=True,
|
| 560 |
+
q_proj_weight=self.q_proj.weight,
|
| 561 |
+
k_proj_weight=self.k_proj.weight,
|
| 562 |
+
v_proj_weight=self.v_proj.weight,
|
| 563 |
+
)
|
| 564 |
+
return x, attn, position_bias
|
| 565 |
+
|
| 566 |
+
if incremental_state is not None:
|
| 567 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 568 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 569 |
+
# previous time steps are cached - no need to recompute
|
| 570 |
+
# key and value if they are static
|
| 571 |
+
if static_kv:
|
| 572 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 573 |
+
key = value = None
|
| 574 |
+
else:
|
| 575 |
+
saved_state = None
|
| 576 |
+
|
| 577 |
+
if self.self_attention:
|
| 578 |
+
q = self.q_proj(query)
|
| 579 |
+
k = self.k_proj(query)
|
| 580 |
+
v = self.v_proj(query)
|
| 581 |
+
elif self.encoder_decoder_attention:
|
| 582 |
+
# encoder-decoder attention
|
| 583 |
+
q = self.q_proj(query)
|
| 584 |
+
if key is None:
|
| 585 |
+
assert value is None
|
| 586 |
+
k = v = None
|
| 587 |
+
else:
|
| 588 |
+
k = self.k_proj(key)
|
| 589 |
+
v = self.v_proj(key)
|
| 590 |
+
|
| 591 |
+
else:
|
| 592 |
+
assert key is not None and value is not None
|
| 593 |
+
q = self.q_proj(query)
|
| 594 |
+
k = self.k_proj(key)
|
| 595 |
+
v = self.v_proj(value)
|
| 596 |
+
q *= self.scaling
|
| 597 |
+
|
| 598 |
+
if self.bias_k is not None:
|
| 599 |
+
assert self.bias_v is not None
|
| 600 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 601 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 602 |
+
if attn_mask is not None:
|
| 603 |
+
attn_mask = torch.cat(
|
| 604 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 605 |
+
)
|
| 606 |
+
if key_padding_mask is not None:
|
| 607 |
+
key_padding_mask = torch.cat(
|
| 608 |
+
[
|
| 609 |
+
key_padding_mask,
|
| 610 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
| 611 |
+
],
|
| 612 |
+
dim=1,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
q = (
|
| 616 |
+
q.contiguous()
|
| 617 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
| 618 |
+
.transpose(0, 1)
|
| 619 |
+
)
|
| 620 |
+
if k is not None:
|
| 621 |
+
k = (
|
| 622 |
+
k.contiguous()
|
| 623 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
| 624 |
+
.transpose(0, 1)
|
| 625 |
+
)
|
| 626 |
+
if v is not None:
|
| 627 |
+
v = (
|
| 628 |
+
v.contiguous()
|
| 629 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 630 |
+
.transpose(0, 1)
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if saved_state is not None:
|
| 634 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 635 |
+
if "prev_key" in saved_state:
|
| 636 |
+
_prev_key = saved_state["prev_key"]
|
| 637 |
+
assert _prev_key is not None
|
| 638 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
| 639 |
+
if static_kv:
|
| 640 |
+
k = prev_key
|
| 641 |
+
else:
|
| 642 |
+
assert k is not None
|
| 643 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 644 |
+
src_len = k.size(1)
|
| 645 |
+
if "prev_value" in saved_state:
|
| 646 |
+
_prev_value = saved_state["prev_value"]
|
| 647 |
+
assert _prev_value is not None
|
| 648 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
| 649 |
+
if static_kv:
|
| 650 |
+
v = prev_value
|
| 651 |
+
else:
|
| 652 |
+
assert v is not None
|
| 653 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 654 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
| 655 |
+
if "prev_key_padding_mask" in saved_state:
|
| 656 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 657 |
+
assert k is not None and v is not None
|
| 658 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
| 659 |
+
key_padding_mask=key_padding_mask,
|
| 660 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
| 661 |
+
batch_size=bsz,
|
| 662 |
+
src_len=k.size(1),
|
| 663 |
+
static_kv=static_kv,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 667 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 668 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 669 |
+
# In this branch incremental_state is never None
|
| 670 |
+
assert incremental_state is not None
|
| 671 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 672 |
+
assert k is not None
|
| 673 |
+
assert k.size(1) == src_len
|
| 674 |
+
|
| 675 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 676 |
+
# not supporting Optional types.
|
| 677 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
| 678 |
+
key_padding_mask = None
|
| 679 |
+
|
| 680 |
+
if key_padding_mask is not None:
|
| 681 |
+
assert key_padding_mask.size(0) == bsz
|
| 682 |
+
assert key_padding_mask.size(1) == src_len
|
| 683 |
+
|
| 684 |
+
if self.add_zero_attn:
|
| 685 |
+
assert v is not None
|
| 686 |
+
src_len += 1
|
| 687 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 688 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 689 |
+
if attn_mask is not None:
|
| 690 |
+
attn_mask = torch.cat(
|
| 691 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 692 |
+
)
|
| 693 |
+
if key_padding_mask is not None:
|
| 694 |
+
key_padding_mask = torch.cat(
|
| 695 |
+
[
|
| 696 |
+
key_padding_mask,
|
| 697 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
| 698 |
+
key_padding_mask
|
| 699 |
+
),
|
| 700 |
+
],
|
| 701 |
+
dim=1,
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 705 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 706 |
+
|
| 707 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 708 |
+
|
| 709 |
+
if attn_mask is not None:
|
| 710 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 711 |
+
attn_weights += attn_mask
|
| 712 |
+
|
| 713 |
+
if key_padding_mask is not None:
|
| 714 |
+
# don't attend to padding symbols
|
| 715 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 716 |
+
if not is_tpu:
|
| 717 |
+
attn_weights = attn_weights.masked_fill(
|
| 718 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
| 719 |
+
float("-inf"),
|
| 720 |
+
)
|
| 721 |
+
else:
|
| 722 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 723 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
| 724 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 725 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 726 |
+
|
| 727 |
+
if before_softmax:
|
| 728 |
+
return attn_weights, v, position_bias
|
| 729 |
+
|
| 730 |
+
if position_bias is not None:
|
| 731 |
+
if self.gru_rel_pos == 1:
|
| 732 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
| 733 |
+
_B, _H, _L, __ = query_layer.size()
|
| 734 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
| 735 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
| 736 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 737 |
+
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
| 738 |
+
|
| 739 |
+
position_bias = position_bias.view(attn_weights.size())
|
| 740 |
+
|
| 741 |
+
attn_weights = attn_weights + position_bias
|
| 742 |
+
|
| 743 |
+
attn_weights_float = F.softmax(
|
| 744 |
+
attn_weights, dim=-1
|
| 745 |
+
)
|
| 746 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 747 |
+
attn_probs = self.dropout_module(attn_weights)
|
| 748 |
+
|
| 749 |
+
assert v is not None
|
| 750 |
+
attn = torch.bmm(attn_probs, v)
|
| 751 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 752 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 753 |
+
attn = self.out_proj(attn)
|
| 754 |
+
attn_weights: Optional[Tensor] = None
|
| 755 |
+
if need_weights:
|
| 756 |
+
attn_weights = attn_weights_float.view(
|
| 757 |
+
bsz, self.num_heads, tgt_len, src_len
|
| 758 |
+
).transpose(1, 0)
|
| 759 |
+
if not need_head_weights:
|
| 760 |
+
# average attention weights over heads
|
| 761 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 762 |
+
|
| 763 |
+
return attn, attn_weights, position_bias
|
| 764 |
+
|
| 765 |
+
@staticmethod
|
| 766 |
+
def _append_prev_key_padding_mask(
|
| 767 |
+
key_padding_mask: Optional[Tensor],
|
| 768 |
+
prev_key_padding_mask: Optional[Tensor],
|
| 769 |
+
batch_size: int,
|
| 770 |
+
src_len: int,
|
| 771 |
+
static_kv: bool,
|
| 772 |
+
) -> Optional[Tensor]:
|
| 773 |
+
# saved key padding masks have shape (bsz, seq_len)
|
| 774 |
+
if prev_key_padding_mask is not None and static_kv:
|
| 775 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 776 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 777 |
+
new_key_padding_mask = torch.cat(
|
| 778 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
| 779 |
+
)
|
| 780 |
+
# During incremental decoding, as the padding token enters and
|
| 781 |
+
# leaves the frame, there will be a time when prev or current
|
| 782 |
+
# is None
|
| 783 |
+
elif prev_key_padding_mask is not None:
|
| 784 |
+
if src_len > prev_key_padding_mask.size(1):
|
| 785 |
+
filler = torch.zeros(
|
| 786 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 787 |
+
device=prev_key_padding_mask.device,
|
| 788 |
+
)
|
| 789 |
+
new_key_padding_mask = torch.cat(
|
| 790 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
| 791 |
+
)
|
| 792 |
+
else:
|
| 793 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
| 794 |
+
elif key_padding_mask is not None:
|
| 795 |
+
if src_len > key_padding_mask.size(1):
|
| 796 |
+
filler = torch.zeros(
|
| 797 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
| 798 |
+
device=key_padding_mask.device,
|
| 799 |
+
)
|
| 800 |
+
new_key_padding_mask = torch.cat(
|
| 801 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
| 802 |
+
)
|
| 803 |
+
else:
|
| 804 |
+
new_key_padding_mask = key_padding_mask.float()
|
| 805 |
+
else:
|
| 806 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 807 |
+
return new_key_padding_mask
|
| 808 |
+
|
| 809 |
+
def _get_input_buffer(
|
| 810 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
| 811 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 812 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 813 |
+
if result is not None:
|
| 814 |
+
return result
|
| 815 |
+
else:
|
| 816 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
| 817 |
+
return empty_result
|
| 818 |
+
|
| 819 |
+
def _set_input_buffer(
|
| 820 |
+
self,
|
| 821 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 822 |
+
buffer: Dict[str, Optional[Tensor]],
|
| 823 |
+
):
|
| 824 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 825 |
+
|
| 826 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 827 |
+
return attn_weights
|
UniSpeech/azure-pipelines.yml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
trigger:
|
| 2 |
+
- master
|
| 3 |
+
|
| 4 |
+
pool:
|
| 5 |
+
vmImage: 'windows-latest'
|
| 6 |
+
|
| 7 |
+
steps:
|
| 8 |
+
- script: echo Hello, world!
|
| 9 |
+
displayName: 'Run a one-line script'
|
| 10 |
+
|
| 11 |
+
- script: |
|
| 12 |
+
echo Add other tasks to build, test, and deploy your project.
|
| 13 |
+
echo See https://aka.ms/yaml
|
| 14 |
+
displayName: 'Run a multi-line script'
|
| 15 |
+
|
| 16 |
+
- task: CredScan@2
|
| 17 |
+
inputs:
|
| 18 |
+
toolMajorVersion: 'V2'
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
- task: Semmle@0
|
| 22 |
+
env:
|
| 23 |
+
SYSTEM_ACCESSTOKEN: $(PATToken)
|
| 24 |
+
inputs:
|
| 25 |
+
sourceCodeDirectory: '$(Build.SourcesDirectory)'
|
| 26 |
+
language: 'python'
|
| 27 |
+
includeNodeModules: true
|
| 28 |
+
querySuite: 'Recommended'
|
| 29 |
+
timeout: '1800'
|
| 30 |
+
ram: '16384'
|
| 31 |
+
addProjectDirToScanningExclusionList: true
|
| 32 |
+
|
| 33 |
+
- task: ComponentGovernanceComponentDetection@0
|
| 34 |
+
inputs:
|
| 35 |
+
scanType: 'Register'
|
| 36 |
+
verbosity: 'Verbose'
|
| 37 |
+
alertWarningLevel: 'High'
|
| 38 |
+
|
| 39 |
+
- task: PublishSecurityAnalysisLogs@2
|
| 40 |
+
inputs:
|
| 41 |
+
ArtifactName: 'CodeAnalysisLogs'
|
| 42 |
+
ArtifactType: 'Container'
|
| 43 |
+
AllTools: true
|
| 44 |
+
ToolLogsNotFoundAction: 'Standard'
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
UniSpeech/downstreams/speaker_diarization/README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Pre-training Representations for Speaker Diarization
|
| 2 |
+
|
| 3 |
+
### Downstream Model
|
| 4 |
+
|
| 5 |
+
[EEND-vector-clustering](https://arxiv.org/abs/2105.09040)
|
| 6 |
+
|
| 7 |
+
### Pre-trained models
|
| 8 |
+
|
| 9 |
+
- It should be noted that the diarization system is trained on 8k audio data.
|
| 10 |
+
|
| 11 |
+
| Model | 2 spk DER | 3 spk DER | 4 spk DER | 5 spk DER | 6 spk DER | ALL spk DER |
|
| 12 |
+
| ------------------------------------------------------------ | --------- | --------- | --------- | --------- | --------- | ----------- |
|
| 13 |
+
| EEND-vector-clustering | 7.96 | 11.93 | 16.38 | 21.21 | 23.1 | 12.49 |
|
| 14 |
+
| [**UniSpeech-SAT large**](https://drive.google.com/file/d/16OwIyOk2uYm0aWtSPaS0S12xE8RxF7k_/view?usp=sharing) | 5.93 | 10.66 | 12.90 | 16.48 | 23.25 | 10.92 |
|
| 15 |
+
|
| 16 |
+
### How to use?
|
| 17 |
+
|
| 18 |
+
#### Environment Setup
|
| 19 |
+
|
| 20 |
+
1. `pip install --require-hashes -r requirements.txt`
|
| 21 |
+
2. Install fairseq code
|
| 22 |
+
- For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code.
|
| 23 |
+
|
| 24 |
+
#### Example
|
| 25 |
+
|
| 26 |
+
1. First, you should download the pre-trained model in the above table to `checkpoint_path`.
|
| 27 |
+
2. Then, run the following codes:
|
| 28 |
+
- The wav file is the multi-talker simulated speech from Librispeech corpus.
|
| 29 |
+
3. The output will be written in `out.rttm` by default.
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
python diarization.py --wav_path tmp/mix_0000496.wav --model_init $checkpoint_path
|
| 33 |
+
```
|
| 34 |
+
|
UniSpeech/downstreams/speaker_diarization/config/infer_est_nspk1.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
|
| 2 |
+
# All rights reserved
|
| 3 |
+
|
| 4 |
+
# inference options
|
| 5 |
+
est_nspk: 1
|
| 6 |
+
sil_spk_th: 0.05
|
| 7 |
+
ahc_dis_th: 1.0
|
| 8 |
+
clink_dis: 1.0e+4
|
| 9 |
+
model:
|
| 10 |
+
n_speakers: 3
|
| 11 |
+
all_n_speakers: 0
|
| 12 |
+
feat_dim: 1024
|
| 13 |
+
n_units: 256
|
| 14 |
+
n_heads: 8
|
| 15 |
+
n_layers: 6
|
| 16 |
+
dropout_rate: 0.1
|
| 17 |
+
spk_emb_dim: 256
|
| 18 |
+
sr: 8000
|
| 19 |
+
frame_shift: 320
|
| 20 |
+
frame_size: 200
|
| 21 |
+
context_size: 0
|
| 22 |
+
subsampling: 1
|
| 23 |
+
feat_type: "config/unispeech_sat.th"
|
| 24 |
+
feature_selection: "hidden_states"
|
| 25 |
+
interpolate_mode: "linear"
|
| 26 |
+
dataset:
|
| 27 |
+
chunk_size: 750
|
| 28 |
+
frame_shift: 320
|
| 29 |
+
sampling_rate: 8000
|
| 30 |
+
subsampling: 1
|
| 31 |
+
num_speakers: 3
|
UniSpeech/downstreams/speaker_diarization/config/unispeech_sat.th
ADDED
|
Binary file (20.7 kB). View file
|
|
|
UniSpeech/downstreams/speaker_diarization/diarization.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import h5py
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
import fire
|
| 5 |
+
import math
|
| 6 |
+
import yamlargparse
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
import torch
|
| 10 |
+
from utils.utils import parse_config_or_kwargs
|
| 11 |
+
from utils.dataset import DiarizationDataset
|
| 12 |
+
from models.models import TransformerDiarization
|
| 13 |
+
from scipy.signal import medfilt
|
| 14 |
+
from sklearn.cluster import AgglomerativeClustering
|
| 15 |
+
from scipy.spatial import distance
|
| 16 |
+
from utils.kaldi_data import KaldiData
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_cl_sil(args, acti, cls_num):
|
| 20 |
+
n_chunks = len(acti)
|
| 21 |
+
mean_acti = np.array([np.mean(acti[i], axis=0)
|
| 22 |
+
for i in range(n_chunks)]).flatten()
|
| 23 |
+
n = args.num_speakers
|
| 24 |
+
sil_spk_th = args.sil_spk_th
|
| 25 |
+
|
| 26 |
+
cl_lst = []
|
| 27 |
+
sil_lst = []
|
| 28 |
+
for chunk_idx in range(n_chunks):
|
| 29 |
+
if cls_num is not None:
|
| 30 |
+
if args.num_speakers > cls_num:
|
| 31 |
+
mean_acti_bi = np.array([mean_acti[n * chunk_idx + s_loc_idx]
|
| 32 |
+
for s_loc_idx in range(n)])
|
| 33 |
+
min_idx = np.argmin(mean_acti_bi)
|
| 34 |
+
mean_acti[n * chunk_idx + min_idx] = 0.0
|
| 35 |
+
|
| 36 |
+
for s_loc_idx in range(n):
|
| 37 |
+
a = n * chunk_idx + (s_loc_idx + 0) % n
|
| 38 |
+
b = n * chunk_idx + (s_loc_idx + 1) % n
|
| 39 |
+
if mean_acti[a] > sil_spk_th and mean_acti[b] > sil_spk_th:
|
| 40 |
+
cl_lst.append((a, b))
|
| 41 |
+
else:
|
| 42 |
+
if mean_acti[a] <= sil_spk_th:
|
| 43 |
+
sil_lst.append(a)
|
| 44 |
+
|
| 45 |
+
return cl_lst, sil_lst
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst):
|
| 49 |
+
org_svec_len = len(svec)
|
| 50 |
+
svec = np.delete(svec, sil_lst, 0)
|
| 51 |
+
|
| 52 |
+
# update cl_lst idx
|
| 53 |
+
_tbl = [i - sum(sil < i for sil in sil_lst) for i in range(org_svec_len)]
|
| 54 |
+
cl_lst = [(_tbl[_cl[0]], _tbl[_cl[1]]) for _cl in cl_lst]
|
| 55 |
+
|
| 56 |
+
distMat = distance.cdist(svec, svec, metric='euclidean')
|
| 57 |
+
for cl in cl_lst:
|
| 58 |
+
distMat[cl[0], cl[1]] = args.clink_dis
|
| 59 |
+
distMat[cl[1], cl[0]] = args.clink_dis
|
| 60 |
+
|
| 61 |
+
clusterer = AgglomerativeClustering(
|
| 62 |
+
n_clusters=cls_num,
|
| 63 |
+
affinity='precomputed',
|
| 64 |
+
linkage='average',
|
| 65 |
+
distance_threshold=ahc_dis_th)
|
| 66 |
+
clusterer.fit(distMat)
|
| 67 |
+
|
| 68 |
+
if cls_num is not None:
|
| 69 |
+
print("oracle n_clusters is known")
|
| 70 |
+
else:
|
| 71 |
+
print("oracle n_clusters is unknown")
|
| 72 |
+
print("estimated n_clusters by constraind AHC: {}"
|
| 73 |
+
.format(len(np.unique(clusterer.labels_))))
|
| 74 |
+
cls_num = len(np.unique(clusterer.labels_))
|
| 75 |
+
|
| 76 |
+
sil_lab = cls_num
|
| 77 |
+
insert_sil_lab = [sil_lab for i in range(len(sil_lst))]
|
| 78 |
+
insert_sil_lab_idx = [sil_lst[i] - i for i in range(len(sil_lst))]
|
| 79 |
+
print("insert_sil_lab : {}".format(insert_sil_lab))
|
| 80 |
+
print("insert_sil_lab_idx : {}".format(insert_sil_lab_idx))
|
| 81 |
+
clslab = np.insert(clusterer.labels_,
|
| 82 |
+
insert_sil_lab_idx,
|
| 83 |
+
insert_sil_lab).reshape(-1, args.num_speakers)
|
| 84 |
+
print("clslab : {}".format(clslab))
|
| 85 |
+
|
| 86 |
+
return clslab, cls_num
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def merge_act_max(act, i, j):
|
| 90 |
+
for k in range(len(act)):
|
| 91 |
+
act[k, i] = max(act[k, i], act[k, j])
|
| 92 |
+
act[k, j] = 0.0
|
| 93 |
+
return act
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def merge_acti_clslab(args, acti, clslab, cls_num):
|
| 97 |
+
sil_lab = cls_num
|
| 98 |
+
for i in range(len(clslab)):
|
| 99 |
+
_lab = clslab[i].reshape(-1, 1)
|
| 100 |
+
distM = distance.cdist(_lab, _lab, metric='euclidean').astype(np.int64)
|
| 101 |
+
for j in range(len(distM)):
|
| 102 |
+
distM[j][:j] = -1
|
| 103 |
+
idx_lst = np.where(np.count_nonzero(distM == 0, axis=1) > 1)
|
| 104 |
+
merge_done = []
|
| 105 |
+
for j in idx_lst[0]:
|
| 106 |
+
for k in (np.where(distM[j] == 0))[0]:
|
| 107 |
+
if j != k and clslab[i, j] != sil_lab and k not in merge_done:
|
| 108 |
+
print("merge : (i, j, k) == ({}, {}, {})".format(i, j, k))
|
| 109 |
+
acti[i] = merge_act_max(acti[i], j, k)
|
| 110 |
+
clslab[i, k] = sil_lab
|
| 111 |
+
merge_done.append(j)
|
| 112 |
+
|
| 113 |
+
return acti, clslab
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def stitching(args, acti, clslab, cls_num):
|
| 117 |
+
n_chunks = len(acti)
|
| 118 |
+
s_loc = args.num_speakers
|
| 119 |
+
sil_lab = cls_num
|
| 120 |
+
s_tot = max(cls_num, s_loc-1)
|
| 121 |
+
|
| 122 |
+
# Extend the max value of s_loc_idx to s_tot+1
|
| 123 |
+
add_acti = []
|
| 124 |
+
for chunk_idx in range(n_chunks):
|
| 125 |
+
zeros = np.zeros((len(acti[chunk_idx]), s_tot+1))
|
| 126 |
+
if s_tot+1 > s_loc:
|
| 127 |
+
zeros[:, :-(s_tot+1-s_loc)] = acti[chunk_idx]
|
| 128 |
+
else:
|
| 129 |
+
zeros = acti[chunk_idx]
|
| 130 |
+
add_acti.append(zeros)
|
| 131 |
+
acti = np.array(add_acti)
|
| 132 |
+
|
| 133 |
+
out_chunks = []
|
| 134 |
+
for chunk_idx in range(n_chunks):
|
| 135 |
+
# Make sloci2lab_dct.
|
| 136 |
+
# key: s_loc_idx
|
| 137 |
+
# value: estimated label by clustering or sil_lab
|
| 138 |
+
cls_set = set()
|
| 139 |
+
for s_loc_idx in range(s_tot+1):
|
| 140 |
+
cls_set.add(s_loc_idx)
|
| 141 |
+
|
| 142 |
+
sloci2lab_dct = {}
|
| 143 |
+
for s_loc_idx in range(s_tot+1):
|
| 144 |
+
if s_loc_idx < s_loc:
|
| 145 |
+
sloci2lab_dct[s_loc_idx] = clslab[chunk_idx][s_loc_idx]
|
| 146 |
+
if clslab[chunk_idx][s_loc_idx] in cls_set:
|
| 147 |
+
cls_set.remove(clslab[chunk_idx][s_loc_idx])
|
| 148 |
+
else:
|
| 149 |
+
if clslab[chunk_idx][s_loc_idx] != sil_lab:
|
| 150 |
+
raise ValueError
|
| 151 |
+
else:
|
| 152 |
+
sloci2lab_dct[s_loc_idx] = list(cls_set)[s_loc_idx-s_loc]
|
| 153 |
+
|
| 154 |
+
# Sort by label value
|
| 155 |
+
sloci2lab_lst = sorted(sloci2lab_dct.items(), key=lambda x: x[1])
|
| 156 |
+
|
| 157 |
+
# Select sil_lab_idx
|
| 158 |
+
sil_lab_idx = None
|
| 159 |
+
for idx_lab in sloci2lab_lst:
|
| 160 |
+
if idx_lab[1] == sil_lab:
|
| 161 |
+
sil_lab_idx = idx_lab[0]
|
| 162 |
+
break
|
| 163 |
+
if sil_lab_idx is None:
|
| 164 |
+
raise ValueError
|
| 165 |
+
|
| 166 |
+
# Get swap_idx
|
| 167 |
+
# [idx of label(0), idx of label(1), ..., idx of label(s_tot)]
|
| 168 |
+
swap_idx = [sil_lab_idx for j in range(s_tot+1)]
|
| 169 |
+
for lab in range(s_tot+1):
|
| 170 |
+
for idx_lab in sloci2lab_lst:
|
| 171 |
+
if lab == idx_lab[1]:
|
| 172 |
+
swap_idx[lab] = idx_lab[0]
|
| 173 |
+
|
| 174 |
+
print("swap_idx {}".format(swap_idx))
|
| 175 |
+
swap_acti = acti[chunk_idx][:, swap_idx]
|
| 176 |
+
swap_acti = np.delete(swap_acti, sil_lab, 1)
|
| 177 |
+
out_chunks.append(swap_acti)
|
| 178 |
+
|
| 179 |
+
return out_chunks
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def prediction(num_speakers, net, wav_list, chunk_len_list):
|
| 183 |
+
acti_lst = []
|
| 184 |
+
svec_lst = []
|
| 185 |
+
len_list = []
|
| 186 |
+
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
for wav, chunk_len in zip(wav_list, chunk_len_list):
|
| 189 |
+
wav = wav.to('cuda')
|
| 190 |
+
outputs = net.batch_estimate(torch.unsqueeze(wav, 0))
|
| 191 |
+
ys = outputs[0]
|
| 192 |
+
|
| 193 |
+
for i in range(num_speakers):
|
| 194 |
+
spkivecs = outputs[i+1]
|
| 195 |
+
svec_lst.append(spkivecs[0].cpu().detach().numpy())
|
| 196 |
+
|
| 197 |
+
acti = ys[0][-chunk_len:].cpu().detach().numpy()
|
| 198 |
+
acti_lst.append(acti)
|
| 199 |
+
len_list.append(chunk_len)
|
| 200 |
+
|
| 201 |
+
acti_arr = np.concatenate(acti_lst, axis=0) # totol_len x num_speakers
|
| 202 |
+
svec_arr = np.stack(svec_lst) # (chunk_num x num_speakers) x emb_dim
|
| 203 |
+
len_arr = np.array(len_list) # chunk_num
|
| 204 |
+
|
| 205 |
+
return acti_arr, svec_arr, len_arr
|
| 206 |
+
|
| 207 |
+
def cluster(args, conf, acti_arr, svec_arr, len_arr):
|
| 208 |
+
|
| 209 |
+
acti_list = []
|
| 210 |
+
n_chunks = len_arr.shape[0]
|
| 211 |
+
start = 0
|
| 212 |
+
for i in range(n_chunks):
|
| 213 |
+
chunk_len = len_arr[i]
|
| 214 |
+
acti_list.append(acti_arr[start: start+chunk_len])
|
| 215 |
+
start += chunk_len
|
| 216 |
+
acti = np.array(acti_list)
|
| 217 |
+
svec = svec_arr
|
| 218 |
+
|
| 219 |
+
# initialize clustering setting
|
| 220 |
+
cls_num = None
|
| 221 |
+
ahc_dis_th = args.ahc_dis_th
|
| 222 |
+
# Get cannot-link index list and silence index list
|
| 223 |
+
cl_lst, sil_lst = get_cl_sil(args, acti, cls_num)
|
| 224 |
+
|
| 225 |
+
n_samples = n_chunks * args.num_speakers - len(sil_lst)
|
| 226 |
+
min_n_samples = 2
|
| 227 |
+
if cls_num is not None:
|
| 228 |
+
min_n_samples = cls_num
|
| 229 |
+
|
| 230 |
+
if n_samples >= min_n_samples:
|
| 231 |
+
# clustering (if cls_num is None, update cls_num)
|
| 232 |
+
clslab, cls_num =\
|
| 233 |
+
clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst)
|
| 234 |
+
# merge
|
| 235 |
+
acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num)
|
| 236 |
+
# stitching
|
| 237 |
+
out_chunks = stitching(args, acti, clslab, cls_num)
|
| 238 |
+
else:
|
| 239 |
+
out_chunks = acti
|
| 240 |
+
|
| 241 |
+
outdata = np.vstack(out_chunks)
|
| 242 |
+
# Saving the resuts
|
| 243 |
+
return outdata
|
| 244 |
+
|
| 245 |
+
def make_rttm(args, conf, cluster_data):
|
| 246 |
+
args.frame_shift = conf['model']['frame_shift']
|
| 247 |
+
args.subsampling = conf['model']['subsampling']
|
| 248 |
+
args.sampling_rate = conf['dataset']['sampling_rate']
|
| 249 |
+
|
| 250 |
+
with open(args.out_rttm_file, 'w') as wf:
|
| 251 |
+
a = np.where(cluster_data > args.threshold, 1, 0)
|
| 252 |
+
if args.median > 1:
|
| 253 |
+
a = medfilt(a, (args.median, 1))
|
| 254 |
+
for spkid, frames in enumerate(a.T):
|
| 255 |
+
frames = np.pad(frames, (1, 1), 'constant')
|
| 256 |
+
changes, = np.where(np.diff(frames, axis=0) != 0)
|
| 257 |
+
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
|
| 258 |
+
for s, e in zip(changes[::2], changes[1::2]):
|
| 259 |
+
print(fmt.format(
|
| 260 |
+
args.session,
|
| 261 |
+
s * args.frame_shift * args.subsampling / args.sampling_rate,
|
| 262 |
+
(e - s) * args.frame_shift * args.subsampling / args.sampling_rate,
|
| 263 |
+
args.session + "_" + str(spkid)), file=wf)
|
| 264 |
+
|
| 265 |
+
def main(args):
|
| 266 |
+
conf = parse_config_or_kwargs(args.config_path)
|
| 267 |
+
num_speakers = conf['dataset']['num_speakers']
|
| 268 |
+
args.num_speakers = num_speakers
|
| 269 |
+
|
| 270 |
+
# Prepare model
|
| 271 |
+
model_parameter_dict = torch.load(args.model_init)['model']
|
| 272 |
+
model_all_n_speakers = model_parameter_dict["embed.weight"].shape[0]
|
| 273 |
+
conf['model']['all_n_speakers'] = model_all_n_speakers
|
| 274 |
+
net = TransformerDiarization(**conf['model'])
|
| 275 |
+
net.load_state_dict(model_parameter_dict, strict=False)
|
| 276 |
+
net.eval()
|
| 277 |
+
net = net.to("cuda")
|
| 278 |
+
|
| 279 |
+
audio, sr = sf.read(args.wav_path, dtype="float32")
|
| 280 |
+
audio_len = audio.shape[0]
|
| 281 |
+
chunk_size, frame_shift, subsampling = conf['dataset']['chunk_size'], conf['model']['frame_shift'], conf['model']['subsampling']
|
| 282 |
+
scale_ratio = int(frame_shift * subsampling)
|
| 283 |
+
chunk_audio_size = chunk_size * scale_ratio
|
| 284 |
+
wav_list, chunk_len_list = [], []
|
| 285 |
+
for i in range(0, math.ceil(1.0 * audio_len / chunk_audio_size)):
|
| 286 |
+
start, end = i*chunk_audio_size, (i+1)*chunk_audio_size
|
| 287 |
+
if end > audio_len:
|
| 288 |
+
chunk_len_list.append(int((audio_len-start) / scale_ratio))
|
| 289 |
+
end = audio_len
|
| 290 |
+
start = max(0, audio_len - chunk_audio_size)
|
| 291 |
+
else:
|
| 292 |
+
chunk_len_list.append(chunk_size)
|
| 293 |
+
wav_list.append(audio[start:end])
|
| 294 |
+
wav_list = [torch.from_numpy(wav).float() for wav in wav_list]
|
| 295 |
+
|
| 296 |
+
acti_arr, svec_arr, len_arr = prediction(num_speakers, net, wav_list, chunk_len_list)
|
| 297 |
+
cluster_data = cluster(args, conf, acti_arr, svec_arr, len_arr)
|
| 298 |
+
make_rttm(args, conf, cluster_data)
|
| 299 |
+
|
| 300 |
+
if __name__ == '__main__':
|
| 301 |
+
parser = yamlargparse.ArgumentParser(description='decoding')
|
| 302 |
+
parser.add_argument('--wav_path',
|
| 303 |
+
help='the input wav path',
|
| 304 |
+
default="tmp/mix_0000496.wav")
|
| 305 |
+
parser.add_argument('--config_path',
|
| 306 |
+
help='config file path',
|
| 307 |
+
default="config/infer_est_nspk1.yaml")
|
| 308 |
+
parser.add_argument('--model_init',
|
| 309 |
+
help='model initialize path',
|
| 310 |
+
default="")
|
| 311 |
+
parser.add_argument('--sil_spk_th', default=0.05, type=float)
|
| 312 |
+
parser.add_argument('--ahc_dis_th', default=1.0, type=float)
|
| 313 |
+
parser.add_argument('--clink_dis', default=1.0e+4, type=float)
|
| 314 |
+
parser.add_argument('--session', default='Anonymous', help='the name of the output speaker')
|
| 315 |
+
parser.add_argument('--out_rttm_file', default='out.rttm', help='the output rttm file')
|
| 316 |
+
parser.add_argument('--threshold', default=0.4, type=float)
|
| 317 |
+
parser.add_argument('--median', default=25, type=int)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
args = parser.parse_args()
|
| 321 |
+
main(args)
|
UniSpeech/downstreams/speaker_diarization/models/models.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
|
| 2 |
+
# All rights reserved
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torchaudio.transforms as trans
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
from itertools import permutations
|
| 12 |
+
from models.transformer import TransformerEncoder
|
| 13 |
+
from .utils import UpstreamExpert
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GradMultiply(torch.autograd.Function):
|
| 17 |
+
@staticmethod
|
| 18 |
+
def forward(ctx, x, scale):
|
| 19 |
+
ctx.scale = scale
|
| 20 |
+
res = x.new(x)
|
| 21 |
+
return res
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def backward(ctx, grad):
|
| 25 |
+
return grad * ctx.scale, None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
P: number of permutation
|
| 31 |
+
T: number of frames
|
| 32 |
+
C: number of speakers (classes)
|
| 33 |
+
B: mini-batch size
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def batch_pit_loss_parallel(outputs, labels, ilens=None):
|
| 38 |
+
""" calculate the batch pit loss parallelly
|
| 39 |
+
Args:
|
| 40 |
+
outputs (torch.Tensor): B x T x C
|
| 41 |
+
labels (torch.Tensor): B x T x C
|
| 42 |
+
ilens (torch.Tensor): B
|
| 43 |
+
Returns:
|
| 44 |
+
perm (torch.Tensor): permutation for outputs (Batch, num_spk)
|
| 45 |
+
loss
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
if ilens is None:
|
| 49 |
+
mask, scale = 1.0, outputs.shape[1]
|
| 50 |
+
else:
|
| 51 |
+
scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device)
|
| 52 |
+
mask = outputs.new_zeros(outputs.size()[:-1])
|
| 53 |
+
for i, chunk_len in enumerate(ilens):
|
| 54 |
+
mask[i, :chunk_len] += 1.0
|
| 55 |
+
mask /= scale
|
| 56 |
+
|
| 57 |
+
def loss_func(output, label):
|
| 58 |
+
# return torch.mean(F.binary_cross_entropy_with_logits(output, label, reduction='none'), dim=tuple(range(1, output.dim())))
|
| 59 |
+
return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1)
|
| 60 |
+
|
| 61 |
+
def pair_loss(outputs, labels, permutation):
|
| 62 |
+
return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation)
|
| 63 |
+
|
| 64 |
+
device = outputs.device
|
| 65 |
+
num_spk = outputs.shape[-1]
|
| 66 |
+
all_permutations = list(permutations(range(num_spk)))
|
| 67 |
+
losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1)
|
| 68 |
+
loss, perm = torch.min(losses, dim=1)
|
| 69 |
+
perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm)
|
| 70 |
+
return torch.mean(loss), perm
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def fix_state_dict(state_dict):
|
| 74 |
+
new_state_dict = OrderedDict()
|
| 75 |
+
for k, v in state_dict.items():
|
| 76 |
+
if k.startswith('module.'):
|
| 77 |
+
# remove 'module.' of DataParallel
|
| 78 |
+
k = k[7:]
|
| 79 |
+
if k.startswith('net.'):
|
| 80 |
+
# remove 'net.' of PadertorchModel
|
| 81 |
+
k = k[4:]
|
| 82 |
+
new_state_dict[k] = v
|
| 83 |
+
return new_state_dict
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TransformerDiarization(nn.Module):
|
| 87 |
+
def __init__(self,
|
| 88 |
+
n_speakers,
|
| 89 |
+
all_n_speakers,
|
| 90 |
+
feat_dim,
|
| 91 |
+
n_units,
|
| 92 |
+
n_heads,
|
| 93 |
+
n_layers,
|
| 94 |
+
dropout_rate,
|
| 95 |
+
spk_emb_dim,
|
| 96 |
+
sr=8000,
|
| 97 |
+
frame_shift=256,
|
| 98 |
+
frame_size=1024,
|
| 99 |
+
context_size=0,
|
| 100 |
+
subsampling=1,
|
| 101 |
+
feat_type='fbank',
|
| 102 |
+
feature_selection='default',
|
| 103 |
+
interpolate_mode='linear',
|
| 104 |
+
update_extract=False,
|
| 105 |
+
feature_grad_mult=1.0
|
| 106 |
+
):
|
| 107 |
+
super(TransformerDiarization, self).__init__()
|
| 108 |
+
self.context_size = context_size
|
| 109 |
+
self.subsampling = subsampling
|
| 110 |
+
self.feat_type = feat_type
|
| 111 |
+
self.feature_selection = feature_selection
|
| 112 |
+
self.sr = sr
|
| 113 |
+
self.frame_shift = frame_shift
|
| 114 |
+
self.interpolate_mode = interpolate_mode
|
| 115 |
+
self.update_extract = update_extract
|
| 116 |
+
self.feature_grad_mult = feature_grad_mult
|
| 117 |
+
|
| 118 |
+
if feat_type == 'fbank':
|
| 119 |
+
self.feature_extract = trans.MelSpectrogram(sample_rate=sr,
|
| 120 |
+
n_fft=frame_size,
|
| 121 |
+
win_length=frame_size,
|
| 122 |
+
hop_length=frame_shift,
|
| 123 |
+
f_min=0.0,
|
| 124 |
+
f_max=sr // 2,
|
| 125 |
+
pad=0,
|
| 126 |
+
n_mels=feat_dim)
|
| 127 |
+
else:
|
| 128 |
+
self.feature_extract = UpstreamExpert(feat_type)
|
| 129 |
+
# self.feature_extract = torch.hub.load('s3prl/s3prl', 'hubert_local', ckpt=feat_type)
|
| 130 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
|
| 131 |
+
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
| 132 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
|
| 133 |
+
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
| 134 |
+
self.feat_num = self.get_feat_num()
|
| 135 |
+
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
| 136 |
+
# for param in self.feature_extract.parameters():
|
| 137 |
+
# param.requires_grad = False
|
| 138 |
+
self.resample = trans.Resample(orig_freq=sr, new_freq=16000)
|
| 139 |
+
|
| 140 |
+
if feat_type != 'fbank' and feat_type != 'mfcc':
|
| 141 |
+
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer', 'spk_proj', 'layer_norm_for_extract']
|
| 142 |
+
for name, param in self.feature_extract.named_parameters():
|
| 143 |
+
for freeze_val in freeze_list:
|
| 144 |
+
if freeze_val in name:
|
| 145 |
+
param.requires_grad = False
|
| 146 |
+
break
|
| 147 |
+
if not self.update_extract:
|
| 148 |
+
for param in self.feature_extract.parameters():
|
| 149 |
+
param.requires_grad = False
|
| 150 |
+
|
| 151 |
+
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
| 152 |
+
|
| 153 |
+
feat_dim = feat_dim * (self.context_size*2 + 1)
|
| 154 |
+
self.enc = TransformerEncoder(
|
| 155 |
+
feat_dim, n_layers, n_units, h=n_heads, dropout_rate=dropout_rate)
|
| 156 |
+
self.linear = nn.Linear(n_units, n_speakers)
|
| 157 |
+
|
| 158 |
+
for i in range(n_speakers):
|
| 159 |
+
setattr(self, '{}{:d}'.format("linear", i), nn.Linear(n_units, spk_emb_dim))
|
| 160 |
+
|
| 161 |
+
self.n_speakers = n_speakers
|
| 162 |
+
self.embed = nn.Embedding(all_n_speakers, spk_emb_dim)
|
| 163 |
+
self.alpha = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
|
| 164 |
+
self.beta = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
|
| 165 |
+
|
| 166 |
+
def get_feat_num(self):
|
| 167 |
+
self.feature_extract.eval()
|
| 168 |
+
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
features = self.feature_extract(wav)
|
| 171 |
+
select_feature = features[self.feature_selection]
|
| 172 |
+
if isinstance(select_feature, (list, tuple)):
|
| 173 |
+
return len(select_feature)
|
| 174 |
+
else:
|
| 175 |
+
return 1
|
| 176 |
+
|
| 177 |
+
def fix_except_embedding(self, requires_grad=False):
|
| 178 |
+
for name, param in self.named_parameters():
|
| 179 |
+
if 'embed' not in name:
|
| 180 |
+
param.requires_grad = requires_grad
|
| 181 |
+
|
| 182 |
+
def modfy_emb(self, weight):
|
| 183 |
+
self.embed = nn.Embedding.from_pretrained(weight)
|
| 184 |
+
|
| 185 |
+
def splice(self, data, context_size):
|
| 186 |
+
# data: B x feat_dim x time_len
|
| 187 |
+
data = torch.unsqueeze(data, -1)
|
| 188 |
+
kernel_size = context_size*2 + 1
|
| 189 |
+
splice_data = F.unfold(data, kernel_size=(kernel_size, 1), padding=(context_size, 0))
|
| 190 |
+
return splice_data
|
| 191 |
+
|
| 192 |
+
def get_feat(self, xs):
|
| 193 |
+
wav_len = xs.shape[-1]
|
| 194 |
+
chunk_size = int(wav_len / self.frame_shift)
|
| 195 |
+
chunk_size = int(chunk_size / self.subsampling)
|
| 196 |
+
|
| 197 |
+
self.feature_extract.eval()
|
| 198 |
+
if self.update_extract:
|
| 199 |
+
xs = self.resample(xs)
|
| 200 |
+
feature = self.feature_extract([sample for sample in xs])
|
| 201 |
+
else:
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
if self.feat_type == 'fbank':
|
| 204 |
+
feature = self.feature_extract(xs) + 1e-6 # B x feat_dim x time_len
|
| 205 |
+
feature = feature.log()
|
| 206 |
+
else:
|
| 207 |
+
xs = self.resample(xs)
|
| 208 |
+
feature = self.feature_extract([sample for sample in xs])
|
| 209 |
+
|
| 210 |
+
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
| 211 |
+
feature = feature[self.feature_selection]
|
| 212 |
+
if isinstance(feature, (list, tuple)):
|
| 213 |
+
feature = torch.stack(feature, dim=0)
|
| 214 |
+
else:
|
| 215 |
+
feature = feature.unsqueeze(0)
|
| 216 |
+
|
| 217 |
+
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 218 |
+
feature = (norm_weights * feature).sum(dim=0)
|
| 219 |
+
feature = torch.transpose(feature, 1, 2) + 1e-6
|
| 220 |
+
|
| 221 |
+
feature = self.instance_norm(feature)
|
| 222 |
+
feature = self.splice(feature, self.context_size)
|
| 223 |
+
feature = feature[:, :, ::self.subsampling]
|
| 224 |
+
feature = F.interpolate(feature, chunk_size, mode=self.interpolate_mode)
|
| 225 |
+
feature = torch.transpose(feature, 1, 2)
|
| 226 |
+
|
| 227 |
+
if self.feature_grad_mult != 1.0:
|
| 228 |
+
feature = GradMultiply.apply(feature, self.feature_grad_mult)
|
| 229 |
+
|
| 230 |
+
return feature
|
| 231 |
+
|
| 232 |
+
def forward(self, inputs):
|
| 233 |
+
if isinstance(inputs, list):
|
| 234 |
+
xs = inputs[0]
|
| 235 |
+
else:
|
| 236 |
+
xs = inputs
|
| 237 |
+
feature = self.get_feat(xs)
|
| 238 |
+
|
| 239 |
+
pad_shape = feature.shape
|
| 240 |
+
emb = self.enc(feature)
|
| 241 |
+
ys = self.linear(emb)
|
| 242 |
+
ys = ys.reshape(pad_shape[0], pad_shape[1], -1)
|
| 243 |
+
|
| 244 |
+
spksvecs = []
|
| 245 |
+
for i in range(self.n_speakers):
|
| 246 |
+
spkivecs = getattr(self, '{}{:d}'.format("linear", i))(emb)
|
| 247 |
+
spkivecs = spkivecs.reshape(pad_shape[0], pad_shape[1], -1)
|
| 248 |
+
spksvecs.append(spkivecs)
|
| 249 |
+
|
| 250 |
+
return ys, spksvecs
|
| 251 |
+
|
| 252 |
+
def get_loss(self, inputs, ys, spksvecs, cal_spk_loss=True):
|
| 253 |
+
ts = inputs[1]
|
| 254 |
+
ss = inputs[2]
|
| 255 |
+
ns = inputs[3]
|
| 256 |
+
ilens = inputs[4]
|
| 257 |
+
ilens = [ilen.item() for ilen in ilens]
|
| 258 |
+
|
| 259 |
+
pit_loss, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
|
| 260 |
+
if cal_spk_loss:
|
| 261 |
+
spk_loss = self.spk_loss_parallel(spksvecs, ys, ts, ss, sigmas, ns, ilens)
|
| 262 |
+
else:
|
| 263 |
+
spk_loss = torch.tensor(0.0).to(pit_loss.device)
|
| 264 |
+
|
| 265 |
+
alpha = torch.clamp(self.alpha, min=sys.float_info.epsilon)
|
| 266 |
+
|
| 267 |
+
return {'spk_loss':spk_loss,
|
| 268 |
+
'pit_loss': pit_loss}
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def batch_estimate(self, xs):
|
| 272 |
+
out = self(xs)
|
| 273 |
+
ys = out[0]
|
| 274 |
+
spksvecs = out[1]
|
| 275 |
+
spksvecs = list(zip(*spksvecs))
|
| 276 |
+
outputs = [
|
| 277 |
+
self.estimate(spksvec, y)
|
| 278 |
+
for (spksvec, y) in zip(spksvecs, ys)]
|
| 279 |
+
outputs = list(zip(*outputs))
|
| 280 |
+
|
| 281 |
+
return outputs
|
| 282 |
+
|
| 283 |
+
def batch_estimate_with_perm(self, xs, ts, ilens=None):
|
| 284 |
+
out = self(xs)
|
| 285 |
+
ys = out[0]
|
| 286 |
+
if ts[0].shape[1] > ys[0].shape[1]:
|
| 287 |
+
# e.g. the case of training 3-spk model with 4-spk data
|
| 288 |
+
add_dim = ts[0].shape[1] - ys[0].shape[1]
|
| 289 |
+
y_device = ys[0].device
|
| 290 |
+
zeros = [torch.zeros(ts[0].shape).to(y_device)
|
| 291 |
+
for i in range(len(ts))]
|
| 292 |
+
_ys = []
|
| 293 |
+
for zero, y in zip(zeros, ys):
|
| 294 |
+
_zero = zero
|
| 295 |
+
_zero[:, :-add_dim] = y
|
| 296 |
+
_ys.append(_zero)
|
| 297 |
+
_, sigmas = batch_pit_loss_parallel(_ys, ts, ilens)
|
| 298 |
+
else:
|
| 299 |
+
_, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
|
| 300 |
+
spksvecs = out[1]
|
| 301 |
+
spksvecs = list(zip(*spksvecs))
|
| 302 |
+
outputs = [self.estimate(spksvec, y)
|
| 303 |
+
for (spksvec, y) in zip(spksvecs, ys)]
|
| 304 |
+
outputs = list(zip(*outputs))
|
| 305 |
+
zs = outputs[0]
|
| 306 |
+
|
| 307 |
+
if ts[0].shape[1] > ys[0].shape[1]:
|
| 308 |
+
# e.g. the case of training 3-spk model with 4-spk data
|
| 309 |
+
add_dim = ts[0].shape[1] - ys[0].shape[1]
|
| 310 |
+
z_device = zs[0].device
|
| 311 |
+
zeros = [torch.zeros(ts[0].shape).to(z_device)
|
| 312 |
+
for i in range(len(ts))]
|
| 313 |
+
_zs = []
|
| 314 |
+
for zero, z in zip(zeros, zs):
|
| 315 |
+
_zero = zero
|
| 316 |
+
_zero[:, :-add_dim] = z
|
| 317 |
+
_zs.append(_zero)
|
| 318 |
+
zs = _zs
|
| 319 |
+
outputs[0] = zs
|
| 320 |
+
outputs.append(sigmas)
|
| 321 |
+
|
| 322 |
+
# outputs: [zs, nmz_wavg_spk0vecs, nmz_wavg_spk1vecs, ..., sigmas]
|
| 323 |
+
return outputs
|
| 324 |
+
|
| 325 |
+
def estimate(self, spksvec, y):
|
| 326 |
+
outputs = []
|
| 327 |
+
z = torch.sigmoid(y.transpose(1, 0))
|
| 328 |
+
|
| 329 |
+
outputs.append(z.transpose(1, 0))
|
| 330 |
+
for spkid, spkvec in enumerate(spksvec):
|
| 331 |
+
norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1)
|
| 332 |
+
# Normalize speaker vectors before weighted average
|
| 333 |
+
spkvec = torch.mul(
|
| 334 |
+
spkvec.transpose(1, 0), norm_spkvec_inv
|
| 335 |
+
).transpose(1, 0)
|
| 336 |
+
wavg_spkvec = torch.mul(
|
| 337 |
+
spkvec.transpose(1, 0), z[spkid]
|
| 338 |
+
).transpose(1, 0)
|
| 339 |
+
sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0)
|
| 340 |
+
nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec)
|
| 341 |
+
outputs.append(nmz_wavg_spkvec)
|
| 342 |
+
|
| 343 |
+
# outputs: [z, nmz_wavg_spk0vec, nmz_wavg_spk1vec, ...]
|
| 344 |
+
return outputs
|
| 345 |
+
|
| 346 |
+
def spk_loss_parallel(self, spksvecs, ys, ts, ss, sigmas, ns, ilens):
|
| 347 |
+
'''
|
| 348 |
+
spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...]
|
| 349 |
+
ys (torch.Tensor): B x T x 3
|
| 350 |
+
ts (torch.Tensor): B x T x 3
|
| 351 |
+
ss (torch.Tensor): B x 3
|
| 352 |
+
sigmas (torch.Tensor): B x 3
|
| 353 |
+
ns (torch.Tensor): B x total_spk_num x 1
|
| 354 |
+
ilens (List): B
|
| 355 |
+
'''
|
| 356 |
+
chunk_spk_num = len(spksvecs) # 3
|
| 357 |
+
|
| 358 |
+
len_mask = ys.new_zeros((ys.size()[:-1])) # B x T
|
| 359 |
+
for i, len_val in enumerate(ilens):
|
| 360 |
+
len_mask[i,:len_val] += 1.0
|
| 361 |
+
ts = ts * len_mask.unsqueeze(-1)
|
| 362 |
+
len_mask = len_mask.repeat((chunk_spk_num, 1)) # B*3 x T
|
| 363 |
+
|
| 364 |
+
spk_vecs = torch.cat(spksvecs, dim=0) # B*3 x T x emb_dim
|
| 365 |
+
# Normalize speaker vectors before weighted average
|
| 366 |
+
spk_vecs = F.normalize(spk_vecs, dim=-1)
|
| 367 |
+
|
| 368 |
+
ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) # 3 x B x T
|
| 369 |
+
ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) # B*3 x T x 1
|
| 370 |
+
|
| 371 |
+
weight_spk_vec = ys * spk_vecs # B*3 x T x emb_dim
|
| 372 |
+
weight_spk_vec *= len_mask.unsqueeze(-1)
|
| 373 |
+
sum_spk_vec = torch.sum(weight_spk_vec, dim=1) # B*3 x emb_dim
|
| 374 |
+
norm_spk_vec = F.normalize(sum_spk_vec, dim=1)
|
| 375 |
+
|
| 376 |
+
embeds = F.normalize(self.embed(ns[0]).squeeze(), dim=1) # total_spk_num x emb_dim
|
| 377 |
+
dist = torch.cdist(norm_spk_vec, embeds) # B*3 x total_spk_num
|
| 378 |
+
logits = -1.0 * torch.add(torch.clamp(self.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.beta)
|
| 379 |
+
label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() # B*3
|
| 380 |
+
label[label==-1] = 0
|
| 381 |
+
valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) # 3 x B
|
| 382 |
+
valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() # B*3
|
| 383 |
+
|
| 384 |
+
valid_spk_loss_num = torch.sum(valid_spk_mask).item()
|
| 385 |
+
if valid_spk_loss_num > 0:
|
| 386 |
+
loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num
|
| 387 |
+
# uncomment the line below, the loss result is same as batch_spk_loss
|
| 388 |
+
# loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_mask.shape[0]
|
| 389 |
+
return torch.sum(loss)
|
| 390 |
+
else:
|
| 391 |
+
return torch.tensor(0.0).to(ys.device)
|
UniSpeech/downstreams/speaker_diarization/models/transformer.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
|
| 2 |
+
# All rights reserved
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class NoamScheduler(_LRScheduler):
|
| 12 |
+
""" learning rate scheduler used in the transformer
|
| 13 |
+
See https://arxiv.org/pdf/1706.03762.pdf
|
| 14 |
+
lrate = d_model**(-0.5) * \
|
| 15 |
+
min(step_num**(-0.5), step_num*warmup_steps**(-1.5))
|
| 16 |
+
Scaling factor is implemented as in
|
| 17 |
+
http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self, optimizer, d_model, warmup_steps, tot_step, scale,
|
| 22 |
+
last_epoch=-1
|
| 23 |
+
):
|
| 24 |
+
self.d_model = d_model
|
| 25 |
+
self.warmup_steps = warmup_steps
|
| 26 |
+
self.tot_step = tot_step
|
| 27 |
+
self.scale = scale
|
| 28 |
+
super(NoamScheduler, self).__init__(optimizer, last_epoch)
|
| 29 |
+
|
| 30 |
+
def get_lr(self):
|
| 31 |
+
self.last_epoch = max(1, self.last_epoch)
|
| 32 |
+
step_num = self.last_epoch
|
| 33 |
+
val = self.scale * self.d_model ** (-0.5) * \
|
| 34 |
+
min(step_num ** (-0.5), step_num * self.warmup_steps ** (-1.5))
|
| 35 |
+
|
| 36 |
+
return [base_lr / base_lr * val for base_lr in self.base_lrs]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 40 |
+
""" Multi head "self" attention layer
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, n_units, h=8, dropout_rate=0.1):
|
| 44 |
+
super(MultiHeadSelfAttention, self).__init__()
|
| 45 |
+
self.linearQ = nn.Linear(n_units, n_units)
|
| 46 |
+
self.linearK = nn.Linear(n_units, n_units)
|
| 47 |
+
self.linearV = nn.Linear(n_units, n_units)
|
| 48 |
+
self.linearO = nn.Linear(n_units, n_units)
|
| 49 |
+
self.d_k = n_units // h
|
| 50 |
+
self.h = h
|
| 51 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 52 |
+
# attention for plot
|
| 53 |
+
self.att = None
|
| 54 |
+
|
| 55 |
+
def forward(self, x, batch_size):
|
| 56 |
+
# x: (BT, F)
|
| 57 |
+
q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k)
|
| 58 |
+
k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k)
|
| 59 |
+
v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k)
|
| 60 |
+
|
| 61 |
+
scores = torch.matmul(
|
| 62 |
+
q.transpose(1, 2), k.permute(0, 2, 3, 1)) / np.sqrt(self.d_k)
|
| 63 |
+
# scores: (B, h, T, T) = (B, h, T, d_k) x (B, h, d_k, T)
|
| 64 |
+
self.att = F.softmax(scores, dim=3)
|
| 65 |
+
p_att = self.dropout(self.att)
|
| 66 |
+
x = torch.matmul(p_att, v.transpose(1, 2))
|
| 67 |
+
x = x.transpose(1, 2).reshape(-1, self.h * self.d_k)
|
| 68 |
+
|
| 69 |
+
return self.linearO(x)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class PositionwiseFeedForward(nn.Module):
|
| 73 |
+
""" Positionwise feed-forward layer
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, n_units, d_units, dropout_rate):
|
| 77 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 78 |
+
self.linear1 = nn.Linear(n_units, d_units)
|
| 79 |
+
self.linear2 = nn.Linear(d_units, n_units)
|
| 80 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
return self.linear2(self.dropout(F.relu(self.linear1(x))))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class PositionalEncoding(nn.Module):
|
| 87 |
+
""" Positional encoding function
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, n_units, dropout_rate, max_len):
|
| 91 |
+
super(PositionalEncoding, self).__init__()
|
| 92 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 93 |
+
positions = np.arange(0, max_len, dtype='f')[:, None]
|
| 94 |
+
dens = np.exp(
|
| 95 |
+
np.arange(0, n_units, 2, dtype='f') * -(np.log(10000.) / n_units))
|
| 96 |
+
self.enc = np.zeros((max_len, n_units), dtype='f')
|
| 97 |
+
self.enc[:, ::2] = np.sin(positions * dens)
|
| 98 |
+
self.enc[:, 1::2] = np.cos(positions * dens)
|
| 99 |
+
self.scale = np.sqrt(n_units)
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
x = x * self.scale + self.xp.array(self.enc[:, :x.shape[1]])
|
| 103 |
+
return self.dropout(x)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TransformerEncoder(nn.Module):
|
| 107 |
+
def __init__(self, idim, n_layers, n_units,
|
| 108 |
+
e_units=2048, h=8, dropout_rate=0.1):
|
| 109 |
+
super(TransformerEncoder, self).__init__()
|
| 110 |
+
self.linear_in = nn.Linear(idim, n_units)
|
| 111 |
+
# self.lnorm_in = nn.LayerNorm(n_units)
|
| 112 |
+
self.pos_enc = PositionalEncoding(n_units, dropout_rate, 5000)
|
| 113 |
+
self.n_layers = n_layers
|
| 114 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 115 |
+
for i in range(n_layers):
|
| 116 |
+
setattr(self, '{}{:d}'.format("lnorm1_", i),
|
| 117 |
+
nn.LayerNorm(n_units))
|
| 118 |
+
setattr(self, '{}{:d}'.format("self_att_", i),
|
| 119 |
+
MultiHeadSelfAttention(n_units, h, dropout_rate))
|
| 120 |
+
setattr(self, '{}{:d}'.format("lnorm2_", i),
|
| 121 |
+
nn.LayerNorm(n_units))
|
| 122 |
+
setattr(self, '{}{:d}'.format("ff_", i),
|
| 123 |
+
PositionwiseFeedForward(n_units, e_units, dropout_rate))
|
| 124 |
+
self.lnorm_out = nn.LayerNorm(n_units)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
# x: (B, T, F) ... batch, time, (mel)freq
|
| 128 |
+
BT_size = x.shape[0] * x.shape[1]
|
| 129 |
+
# e: (BT, F)
|
| 130 |
+
e = self.linear_in(x.reshape(BT_size, -1))
|
| 131 |
+
# Encoder stack
|
| 132 |
+
for i in range(self.n_layers):
|
| 133 |
+
# layer normalization
|
| 134 |
+
e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
|
| 135 |
+
# self-attention
|
| 136 |
+
s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0])
|
| 137 |
+
# residual
|
| 138 |
+
e = e + self.dropout(s)
|
| 139 |
+
# layer normalization
|
| 140 |
+
e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
|
| 141 |
+
# positionwise feed-forward
|
| 142 |
+
s = getattr(self, '{}{:d}'.format("ff_", i))(e)
|
| 143 |
+
# residual
|
| 144 |
+
e = e + self.dropout(s)
|
| 145 |
+
# final layer normalization
|
| 146 |
+
# output: (BT, F)
|
| 147 |
+
return self.lnorm_out(e)
|
UniSpeech/downstreams/speaker_diarization/models/utils.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import fairseq
|
| 3 |
+
from packaging import version
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from fairseq import tasks
|
| 6 |
+
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
| 7 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
from s3prl.upstream.interfaces import UpstreamBase
|
| 10 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 11 |
+
|
| 12 |
+
def load_model(filepath):
|
| 13 |
+
state = torch.load(filepath, map_location=lambda storage, loc: storage)
|
| 14 |
+
# state = load_checkpoint_to_cpu(filepath)
|
| 15 |
+
state["cfg"] = OmegaConf.create(state["cfg"])
|
| 16 |
+
|
| 17 |
+
if "args" in state and state["args"] is not None:
|
| 18 |
+
cfg = convert_namespace_to_omegaconf(state["args"])
|
| 19 |
+
elif "cfg" in state and state["cfg"] is not None:
|
| 20 |
+
cfg = state["cfg"]
|
| 21 |
+
else:
|
| 22 |
+
raise RuntimeError(
|
| 23 |
+
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
task = tasks.setup_task(cfg.task)
|
| 27 |
+
if "task_state" in state:
|
| 28 |
+
task.load_state_dict(state["task_state"])
|
| 29 |
+
|
| 30 |
+
model = task.build_model(cfg.model)
|
| 31 |
+
|
| 32 |
+
return model, cfg, task
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
###################
|
| 36 |
+
# UPSTREAM EXPERT #
|
| 37 |
+
###################
|
| 38 |
+
class UpstreamExpert(UpstreamBase):
|
| 39 |
+
def __init__(self, ckpt, **kwargs):
|
| 40 |
+
super().__init__(**kwargs)
|
| 41 |
+
assert version.parse(fairseq.__version__) > version.parse(
|
| 42 |
+
"0.10.2"
|
| 43 |
+
), "Please install the fairseq master branch."
|
| 44 |
+
|
| 45 |
+
model, cfg, task = load_model(ckpt)
|
| 46 |
+
self.model = model
|
| 47 |
+
self.task = task
|
| 48 |
+
|
| 49 |
+
if len(self.hooks) == 0:
|
| 50 |
+
module_name = "self.model.encoder.layers"
|
| 51 |
+
for module_id in range(len(eval(module_name))):
|
| 52 |
+
self.add_hook(
|
| 53 |
+
f"{module_name}[{module_id}]",
|
| 54 |
+
lambda input, output: input[0].transpose(0, 1),
|
| 55 |
+
)
|
| 56 |
+
self.add_hook("self.model.encoder", lambda input, output: output[0])
|
| 57 |
+
|
| 58 |
+
def forward(self, wavs):
|
| 59 |
+
if self.task.cfg.normalize:
|
| 60 |
+
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
|
| 61 |
+
|
| 62 |
+
device = wavs[0].device
|
| 63 |
+
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
|
| 64 |
+
wav_padding_mask = ~torch.lt(
|
| 65 |
+
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
|
| 66 |
+
wav_lengths.unsqueeze(1),
|
| 67 |
+
)
|
| 68 |
+
padded_wav = pad_sequence(wavs, batch_first=True)
|
| 69 |
+
|
| 70 |
+
features, feat_padding_mask = self.model.extract_features(
|
| 71 |
+
padded_wav,
|
| 72 |
+
padding_mask=wav_padding_mask,
|
| 73 |
+
mask=None,
|
| 74 |
+
)
|
| 75 |
+
return {
|
| 76 |
+
"default": features,
|
| 77 |
+
}
|
| 78 |
+
|
UniSpeech/downstreams/speaker_diarization/requirements.txt
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SoundFile==0.10.3.post1 \
|
| 2 |
+
--hash=sha256:2d17e0a6fc2af0d6c1d868bafa5ec80aae6e186a97fec8db07ad6af29842fbc7 \
|
| 3 |
+
--hash=sha256:4555438c2c4f02b39fea2ed40f6ddeda88a80cd1ee9dd129be4d5f5134698cc2 \
|
| 4 |
+
--hash=sha256:490cff42650733d1832728b937fe99fa1802896f5ef4d61bcf78cf7ebecb107b \
|
| 5 |
+
--hash=sha256:5e342ee293b896d31da67617fe65d0bdca217af193991b0cb6052353b1e0e506 \
|
| 6 |
+
--hash=sha256:b361d4ac1519a2e516cabafa6bf7e93492f999f35d7d25350cd87fdc3e5cb27e
|
| 7 |
+
fire==0.4.0 \
|
| 8 |
+
--hash=sha256:c5e2b8763699d1142393a46d0e3e790c5eb2f0706082df8f647878842c216a62
|
| 9 |
+
sentencepiece==0.1.96 \
|
| 10 |
+
--hash=sha256:1dac8c2ad02b5ebc1179c0a14cbc7d7c6f4fd73d4dd51820626402d0aefc974e \
|
| 11 |
+
--hash=sha256:26d20d713b3ba1b7a19205336afb1e93a4327c372b2f795e907b8dc2315ac92e \
|
| 12 |
+
--hash=sha256:335bf84d72112cc91f3c3b691d61802fc963503b7772fd8280d20368048b8f3e \
|
| 13 |
+
--hash=sha256:36e9ff61e7b67c5b7ee96733613622620b4802fc8cf188a4dbc1f355b03dde02 \
|
| 14 |
+
--hash=sha256:384148cead5cdab34a4d74fe1fb6a5a8abaafed25eaa4a7698b49dd9482e4c4e \
|
| 15 |
+
--hash=sha256:3c703e68ea192e45b65c5d5836f6980849d828a18da4189899d7150fad82dc9e \
|
| 16 |
+
--hash=sha256:3e61e0757e49c306fff78ea75d6b75773418fe22214b4a460959203be934e834 \
|
| 17 |
+
--hash=sha256:466e381f0a812da8fda97a9707498cef3210ea8385a3421bcbadcb5384063969 \
|
| 18 |
+
--hash=sha256:48c6d13b3bfff08060c138248e85df60f6fad11135ad7a8fc2ef6005aacca839 \
|
| 19 |
+
--hash=sha256:4997c7ccf2ae462320250314aa5709a88d8a09fa271d073458a07bebf33f8e7c \
|
| 20 |
+
--hash=sha256:5388882bb24d083f6cc8cffc5c435f3694a7772b018e06ea6fd84d1044009efb \
|
| 21 |
+
--hash=sha256:5513298d62fe63dd0862d08a6eb52a9aa3537006f597f2386184e3f95bb88889 \
|
| 22 |
+
--hash=sha256:78e18d9106c36dcca929e18fd2c412378deac661d47fa3ee25defc55eef8a215 \
|
| 23 |
+
--hash=sha256:8179785883b556cd517416cdbda6244745414b00ec83132cfe1d26000971f3ae \
|
| 24 |
+
--hash=sha256:81bb77ba3651114943b2f8f77829cf764137dff06e38f4bf7fa43efea12c7f84 \
|
| 25 |
+
--hash=sha256:89c038da7f827a6e2ca4c73aeb4e4b25b99d981ce47dd61b04d446c8200cba1e \
|
| 26 |
+
--hash=sha256:940a6999c7d3f55e9d7b194fd5e1f41a7dbed26d3519fb95333216292a39599e \
|
| 27 |
+
--hash=sha256:99ea2d9db19e63a2d17d5dc64f9ace83fb9308a735be05a1aaf98eb4b496fba7 \
|
| 28 |
+
--hash=sha256:9bdf097d5bd1d8ce42dfee51f6ff05f5578b96e48c6f6006aa4eff69edfa3639 \
|
| 29 |
+
--hash=sha256:a336575463d75d3aac1f7e32470b8998643ccd9a73786bd726f6b0470520b6b4 \
|
| 30 |
+
--hash=sha256:a697257a2cd7581732d7741a8d32a06927f0311c3d277dbc47fa1043350c9d17 \
|
| 31 |
+
--hash=sha256:a92e1932ee8fd500680ccbe1bf53eb33228f4c9d6524ed6f300bcc80ac359f27 \
|
| 32 |
+
--hash=sha256:aeb090ad462833df03af1debce4ae607a2766ef861f992003ad0c56d074ab805 \
|
| 33 |
+
--hash=sha256:b1c24c1d9405b2148184ff27c062493d5e3be5c144575f95b5a0d7c660a515af \
|
| 34 |
+
--hash=sha256:b77d27f59d515c43b61745b8173fbe7c7b3014b14b3702a75bf1793471e7def6 \
|
| 35 |
+
--hash=sha256:b8b1dd2712f8a7de5b4c8ec912e6c041d25750bf03e1ce325cdba43bae0944ae \
|
| 36 |
+
--hash=sha256:bedf0355117fb4e9b1fc9fc92b4d5ee743a7d468be9f6196e3b94447710ea589 \
|
| 37 |
+
--hash=sha256:cc969e6694fb27fba7cee2953f350804faf03913f25ae1ee713a7b8a1bc08018 \
|
| 38 |
+
--hash=sha256:d45e3f78e746aa161bc9f5a31c6a2839c512101113a4065f4d2e7a3ab8198d8c \
|
| 39 |
+
--hash=sha256:d501713a8396193883aa526f48dc609f5f031a5df1afbafa561cf9ab492ffc76 \
|
| 40 |
+
--hash=sha256:d954d25a8705f972e8bfc1dea5464d7e697dd6f4ade092f1a487387e6d6c829a \
|
| 41 |
+
--hash=sha256:dadccb2e49244b6e64b4527d13ec14d5e094a90b41cf9b963e457e64182f1941 \
|
| 42 |
+
--hash=sha256:e811984b0908c14c56de7d8226fdd494d87a7ccb75af8ac3a07423037aaafc35 \
|
| 43 |
+
--hash=sha256:e88354b61f59dfdeb41023f7be8ae31dc627c2dc2dacbc2de8b2d82a0997135c \
|
| 44 |
+
--hash=sha256:e8ec5bb6777e2060e1499750c50e1b69dca5a0f80f90f2c66656c5f3e5244593 \
|
| 45 |
+
--hash=sha256:e9e9fe8094ca57549d801e9a2017ac5c24108bbf485ea4f8994a72e8e96ee135 \
|
| 46 |
+
--hash=sha256:eba0471ab0bb2e07ed06d91ecf5185d402c83d194155a41d8e2aa547d187712e \
|
| 47 |
+
--hash=sha256:ef59ba19340dc1d002ce5713b911c0ef23c577b08f8ed57998ee3c8e62c5bf6e \
|
| 48 |
+
--hash=sha256:f8c90df663cd9759b2cf8dd29998b63140ac39e51ada2e739dc13bdac0b4f001 \
|
| 49 |
+
--hash=sha256:f8cb24d8d0b2f8b7463815a59183eb81ec1d7a06e3217bed456063f3303eddfb \
|
| 50 |
+
--hash=sha256:fd907a8f744e5337de7fc532dd800c4416b571ea47f8c3c66be10cd1bc67c925 \
|
| 51 |
+
--hash=sha256:ff7d752a7f82d87711ec1a95c2262cb74f98be5b457f0300d81a1aefe5be2a95
|
| 52 |
+
tqdm==4.62.0 \
|
| 53 |
+
--hash=sha256:3642d483b558eec80d3c831e23953582c34d7e4540db86d9e5ed9dad238dabc6 \
|
| 54 |
+
--hash=sha256:706dea48ee05ba16e936ee91cb3791cd2ea6da348a0e50b46863ff4363ff4340
|
| 55 |
+
PyYAML==5.4.1 \
|
| 56 |
+
--hash=sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf \
|
| 57 |
+
--hash=sha256:0f5f5786c0e09baddcd8b4b45f20a7b5d61a7e7e99846e3c799b05c7c53fa696 \
|
| 58 |
+
--hash=sha256:129def1b7c1bf22faffd67b8f3724645203b79d8f4cc81f674654d9902cb4393 \
|
| 59 |
+
--hash=sha256:294db365efa064d00b8d1ef65d8ea2c3426ac366c0c4368d930bf1c5fb497f77 \
|
| 60 |
+
--hash=sha256:3b2b1824fe7112845700f815ff6a489360226a5609b96ec2190a45e62a9fc922 \
|
| 61 |
+
--hash=sha256:3bd0e463264cf257d1ffd2e40223b197271046d09dadf73a0fe82b9c1fc385a5 \
|
| 62 |
+
--hash=sha256:4465124ef1b18d9ace298060f4eccc64b0850899ac4ac53294547536533800c8 \
|
| 63 |
+
--hash=sha256:49d4cdd9065b9b6e206d0595fee27a96b5dd22618e7520c33204a4a3239d5b10 \
|
| 64 |
+
--hash=sha256:4e0583d24c881e14342eaf4ec5fbc97f934b999a6828693a99157fde912540cc \
|
| 65 |
+
--hash=sha256:5accb17103e43963b80e6f837831f38d314a0495500067cb25afab2e8d7a4018 \
|
| 66 |
+
--hash=sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e \
|
| 67 |
+
--hash=sha256:6c78645d400265a062508ae399b60b8c167bf003db364ecb26dcab2bda048253 \
|
| 68 |
+
--hash=sha256:72a01f726a9c7851ca9bfad6fd09ca4e090a023c00945ea05ba1638c09dc3347 \
|
| 69 |
+
--hash=sha256:74c1485f7707cf707a7aef42ef6322b8f97921bd89be2ab6317fd782c2d53183 \
|
| 70 |
+
--hash=sha256:895f61ef02e8fed38159bb70f7e100e00f471eae2bc838cd0f4ebb21e28f8541 \
|
| 71 |
+
--hash=sha256:8c1be557ee92a20f184922c7b6424e8ab6691788e6d86137c5d93c1a6ec1b8fb \
|
| 72 |
+
--hash=sha256:bb4191dfc9306777bc594117aee052446b3fa88737cd13b7188d0e7aa8162185 \
|
| 73 |
+
--hash=sha256:bfb51918d4ff3d77c1c856a9699f8492c612cde32fd3bcd344af9be34999bfdc \
|
| 74 |
+
--hash=sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db \
|
| 75 |
+
--hash=sha256:cb333c16912324fd5f769fff6bc5de372e9e7a202247b48870bc251ed40239aa \
|
| 76 |
+
--hash=sha256:d2d9808ea7b4af864f35ea216be506ecec180628aced0704e34aca0b040ffe46 \
|
| 77 |
+
--hash=sha256:d483ad4e639292c90170eb6f7783ad19490e7a8defb3e46f97dfe4bacae89122 \
|
| 78 |
+
--hash=sha256:dd5de0646207f053eb0d6c74ae45ba98c3395a571a2891858e87df7c9b9bd51b \
|
| 79 |
+
--hash=sha256:e1d4970ea66be07ae37a3c2e48b5ec63f7ba6804bdddfdbd3cfd954d25a82e63 \
|
| 80 |
+
--hash=sha256:e4fac90784481d221a8e4b1162afa7c47ed953be40d31ab4629ae917510051df \
|
| 81 |
+
--hash=sha256:fa5ae20527d8e831e8230cbffd9f8fe952815b2b7dae6ffec25318803a7528fc \
|
| 82 |
+
--hash=sha256:fd7f6999a8070df521b6384004ef42833b9bd62cfee11a09bda1079b4b704247 \
|
| 83 |
+
--hash=sha256:fdc842473cd33f45ff6bce46aea678a54e3d21f1b61a7750ce3c498eedfe25d6 \
|
| 84 |
+
--hash=sha256:fe69978f3f768926cfa37b867e3843918e012cf83f680806599ddce33c2c68b0
|
| 85 |
+
h5py==3.3.0 \
|
| 86 |
+
--hash=sha256:09e78cefdef0b7566ab66366c5c7d9984c7b23142245bd51b82b744ad1eebf65 \
|
| 87 |
+
--hash=sha256:13355234c004ff8bd819f7d3420188aa1936b17d7f8470d622974a373421b7a5 \
|
| 88 |
+
--hash=sha256:5e2f22e66a3fb1815405cfe5711670450c973b8552507c535a546a23a468af3d \
|
| 89 |
+
--hash=sha256:7ca7d23ebbdd59a4be9b4820de52fe67adc74e6a44d5084881305461765aac47 \
|
| 90 |
+
--hash=sha256:89d7e10409b62fed81c571e35798763cb8375442b98f8ebfc52ba41ac019e081 \
|
| 91 |
+
--hash=sha256:8e09b682e4059c8cd259ddcc34bee35d639b9170105efeeae6ad195e7c1cea7a \
|
| 92 |
+
--hash=sha256:baef1a2cdef287a83e7f95ce9e0f4d762a9852fe7117b471063442c78b973695 \
|
| 93 |
+
--hash=sha256:e0dac887d779929778b3cfd13309a939359cc9e74756fc09af7c527a82797186 \
|
| 94 |
+
--hash=sha256:e0ea3330bf136f8213e43db67448994046ce501585dddc7ea4e8ceef0ef1600c \
|
| 95 |
+
--hash=sha256:f3bba8ffddd1fd2bf06127c5ff7b73f022cc1c8b7164355ddc760dc3f8570136
|
| 96 |
+
yamlargparse==1.31.1 \
|
| 97 |
+
--hash=sha256:2c09fc8e20c147d074f765512b880757a6fea669d57a3dc672a5e1be6c68c667
|
| 98 |
+
sklearn==0.0 \
|
| 99 |
+
--hash=sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31
|
| 100 |
+
matplotlib==3.4.2 \
|
| 101 |
+
--hash=sha256:0bea5ec5c28d49020e5d7923c2725b837e60bc8be99d3164af410eb4b4c827da \
|
| 102 |
+
--hash=sha256:1c1779f7ab7d8bdb7d4c605e6ffaa0614b3e80f1e3c8ccf7b9269a22dbc5986b \
|
| 103 |
+
--hash=sha256:21b31057bbc5e75b08e70a43cefc4c0b2c2f1b1a850f4a0f7af044eb4163086c \
|
| 104 |
+
--hash=sha256:32fa638cc10886885d1ca3d409d4473d6a22f7ceecd11322150961a70fab66dd \
|
| 105 |
+
--hash=sha256:3a5c18dbd2c7c366da26a4ad1462fe3e03a577b39e3b503bbcf482b9cdac093c \
|
| 106 |
+
--hash=sha256:5826f56055b9b1c80fef82e326097e34dc4af8c7249226b7dd63095a686177d1 \
|
| 107 |
+
--hash=sha256:6382bc6e2d7e481bcd977eb131c31dee96e0fb4f9177d15ec6fb976d3b9ace1a \
|
| 108 |
+
--hash=sha256:6475d0209024a77f869163ec3657c47fed35d9b6ed8bccba8aa0f0099fbbdaa8 \
|
| 109 |
+
--hash=sha256:6a6a44f27aabe720ec4fd485061e8a35784c2b9ffa6363ad546316dfc9cea04e \
|
| 110 |
+
--hash=sha256:7a58f3d8fe8fac3be522c79d921c9b86e090a59637cb88e3bc51298d7a2c862a \
|
| 111 |
+
--hash=sha256:7ad19f3fb6145b9eb41c08e7cbb9f8e10b91291396bee21e9ce761bb78df63ec \
|
| 112 |
+
--hash=sha256:85f191bb03cb1a7b04b5c2cca4792bef94df06ef473bc49e2818105671766fee \
|
| 113 |
+
--hash=sha256:956c8849b134b4a343598305a3ca1bdd3094f01f5efc8afccdebeffe6b315247 \
|
| 114 |
+
--hash=sha256:a9d8cb5329df13e0cdaa14b3b43f47b5e593ec637f13f14db75bb16e46178b05 \
|
| 115 |
+
--hash=sha256:b1d5a2cedf5de05567c441b3a8c2651fbde56df08b82640e7f06c8cd91e201f6 \
|
| 116 |
+
--hash=sha256:b26535b9de85326e6958cdef720ecd10bcf74a3f4371bf9a7e5b2e659c17e153 \
|
| 117 |
+
--hash=sha256:c541ee5a3287efe066bbe358320853cf4916bc14c00c38f8f3d8d75275a405a9 \
|
| 118 |
+
--hash=sha256:d8d994cefdff9aaba45166eb3de4f5211adb4accac85cbf97137e98f26ea0219 \
|
| 119 |
+
--hash=sha256:df815378a754a7edd4559f8c51fc7064f779a74013644a7f5ac7a0c31f875866
|
| 120 |
+
torchaudio==0.9.0 \
|
| 121 |
+
--hash=sha256:0a387e78eeaf6e0abd36df70e9d8a15d242b49c2507dbd9522568f5d4af5fb96 \
|
| 122 |
+
--hash=sha256:18763c05cb7d85a08b8ea960e40f6984e9513b02e76f4526d920493c701b0671 \
|
| 123 |
+
--hash=sha256:48e33bb96b7ff2dc10a778a695429dbd6dfc8c8baa0d7c9b63569cb002bb87cd \
|
| 124 |
+
--hash=sha256:62fd9393ddbe40aadaabef7595f5bff0057e39f7e519195a010731542815f5a4 \
|
| 125 |
+
--hash=sha256:76a5b8ea0e4ddafd5b8f24abdf1a6f7afe847d892570da13cf0fc9bceeac437f \
|
| 126 |
+
--hash=sha256:87520525da10b5f00d3e5e1180db6ee37b1fa305edb2260c7335e0859dbe634e \
|
| 127 |
+
--hash=sha256:9d3f5d6df7d91676e67a38a448253b74d77da723f8e24bd833ff7ed0f82fa4ef \
|
| 128 |
+
--hash=sha256:acf0d736a5c1ea6b94adf08b0a31670009b6e78dfe50a1b0bdabf2b0f7895dc0 \
|
| 129 |
+
--hash=sha256:ad221258fc5d1d446f2c1ce9a1bb54cc05ca2b208491d4eaa5af443f1c0f16a2 \
|
| 130 |
+
--hash=sha256:ba52ae64611773bec7fc664c29f9ea3e02c9e5c817693726b978ed1bdedd07f2 \
|
| 131 |
+
--hash=sha256:c6126556d529df73b676e023063388d551be3c0cb2d42a4ff5c4cfd44ef3e012 \
|
| 132 |
+
--hash=sha256:ef5f0b22646a94f95869001b40ab940468b1ae399d0ffd3bc73d5c43342a013a \
|
| 133 |
+
--hash=sha256:ef8dc4ab1ec807382a713e71e8493d1985930537c933273e3c0739f02183cedc \
|
| 134 |
+
--hash=sha256:efb16c593b2a5ada07b180580c7612617e84f4714ce86928ad54baefe71ef29d
|
| 135 |
+
s3prl==0.3.1 \
|
| 136 |
+
--hash=sha256:e497989b10d4e058b619cf3e7a547820fceb3fe18c14c566427eb7b8c770d62e
|
UniSpeech/downstreams/speaker_diarization/tmp/mix_0000496.wav
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/wav/dev_clean_2_ns3_beta2_500/100/mix_0000496.wav
|
UniSpeech/downstreams/speaker_diarization/utils/dataset.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*- #
|
| 2 |
+
"""*********************************************************************************************"""
|
| 3 |
+
# FileName [ dataset.py ]
|
| 4 |
+
# Synopsis [ the speaker diarization dataset ]
|
| 5 |
+
# Source [ Refactored from https://github.com/hitachi-speech/EEND ]
|
| 6 |
+
# Author [ Jiatong Shi ]
|
| 7 |
+
# Copyright [ Copyright(c), Johns Hopkins University ]
|
| 8 |
+
"""*********************************************************************************************"""
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
###############
|
| 12 |
+
# IMPORTATION #
|
| 13 |
+
###############
|
| 14 |
+
import io
|
| 15 |
+
import os
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
# -------------#
|
| 20 |
+
import numpy as np
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
import torch
|
| 23 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 24 |
+
|
| 25 |
+
# -------------#
|
| 26 |
+
from torch.utils.data.dataset import Dataset
|
| 27 |
+
# -------------#
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _count_frames(data_len, size, step):
|
| 31 |
+
# no padding at edges, last remaining samples are ignored
|
| 32 |
+
return int((data_len - size + step) / step)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _gen_frame_indices(data_length, size=2000, step=2000):
|
| 36 |
+
i = -1
|
| 37 |
+
for i in range(_count_frames(data_length, size, step)):
|
| 38 |
+
yield i * step, i * step + size
|
| 39 |
+
|
| 40 |
+
if i * step + size < data_length:
|
| 41 |
+
if data_length - (i + 1) * step > 0:
|
| 42 |
+
if i == -1:
|
| 43 |
+
yield (i + 1) * step, data_length
|
| 44 |
+
else:
|
| 45 |
+
yield data_length - size, data_length
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _gen_chunk_indices(data_len, chunk_size):
|
| 49 |
+
step = chunk_size
|
| 50 |
+
start = 0
|
| 51 |
+
while start < data_len:
|
| 52 |
+
end = min(data_len, start + chunk_size)
|
| 53 |
+
yield start, end
|
| 54 |
+
start += step
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
#######################
|
| 58 |
+
# Diarization Dataset #
|
| 59 |
+
#######################
|
| 60 |
+
class DiarizationDataset(Dataset):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
mode,
|
| 64 |
+
data_dir,
|
| 65 |
+
chunk_size=2000,
|
| 66 |
+
frame_shift=256,
|
| 67 |
+
sampling_rate=16000,
|
| 68 |
+
subsampling=1,
|
| 69 |
+
use_last_samples=True,
|
| 70 |
+
num_speakers=3,
|
| 71 |
+
filter_spk=False
|
| 72 |
+
):
|
| 73 |
+
super(DiarizationDataset, self).__init__()
|
| 74 |
+
|
| 75 |
+
self.mode = mode
|
| 76 |
+
self.data_dir = data_dir
|
| 77 |
+
self.chunk_size = chunk_size
|
| 78 |
+
self.frame_shift = frame_shift
|
| 79 |
+
self.subsampling = subsampling
|
| 80 |
+
self.n_speakers = num_speakers
|
| 81 |
+
self.chunk_indices = [] if mode != "test" else {}
|
| 82 |
+
|
| 83 |
+
self.data = KaldiData(self.data_dir)
|
| 84 |
+
self.all_speakers = sorted(self.data.spk2utt.keys())
|
| 85 |
+
self.all_n_speakers = len(self.all_speakers)
|
| 86 |
+
|
| 87 |
+
# make chunk indices: filepath, start_frame, end_frame
|
| 88 |
+
for rec in self.data.wavs:
|
| 89 |
+
data_len = int(self.data.reco2dur[rec] * sampling_rate / frame_shift)
|
| 90 |
+
data_len = int(data_len / self.subsampling)
|
| 91 |
+
if mode == "test":
|
| 92 |
+
self.chunk_indices[rec] = []
|
| 93 |
+
if mode != "test":
|
| 94 |
+
for st, ed in _gen_frame_indices(data_len, chunk_size, chunk_size):
|
| 95 |
+
self.chunk_indices.append(
|
| 96 |
+
(rec, st * self.subsampling, ed * self.subsampling)
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
for st, ed in _gen_chunk_indices(data_len, chunk_size):
|
| 100 |
+
self.chunk_indices[rec].append(
|
| 101 |
+
(rec, st, ed)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if mode != "test":
|
| 105 |
+
if filter_spk:
|
| 106 |
+
self.filter_spk()
|
| 107 |
+
print(len(self.chunk_indices), " chunks")
|
| 108 |
+
else:
|
| 109 |
+
self.rec_list = list(self.chunk_indices.keys())
|
| 110 |
+
print(len(self.rec_list), " recordings")
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
return (
|
| 114 |
+
len(self.rec_list)
|
| 115 |
+
if type(self.chunk_indices) == dict
|
| 116 |
+
else len(self.chunk_indices)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def filter_spk(self):
|
| 120 |
+
# filter the spk in spk2utt but will not be used in training
|
| 121 |
+
# i.e. the chunks contains more spk than self.n_speakers
|
| 122 |
+
occur_spk_set = set()
|
| 123 |
+
|
| 124 |
+
new_chunk_indices = [] # filter the chunk that more than self.num_speakers
|
| 125 |
+
for idx in range(self.__len__()):
|
| 126 |
+
rec, st, ed = self.chunk_indices[idx]
|
| 127 |
+
|
| 128 |
+
filtered_segments = self.data.segments[rec]
|
| 129 |
+
# all the speakers in this recording not the chunk
|
| 130 |
+
speakers = np.unique(
|
| 131 |
+
[self.data.utt2spk[seg['utt']] for seg in filtered_segments]
|
| 132 |
+
).tolist()
|
| 133 |
+
n_speakers = self.n_speakers
|
| 134 |
+
# we assume that in each chunk the speaker number is less or equal than self.n_speakers
|
| 135 |
+
# but the speaker number in the whole recording may exceed self.n_speakers
|
| 136 |
+
if self.n_speakers < len(speakers):
|
| 137 |
+
n_speakers = len(speakers)
|
| 138 |
+
|
| 139 |
+
# Y: (length,), T: (frame_num, n_speakers)
|
| 140 |
+
Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
|
| 141 |
+
# the spk index exist in this chunk data
|
| 142 |
+
exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
|
| 143 |
+
chunk_spk_num = np.sum(exist_spk_idx)
|
| 144 |
+
if chunk_spk_num <= self.n_speakers:
|
| 145 |
+
spk_arr = np.array(speakers)
|
| 146 |
+
valid_spk_arr = spk_arr[exist_spk_idx[:spk_arr.shape[0]]]
|
| 147 |
+
for spk in valid_spk_arr:
|
| 148 |
+
occur_spk_set.add(spk)
|
| 149 |
+
|
| 150 |
+
new_chunk_indices.append((rec, st, ed))
|
| 151 |
+
self.chunk_indices = new_chunk_indices
|
| 152 |
+
self.all_speakers = sorted(list(occur_spk_set))
|
| 153 |
+
self.all_n_speakers = len(self.all_speakers)
|
| 154 |
+
|
| 155 |
+
def __getitem__(self, i):
|
| 156 |
+
if self.mode != "test":
|
| 157 |
+
rec, st, ed = self.chunk_indices[i]
|
| 158 |
+
|
| 159 |
+
filtered_segments = self.data.segments[rec]
|
| 160 |
+
# all the speakers in this recording not the chunk
|
| 161 |
+
speakers = np.unique(
|
| 162 |
+
[self.data.utt2spk[seg['utt']] for seg in filtered_segments]
|
| 163 |
+
).tolist()
|
| 164 |
+
n_speakers = self.n_speakers
|
| 165 |
+
# we assume that in each chunk the speaker number is less or equal than self.n_speakers
|
| 166 |
+
# but the speaker number in the whole recording may exceed self.n_speakers
|
| 167 |
+
if self.n_speakers < len(speakers):
|
| 168 |
+
n_speakers = len(speakers)
|
| 169 |
+
|
| 170 |
+
# Y: (length,), T: (frame_num, n_speakers)
|
| 171 |
+
Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
|
| 172 |
+
# the spk index exist in this chunk data
|
| 173 |
+
exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
|
| 174 |
+
chunk_spk_num = np.sum(exist_spk_idx)
|
| 175 |
+
if chunk_spk_num > self.n_speakers:
|
| 176 |
+
# the speaker number in a chunk exceed our pre-set value
|
| 177 |
+
return None, None, None
|
| 178 |
+
|
| 179 |
+
# the map from within recording speaker index to global speaker index
|
| 180 |
+
S_arr = -1 * np.ones(n_speakers).astype(np.int64)
|
| 181 |
+
for seg in filtered_segments:
|
| 182 |
+
speaker_index = speakers.index(self.data.utt2spk[seg['utt']])
|
| 183 |
+
try:
|
| 184 |
+
all_speaker_index = self.all_speakers.index(
|
| 185 |
+
self.data.utt2spk[seg['utt']])
|
| 186 |
+
except:
|
| 187 |
+
# we have pre-filter some spk in self.filter_spk
|
| 188 |
+
all_speaker_index = -1
|
| 189 |
+
S_arr[speaker_index] = all_speaker_index
|
| 190 |
+
# If T[:, n_speakers - 1] == 0.0, then S_arr[n_speakers - 1] == -1,
|
| 191 |
+
# so S_arr[n_speakers - 1] is not used for training,
|
| 192 |
+
# e.g., in the case of training 3-spk model with 2-spk data
|
| 193 |
+
|
| 194 |
+
# filter the speaker not exist in this chunk and ensure there are self.num_speakers outputs
|
| 195 |
+
T_exist = T[:,exist_spk_idx]
|
| 196 |
+
T = np.zeros((T_exist.shape[0], self.n_speakers), dtype=np.int32)
|
| 197 |
+
T[:,:T_exist.shape[1]] = T_exist
|
| 198 |
+
# subsampling for Y will be done in the model forward function
|
| 199 |
+
T = T[::self.subsampling]
|
| 200 |
+
|
| 201 |
+
S_arr_exist = S_arr[exist_spk_idx]
|
| 202 |
+
S_arr = -1 * np.ones(self.n_speakers).astype(np.int64)
|
| 203 |
+
S_arr[:S_arr_exist.shape[0]] = S_arr_exist
|
| 204 |
+
|
| 205 |
+
n = np.arange(self.all_n_speakers, dtype=np.int64).reshape(self.all_n_speakers, 1)
|
| 206 |
+
return Y, T, S_arr, n, T.shape[0]
|
| 207 |
+
else:
|
| 208 |
+
len_ratio = self.frame_shift * self.subsampling
|
| 209 |
+
chunks = self.chunk_indices[self.rec_list[i]]
|
| 210 |
+
Ys = []
|
| 211 |
+
chunk_len_list = []
|
| 212 |
+
for (rec, st, ed) in chunks:
|
| 213 |
+
chunk_len = ed - st
|
| 214 |
+
if chunk_len != self.chunk_size:
|
| 215 |
+
st = max(0, ed - self.chunk_size)
|
| 216 |
+
Y, _ = self.data.load_wav(rec, st * len_ratio, ed * len_ratio)
|
| 217 |
+
Ys.append(Y)
|
| 218 |
+
chunk_len_list.append(chunk_len)
|
| 219 |
+
return Ys, self.rec_list[i], chunk_len_list
|
| 220 |
+
|
| 221 |
+
def get_allnspk(self):
|
| 222 |
+
return self.all_n_speakers
|
| 223 |
+
|
| 224 |
+
def _get_labeled_speech(
|
| 225 |
+
self, rec, start, end, n_speakers=None, use_speaker_id=False
|
| 226 |
+
):
|
| 227 |
+
"""Extracts speech chunks and corresponding labels
|
| 228 |
+
|
| 229 |
+
Extracts speech chunks and corresponding diarization labels for
|
| 230 |
+
given recording id and start/end times
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
rec (str): recording id
|
| 234 |
+
start (int): start frame index
|
| 235 |
+
end (int): end frame index
|
| 236 |
+
n_speakers (int): number of speakers
|
| 237 |
+
if None, the value is given from data
|
| 238 |
+
Returns:
|
| 239 |
+
data: speech chunk
|
| 240 |
+
(n_samples)
|
| 241 |
+
T: label
|
| 242 |
+
(n_frmaes, n_speakers)-shaped np.int32 array.
|
| 243 |
+
"""
|
| 244 |
+
data, rate = self.data.load_wav(
|
| 245 |
+
rec, start * self.frame_shift, end * self.frame_shift
|
| 246 |
+
)
|
| 247 |
+
frame_num = end - start
|
| 248 |
+
filtered_segments = self.data.segments[rec]
|
| 249 |
+
# filtered_segments = self.data.segments[self.data.segments['rec'] == rec]
|
| 250 |
+
speakers = np.unique(
|
| 251 |
+
[self.data.utt2spk[seg["utt"]] for seg in filtered_segments]
|
| 252 |
+
).tolist()
|
| 253 |
+
if n_speakers is None:
|
| 254 |
+
n_speakers = len(speakers)
|
| 255 |
+
T = np.zeros((frame_num, n_speakers), dtype=np.int32)
|
| 256 |
+
|
| 257 |
+
if use_speaker_id:
|
| 258 |
+
all_speakers = sorted(self.data.spk2utt.keys())
|
| 259 |
+
S = np.zeros((frame_num, len(all_speakers)), dtype=np.int32)
|
| 260 |
+
|
| 261 |
+
for seg in filtered_segments:
|
| 262 |
+
speaker_index = speakers.index(self.data.utt2spk[seg["utt"]])
|
| 263 |
+
if use_speaker_id:
|
| 264 |
+
all_speaker_index = all_speakers.index(self.data.utt2spk[seg["utt"]])
|
| 265 |
+
start_frame = np.rint(seg["st"] * rate / self.frame_shift).astype(int)
|
| 266 |
+
end_frame = np.rint(seg["et"] * rate / self.frame_shift).astype(int)
|
| 267 |
+
rel_start = rel_end = None
|
| 268 |
+
if start <= start_frame and start_frame < end:
|
| 269 |
+
rel_start = start_frame - start
|
| 270 |
+
if start < end_frame and end_frame <= end:
|
| 271 |
+
rel_end = end_frame - start
|
| 272 |
+
if rel_start is not None or rel_end is not None:
|
| 273 |
+
T[rel_start:rel_end, speaker_index] = 1
|
| 274 |
+
if use_speaker_id:
|
| 275 |
+
S[rel_start:rel_end, all_speaker_index] = 1
|
| 276 |
+
|
| 277 |
+
if use_speaker_id:
|
| 278 |
+
return data, T, S
|
| 279 |
+
else:
|
| 280 |
+
return data, T
|
| 281 |
+
|
| 282 |
+
def collate_fn(self, batch):
|
| 283 |
+
valid_samples = [sample for sample in batch if sample[0] is not None]
|
| 284 |
+
|
| 285 |
+
wav_list, binary_label_list, spk_label_list= [], [], []
|
| 286 |
+
all_spk_idx_list, len_list = [], []
|
| 287 |
+
for sample in valid_samples:
|
| 288 |
+
wav_list.append(torch.from_numpy(sample[0]).float())
|
| 289 |
+
binary_label_list.append(torch.from_numpy(sample[1]).long())
|
| 290 |
+
spk_label_list.append(torch.from_numpy(sample[2]).long())
|
| 291 |
+
all_spk_idx_list.append(torch.from_numpy(sample[3]).long())
|
| 292 |
+
len_list.append(sample[4])
|
| 293 |
+
wav_batch = pad_sequence(wav_list, batch_first=True, padding_value=0.0)
|
| 294 |
+
binary_label_batch = pad_sequence(binary_label_list, batch_first=True, padding_value=1).long()
|
| 295 |
+
spk_label_batch = torch.stack(spk_label_list)
|
| 296 |
+
all_spk_idx_batch = torch.stack(all_spk_idx_list)
|
| 297 |
+
len_batch = torch.LongTensor(len_list)
|
| 298 |
+
|
| 299 |
+
return wav_batch, binary_label_batch.float(), spk_label_batch, all_spk_idx_batch, len_batch
|
| 300 |
+
|
| 301 |
+
def collate_fn_infer(self, batch):
|
| 302 |
+
assert len(batch) == 1 # each batch should contain one recording
|
| 303 |
+
Ys, rec, chunk_len_list = batch[0]
|
| 304 |
+
wav_list = [torch.from_numpy(Y).float() for Y in Ys]
|
| 305 |
+
|
| 306 |
+
return wav_list, rec, chunk_len_list
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
#######################
|
| 310 |
+
# Kaldi-style Dataset #
|
| 311 |
+
#######################
|
| 312 |
+
class KaldiData:
|
| 313 |
+
"""This class holds data in kaldi-style directory."""
|
| 314 |
+
|
| 315 |
+
def __init__(self, data_dir):
|
| 316 |
+
"""Load kaldi data directory."""
|
| 317 |
+
self.data_dir = data_dir
|
| 318 |
+
self.segments = self._load_segments_rechash(
|
| 319 |
+
os.path.join(self.data_dir, "segments")
|
| 320 |
+
)
|
| 321 |
+
self.utt2spk = self._load_utt2spk(os.path.join(self.data_dir, "utt2spk"))
|
| 322 |
+
self.wavs = self._load_wav_scp(os.path.join(self.data_dir, "wav.scp"))
|
| 323 |
+
self.reco2dur = self._load_reco2dur(os.path.join(self.data_dir, "reco2dur"))
|
| 324 |
+
self.spk2utt = self._load_spk2utt(os.path.join(self.data_dir, "spk2utt"))
|
| 325 |
+
|
| 326 |
+
def load_wav(self, recid, start=0, end=None):
|
| 327 |
+
"""Load wavfile given recid, start time and end time."""
|
| 328 |
+
data, rate = self._load_wav(self.wavs[recid], start, end)
|
| 329 |
+
return data, rate
|
| 330 |
+
|
| 331 |
+
def _load_segments(self, segments_file):
|
| 332 |
+
"""Load segments file as array."""
|
| 333 |
+
if not os.path.exists(segments_file):
|
| 334 |
+
return None
|
| 335 |
+
return np.loadtxt(
|
| 336 |
+
segments_file,
|
| 337 |
+
dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")],
|
| 338 |
+
ndmin=1,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def _load_segments_hash(self, segments_file):
|
| 342 |
+
"""Load segments file as dict with uttid index."""
|
| 343 |
+
ret = {}
|
| 344 |
+
if not os.path.exists(segments_file):
|
| 345 |
+
return None
|
| 346 |
+
for line in open(segments_file):
|
| 347 |
+
utt, rec, st, et = line.strip().split()
|
| 348 |
+
ret[utt] = (rec, float(st), float(et))
|
| 349 |
+
return ret
|
| 350 |
+
|
| 351 |
+
def _load_segments_rechash(self, segments_file):
|
| 352 |
+
"""Load segments file as dict with recid index."""
|
| 353 |
+
ret = {}
|
| 354 |
+
if not os.path.exists(segments_file):
|
| 355 |
+
return None
|
| 356 |
+
for line in open(segments_file):
|
| 357 |
+
utt, rec, st, et = line.strip().split()
|
| 358 |
+
if rec not in ret:
|
| 359 |
+
ret[rec] = []
|
| 360 |
+
ret[rec].append({"utt": utt, "st": float(st), "et": float(et)})
|
| 361 |
+
return ret
|
| 362 |
+
|
| 363 |
+
def _load_wav_scp(self, wav_scp_file):
|
| 364 |
+
"""Return dictionary { rec: wav_rxfilename }."""
|
| 365 |
+
if os.path.exists(wav_scp_file):
|
| 366 |
+
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
|
| 367 |
+
return {x[0]: x[1] for x in lines}
|
| 368 |
+
else:
|
| 369 |
+
wav_dir = os.path.join(self.data_dir, "wav")
|
| 370 |
+
return {
|
| 371 |
+
os.path.splitext(filename)[0]: os.path.join(wav_dir, filename)
|
| 372 |
+
for filename in sorted(os.listdir(wav_dir))
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
def _load_wav(self, wav_rxfilename, start=0, end=None):
|
| 376 |
+
"""This function reads audio file and return data in numpy.float32 array.
|
| 377 |
+
"lru_cache" holds recently loaded audio so that can be called
|
| 378 |
+
many times on the same audio file.
|
| 379 |
+
OPTIMIZE: controls lru_cache size for random access,
|
| 380 |
+
considering memory size
|
| 381 |
+
"""
|
| 382 |
+
if wav_rxfilename.endswith("|"):
|
| 383 |
+
# input piped command
|
| 384 |
+
p = subprocess.Popen(
|
| 385 |
+
wav_rxfilename[:-1],
|
| 386 |
+
shell=True,
|
| 387 |
+
stdout=subprocess.PIPE,
|
| 388 |
+
)
|
| 389 |
+
data, samplerate = sf.read(
|
| 390 |
+
io.BytesIO(p.stdout.read()),
|
| 391 |
+
dtype="float32",
|
| 392 |
+
)
|
| 393 |
+
# cannot seek
|
| 394 |
+
data = data[start:end]
|
| 395 |
+
elif wav_rxfilename == "-":
|
| 396 |
+
# stdin
|
| 397 |
+
data, samplerate = sf.read(sys.stdin, dtype="float32")
|
| 398 |
+
# cannot seek
|
| 399 |
+
data = data[start:end]
|
| 400 |
+
else:
|
| 401 |
+
# normal wav file
|
| 402 |
+
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
|
| 403 |
+
return data, samplerate
|
| 404 |
+
|
| 405 |
+
def _load_utt2spk(self, utt2spk_file):
|
| 406 |
+
"""Returns dictionary { uttid: spkid }."""
|
| 407 |
+
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
|
| 408 |
+
return {x[0]: x[1] for x in lines}
|
| 409 |
+
|
| 410 |
+
def _load_spk2utt(self, spk2utt_file):
|
| 411 |
+
"""Returns dictionary { spkid: list of uttids }."""
|
| 412 |
+
if not os.path.exists(spk2utt_file):
|
| 413 |
+
return None
|
| 414 |
+
lines = [line.strip().split() for line in open(spk2utt_file)]
|
| 415 |
+
return {x[0]: x[1:] for x in lines}
|
| 416 |
+
|
| 417 |
+
def _load_reco2dur(self, reco2dur_file):
|
| 418 |
+
"""Returns dictionary { recid: duration }."""
|
| 419 |
+
if not os.path.exists(reco2dur_file):
|
| 420 |
+
return None
|
| 421 |
+
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
|
| 422 |
+
return {x[0]: float(x[1]) for x in lines}
|
| 423 |
+
|
| 424 |
+
def _process_wav(self, wav_rxfilename, process):
|
| 425 |
+
"""This function returns preprocessed wav_rxfilename.
|
| 426 |
+
Args:
|
| 427 |
+
wav_rxfilename:
|
| 428 |
+
input
|
| 429 |
+
process:
|
| 430 |
+
command which can be connected via pipe, use stdin and stdout
|
| 431 |
+
Returns:
|
| 432 |
+
wav_rxfilename: output piped command
|
| 433 |
+
"""
|
| 434 |
+
if wav_rxfilename.endswith("|"):
|
| 435 |
+
# input piped command
|
| 436 |
+
return wav_rxfilename + process + "|"
|
| 437 |
+
# stdin "-" or normal file
|
| 438 |
+
return "cat {0} | {1} |".format(wav_rxfilename, process)
|
| 439 |
+
|
| 440 |
+
def _extract_segments(self, wavs, segments=None):
|
| 441 |
+
"""This function returns generator of segmented audio.
|
| 442 |
+
Yields (utterance id, numpy.float32 array).
|
| 443 |
+
TODO?: sampling rate is not converted.
|
| 444 |
+
"""
|
| 445 |
+
if segments is not None:
|
| 446 |
+
# segments should be sorted by rec-id
|
| 447 |
+
for seg in segments:
|
| 448 |
+
wav = wavs[seg["rec"]]
|
| 449 |
+
data, samplerate = self.load_wav(wav)
|
| 450 |
+
st_sample = np.rint(seg["st"] * samplerate).astype(int)
|
| 451 |
+
et_sample = np.rint(seg["et"] * samplerate).astype(int)
|
| 452 |
+
yield seg["utt"], data[st_sample:et_sample]
|
| 453 |
+
else:
|
| 454 |
+
# segments file not found,
|
| 455 |
+
# wav.scp is used as segmented audio list
|
| 456 |
+
for rec in wavs:
|
| 457 |
+
data, samplerate = self.load_wav(wavs[rec])
|
| 458 |
+
yield rec, data
|
| 459 |
+
|
| 460 |
+
if __name__ == "__main__":
|
| 461 |
+
args = {
|
| 462 |
+
'mode': 'train',
|
| 463 |
+
'data_dir': "/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/data/train_clean_5_ns3_beta2_500",
|
| 464 |
+
'chunk_size': 2001,
|
| 465 |
+
'frame_shift': 256,
|
| 466 |
+
'sampling_rate': 8000,
|
| 467 |
+
'num_speakers':3
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
torch.manual_seed(6)
|
| 471 |
+
dataset = DiarizationDataset(**args)
|
| 472 |
+
|
| 473 |
+
from torch.utils.data import DataLoader
|
| 474 |
+
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn)
|
| 475 |
+
data_iter = iter(dataloader)
|
| 476 |
+
# wav_batch, binary_label_batch, spk_label_batch, all_spk_idx_batch, len_batch = next(data_iter)
|
| 477 |
+
data = next(data_iter)
|
| 478 |
+
for val in data:
|
| 479 |
+
print(val.shape)
|
| 480 |
+
|
| 481 |
+
# from torch.utils.data import DataLoader
|
| 482 |
+
# dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn_infer)
|
| 483 |
+
# data_iter = iter(dataloader)
|
| 484 |
+
# wav_list, binary_label_list, rec = next(data_iter)
|
UniSpeech/downstreams/speaker_diarization/utils/kaldi_data.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
#
|
| 4 |
+
# This library provides utilities for kaldi-style data directory.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from __future__ import print_function
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import numpy as np
|
| 11 |
+
import subprocess
|
| 12 |
+
import soundfile as sf
|
| 13 |
+
import io
|
| 14 |
+
from functools import lru_cache
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_segments(segments_file):
|
| 18 |
+
""" load segments file as array """
|
| 19 |
+
if not os.path.exists(segments_file):
|
| 20 |
+
return None
|
| 21 |
+
return np.loadtxt(
|
| 22 |
+
segments_file,
|
| 23 |
+
dtype=[('utt', 'object'),
|
| 24 |
+
('rec', 'object'),
|
| 25 |
+
('st', 'f'),
|
| 26 |
+
('et', 'f')],
|
| 27 |
+
ndmin=1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_segments_hash(segments_file):
|
| 31 |
+
ret = {}
|
| 32 |
+
if not os.path.exists(segments_file):
|
| 33 |
+
return None
|
| 34 |
+
for line in open(segments_file):
|
| 35 |
+
utt, rec, st, et = line.strip().split()
|
| 36 |
+
ret[utt] = (rec, float(st), float(et))
|
| 37 |
+
return ret
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_segments_rechash(segments_file):
|
| 41 |
+
ret = {}
|
| 42 |
+
if not os.path.exists(segments_file):
|
| 43 |
+
return None
|
| 44 |
+
for line in open(segments_file):
|
| 45 |
+
utt, rec, st, et = line.strip().split()
|
| 46 |
+
if rec not in ret:
|
| 47 |
+
ret[rec] = []
|
| 48 |
+
ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
|
| 49 |
+
return ret
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_wav_scp(wav_scp_file):
|
| 53 |
+
""" return dictionary { rec: wav_rxfilename } """
|
| 54 |
+
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
|
| 55 |
+
return {x[0]: x[1] for x in lines}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@lru_cache(maxsize=1)
|
| 59 |
+
def load_wav(wav_rxfilename, start=0, end=None):
|
| 60 |
+
""" This function reads audio file and return data in numpy.float32 array.
|
| 61 |
+
"lru_cache" holds recently loaded audio so that can be called
|
| 62 |
+
many times on the same audio file.
|
| 63 |
+
OPTIMIZE: controls lru_cache size for random access,
|
| 64 |
+
considering memory size
|
| 65 |
+
"""
|
| 66 |
+
if wav_rxfilename.endswith('|'):
|
| 67 |
+
# input piped command
|
| 68 |
+
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
|
| 69 |
+
stdout=subprocess.PIPE)
|
| 70 |
+
data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
|
| 71 |
+
dtype='float32')
|
| 72 |
+
# cannot seek
|
| 73 |
+
data = data[start:end]
|
| 74 |
+
elif wav_rxfilename == '-':
|
| 75 |
+
# stdin
|
| 76 |
+
data, samplerate = sf.read(sys.stdin, dtype='float32')
|
| 77 |
+
# cannot seek
|
| 78 |
+
data = data[start:end]
|
| 79 |
+
else:
|
| 80 |
+
# normal wav file
|
| 81 |
+
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
|
| 82 |
+
return data, samplerate
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_utt2spk(utt2spk_file):
|
| 86 |
+
""" returns dictionary { uttid: spkid } """
|
| 87 |
+
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
|
| 88 |
+
return {x[0]: x[1] for x in lines}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def load_spk2utt(spk2utt_file):
|
| 92 |
+
""" returns dictionary { spkid: list of uttids } """
|
| 93 |
+
if not os.path.exists(spk2utt_file):
|
| 94 |
+
return None
|
| 95 |
+
lines = [line.strip().split() for line in open(spk2utt_file)]
|
| 96 |
+
return {x[0]: x[1:] for x in lines}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_reco2dur(reco2dur_file):
|
| 100 |
+
""" returns dictionary { recid: duration } """
|
| 101 |
+
if not os.path.exists(reco2dur_file):
|
| 102 |
+
return None
|
| 103 |
+
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
|
| 104 |
+
return {x[0]: float(x[1]) for x in lines}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def process_wav(wav_rxfilename, process):
|
| 108 |
+
""" This function returns preprocessed wav_rxfilename
|
| 109 |
+
Args:
|
| 110 |
+
wav_rxfilename: input
|
| 111 |
+
process: command which can be connected via pipe,
|
| 112 |
+
use stdin and stdout
|
| 113 |
+
Returns:
|
| 114 |
+
wav_rxfilename: output piped command
|
| 115 |
+
"""
|
| 116 |
+
if wav_rxfilename.endswith('|'):
|
| 117 |
+
# input piped command
|
| 118 |
+
return wav_rxfilename + process + "|"
|
| 119 |
+
else:
|
| 120 |
+
# stdin "-" or normal file
|
| 121 |
+
return "cat {} | {} |".format(wav_rxfilename, process)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def extract_segments(wavs, segments=None):
|
| 125 |
+
""" This function returns generator of segmented audio as
|
| 126 |
+
(utterance id, numpy.float32 array)
|
| 127 |
+
TODO?: sampling rate is not converted.
|
| 128 |
+
"""
|
| 129 |
+
if segments is not None:
|
| 130 |
+
# segments should be sorted by rec-id
|
| 131 |
+
for seg in segments:
|
| 132 |
+
wav = wavs[seg['rec']]
|
| 133 |
+
data, samplerate = load_wav(wav)
|
| 134 |
+
st_sample = np.rint(seg['st'] * samplerate).astype(int)
|
| 135 |
+
et_sample = np.rint(seg['et'] * samplerate).astype(int)
|
| 136 |
+
yield seg['utt'], data[st_sample:et_sample]
|
| 137 |
+
else:
|
| 138 |
+
# segments file not found,
|
| 139 |
+
# wav.scp is used as segmented audio list
|
| 140 |
+
for rec in wavs:
|
| 141 |
+
data, samplerate = load_wav(wavs[rec])
|
| 142 |
+
yield rec, data
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class KaldiData:
|
| 146 |
+
def __init__(self, data_dir):
|
| 147 |
+
self.data_dir = data_dir
|
| 148 |
+
self.segments = load_segments_rechash(
|
| 149 |
+
os.path.join(self.data_dir, 'segments'))
|
| 150 |
+
self.utt2spk = load_utt2spk(
|
| 151 |
+
os.path.join(self.data_dir, 'utt2spk'))
|
| 152 |
+
self.wavs = load_wav_scp(
|
| 153 |
+
os.path.join(self.data_dir, 'wav.scp'))
|
| 154 |
+
self.reco2dur = load_reco2dur(
|
| 155 |
+
os.path.join(self.data_dir, 'reco2dur'))
|
| 156 |
+
self.spk2utt = load_spk2utt(
|
| 157 |
+
os.path.join(self.data_dir, 'spk2utt'))
|
| 158 |
+
|
| 159 |
+
def load_wav(self, recid, start=0, end=None):
|
| 160 |
+
data, rate = load_wav(
|
| 161 |
+
self.wavs[recid], start, end)
|
| 162 |
+
return data, rate
|
UniSpeech/downstreams/speaker_diarization/utils/parse_options.sh
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
| 4 |
+
# Arnab Ghoshal, Karel Vesely
|
| 5 |
+
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
| 13 |
+
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
| 14 |
+
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
| 15 |
+
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
| 16 |
+
# See the Apache 2 License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Parse command-line options.
|
| 21 |
+
# To be sourced by another script (as in ". parse_options.sh").
|
| 22 |
+
# Option format is: --option-name arg
|
| 23 |
+
# and shell variable "option_name" gets set to value "arg."
|
| 24 |
+
# The exception is --help, which takes no arguments, but prints the
|
| 25 |
+
# $help_message variable (if defined).
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
###
|
| 29 |
+
### The --config file options have lower priority to command line
|
| 30 |
+
### options, so we need to import them first...
|
| 31 |
+
###
|
| 32 |
+
|
| 33 |
+
# Now import all the configs specified by command-line, in left-to-right order
|
| 34 |
+
for ((argpos=1; argpos<$#; argpos++)); do
|
| 35 |
+
if [ "${!argpos}" == "--config" ]; then
|
| 36 |
+
argpos_plus1=$((argpos+1))
|
| 37 |
+
config=${!argpos_plus1}
|
| 38 |
+
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
| 39 |
+
. $config # source the config file.
|
| 40 |
+
fi
|
| 41 |
+
done
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
###
|
| 45 |
+
### Now we process the command line options
|
| 46 |
+
###
|
| 47 |
+
while true; do
|
| 48 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 49 |
+
case "$1" in
|
| 50 |
+
# If the enclosing script is called with --help option, print the help
|
| 51 |
+
# message and exit. Scripts should put help messages in $help_message
|
| 52 |
+
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
| 53 |
+
else printf "$help_message\n" 1>&2 ; fi;
|
| 54 |
+
exit 0 ;;
|
| 55 |
+
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
| 56 |
+
exit 1 ;;
|
| 57 |
+
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
| 58 |
+
# then work out the variable name as $name, which will equal "foo_bar".
|
| 59 |
+
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
| 60 |
+
# Next we test whether the variable in question is undefned-- if so it's
|
| 61 |
+
# an invalid option and we die. Note: $0 evaluates to the name of the
|
| 62 |
+
# enclosing script.
|
| 63 |
+
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
| 64 |
+
# is undefined. We then have to wrap this test inside "eval" because
|
| 65 |
+
# foo_bar is itself inside a variable ($name).
|
| 66 |
+
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 67 |
+
|
| 68 |
+
oldval="`eval echo \\$$name`";
|
| 69 |
+
# Work out whether we seem to be expecting a Boolean argument.
|
| 70 |
+
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
| 71 |
+
was_bool=true;
|
| 72 |
+
else
|
| 73 |
+
was_bool=false;
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
| 77 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 78 |
+
eval $name=\"$2\";
|
| 79 |
+
|
| 80 |
+
# Check that Boolean-valued arguments are really Boolean.
|
| 81 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 82 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 83 |
+
exit 1;
|
| 84 |
+
fi
|
| 85 |
+
shift 2;
|
| 86 |
+
;;
|
| 87 |
+
*) break;
|
| 88 |
+
esac
|
| 89 |
+
done
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Check for an empty argument to the --cmd option, which can easily occur as a
|
| 93 |
+
# result of scripting errors.
|
| 94 |
+
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
true; # so this script returns exit code 0.
|
UniSpeech/downstreams/speaker_diarization/utils/utils.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import struct
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
import math
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import yaml
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ------------------------------ Logger ------------------------------
|
| 14 |
+
# log to console or a file
|
| 15 |
+
def get_logger(
|
| 16 |
+
name,
|
| 17 |
+
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
|
| 18 |
+
date_format="%Y-%m-%d %H:%M:%S",
|
| 19 |
+
file=False):
|
| 20 |
+
"""
|
| 21 |
+
Get python logger instance
|
| 22 |
+
"""
|
| 23 |
+
logger = logging.getLogger(name)
|
| 24 |
+
logger.setLevel(logging.INFO)
|
| 25 |
+
# file or console
|
| 26 |
+
handler = logging.StreamHandler() if not file else logging.FileHandler(
|
| 27 |
+
name)
|
| 28 |
+
handler.setLevel(logging.INFO)
|
| 29 |
+
formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
|
| 30 |
+
handler.setFormatter(formatter)
|
| 31 |
+
logger.addHandler(handler)
|
| 32 |
+
return logger
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# log to concole and file at the same time
|
| 36 |
+
def get_logger_2(
|
| 37 |
+
name,
|
| 38 |
+
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
|
| 39 |
+
date_format="%Y-%m-%d %H:%M:%S"):
|
| 40 |
+
logger = logging.getLogger(name)
|
| 41 |
+
logger.setLevel(logging.INFO)
|
| 42 |
+
|
| 43 |
+
# Create handlers
|
| 44 |
+
c_handler = logging.StreamHandler()
|
| 45 |
+
f_handler = logging.FileHandler(name)
|
| 46 |
+
c_handler.setLevel(logging.INFO)
|
| 47 |
+
f_handler.setLevel(logging.INFO)
|
| 48 |
+
|
| 49 |
+
# Create formatters and add it to handlers
|
| 50 |
+
c_format = logging.Formatter(fmt=format_str, datefmt=date_format)
|
| 51 |
+
f_format = logging.Formatter(fmt=format_str, datefmt=date_format)
|
| 52 |
+
c_handler.setFormatter(c_format)
|
| 53 |
+
f_handler.setFormatter(f_format)
|
| 54 |
+
|
| 55 |
+
# Add handlers to the logger
|
| 56 |
+
logger.addHandler(c_handler)
|
| 57 |
+
logger.addHandler(f_handler)
|
| 58 |
+
|
| 59 |
+
return logger
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ------------------------------ Logger ------------------------------
|
| 63 |
+
|
| 64 |
+
# ------------------------------ Pytorch Distributed Training ------------------------------
|
| 65 |
+
def getoneNode():
|
| 66 |
+
nodelist = os.environ['SLURM_JOB_NODELIST']
|
| 67 |
+
nodelist = nodelist.strip().split(',')[0]
|
| 68 |
+
import re
|
| 69 |
+
text = re.split('[-\[\]]', nodelist)
|
| 70 |
+
if ('' in text):
|
| 71 |
+
text.remove('')
|
| 72 |
+
return text[0] + '-' + text[1] + '-' + text[2]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def dist_init(host_addr, rank, local_rank, world_size, port=23456):
|
| 76 |
+
host_addr_full = 'tcp://' + host_addr + ':' + str(port)
|
| 77 |
+
dist.init_process_group("nccl", init_method=host_addr_full,
|
| 78 |
+
rank=rank, world_size=world_size)
|
| 79 |
+
num_gpus = torch.cuda.device_count()
|
| 80 |
+
# torch.cuda.set_device(local_rank)
|
| 81 |
+
assert dist.is_initialized()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def cleanup():
|
| 85 |
+
dist.destroy_process_group()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def average_gradients(model, world_size):
|
| 89 |
+
size = float(world_size)
|
| 90 |
+
for param in model.parameters():
|
| 91 |
+
if (param.requires_grad and param.grad is not None):
|
| 92 |
+
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
| 93 |
+
param.grad.data /= size
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def data_reduce(data):
|
| 97 |
+
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
| 98 |
+
return data / torch.distributed.get_world_size()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ------------------------------ Pytorch Distributed Training ------------------------------
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ------------------------------ Hyper-parameter Dynamic Change ------------------------------
|
| 105 |
+
def reduce_lr(optimizer, initial_lr, final_lr, current_iter, max_iter, coeff=1.0):
|
| 106 |
+
current_lr = coeff * math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
|
| 107 |
+
for param_group in optimizer.param_groups:
|
| 108 |
+
param_group['lr'] = current_lr
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_reduce_lr(initial_lr, final_lr, current_iter, max_iter):
|
| 112 |
+
current_lr = math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
|
| 113 |
+
return current_lr
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def set_lr(optimizer, lr):
|
| 117 |
+
for param_group in optimizer.param_groups:
|
| 118 |
+
param_group['lr'] = lr
|
| 119 |
+
|
| 120 |
+
# ------------------------------ Hyper-parameter Dynamic Change ------------------------------
|
| 121 |
+
|
| 122 |
+
# ---------------------- About Configuration --------------------
|
| 123 |
+
def parse_config_or_kwargs(config_file, **kwargs):
|
| 124 |
+
with open(config_file) as con_read:
|
| 125 |
+
yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
|
| 126 |
+
# passed kwargs will override yaml config
|
| 127 |
+
return dict(yaml_config, **kwargs)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def store_yaml(config_file, store_path, **kwargs):
|
| 131 |
+
with open(config_file, 'r') as f:
|
| 132 |
+
config_lines = f.readlines()
|
| 133 |
+
|
| 134 |
+
keys_list = list(kwargs.keys())
|
| 135 |
+
with open(store_path, 'w') as f:
|
| 136 |
+
for line in config_lines:
|
| 137 |
+
if ':' in line and line.split(':')[0] in keys_list:
|
| 138 |
+
key = line.split(':')[0]
|
| 139 |
+
line = '{}: {}\n'.format(key, kwargs[key])
|
| 140 |
+
f.write(line)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ---------------------- About Configuration --------------------
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def check_dir(dir):
|
| 147 |
+
if not os.path.exists(dir):
|
| 148 |
+
os.mkdir(dir)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def set_seed(seed=66):
|
| 152 |
+
np.random.seed(seed)
|
| 153 |
+
random.seed(seed)
|
| 154 |
+
|
| 155 |
+
torch.manual_seed(seed)
|
| 156 |
+
torch.cuda.manual_seed(seed)
|
| 157 |
+
torch.cuda.manual_seed_all(seed)
|
| 158 |
+
|
| 159 |
+
# torch.backends.cudnn.deterministic = True
|
| 160 |
+
# torch.backends.cudnn.benchmark = False
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# when store the model wrongly with "module" involved,
|
| 164 |
+
# we remove it here
|
| 165 |
+
def correct_key(state_dict):
|
| 166 |
+
keys = list(state_dict.keys())
|
| 167 |
+
if 'module' not in keys[0]:
|
| 168 |
+
return state_dict
|
| 169 |
+
else:
|
| 170 |
+
new_state_dict = {}
|
| 171 |
+
for key in keys:
|
| 172 |
+
new_key = '.'.join(key.split('.')[1:])
|
| 173 |
+
new_state_dict[new_key] = state_dict[key]
|
| 174 |
+
return new_state_dict
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def validate_path(dir_name):
|
| 178 |
+
"""
|
| 179 |
+
:param dir_name: Create the directory if it doesn't exist
|
| 180 |
+
:return: None
|
| 181 |
+
"""
|
| 182 |
+
dir_name = os.path.dirname(dir_name) # get the path
|
| 183 |
+
if not os.path.exists(dir_name) and (dir_name != ''):
|
| 184 |
+
os.makedirs(dir_name)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_lr(optimizer):
|
| 188 |
+
for param_group in optimizer.param_groups:
|
| 189 |
+
return param_group['lr']
|
UniSpeech/downstreams/speaker_verification/README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Pre-training Representations for Speaker Verification
|
| 2 |
+
|
| 3 |
+
### Pre-trained models
|
| 4 |
+
|
| 5 |
+
| Model | Fix pre-train | Vox1-O | Vox1-E | Vox1-H |
|
| 6 |
+
| ------------------------------------------------------------ | ------------- | --------- | --------- | -------- |
|
| 7 |
+
| [ECAPA-TDNN](https://drive.google.com/file/d/1kWmLyTGkBExTdxtwmrXoP4DhWz_7ZAv3/view?usp=sharing) | - | 1.080 | 1.200 | 2.127 |
|
| 8 |
+
| [HuBERT large](https://1drv.ms/u/s!AqeByhGUtINrgcsDGB2GnI61eMcCWA?e=olWmeG) | Yes | 0.888 | 0.912 | 1.853 |
|
| 9 |
+
| [Wav2Vec2.0 (XLSR)](https://1drv.ms/u/s!AqeByhGUtINrgcsCVGF9JpPlxNkVKg?e=KknWbW) | Yes | 0.915 | 0.945 | 1.895 |
|
| 10 |
+
| [UniSpeech-SAT large](https://1drv.ms/u/s!AqeByhGUtINrgcsEXX-_TS5VBxIF-Q?e=nRfkcT) | Yes | 0.771 | 0.781 | 1.669 |
|
| 11 |
+
| [WavLM Base](https://1drv.ms/u/s!AqeByhGUtINrgcsBsAf8EJT7x-HMhA?e=gXjLRD) | Yes | 0.84 | 0.928 | 1.758 |
|
| 12 |
+
| [**WavLM large**](https://1drv.ms/u/s!AqeByhGUtINrgcp_7CsbcBjYW2Tr-w?e=VeCMic) | Yes | 0.75 | 0.764 | 1.548 |
|
| 13 |
+
| [HuBERT large](https://drive.google.com/file/d/1nit9Z6RyM8Sdb3n8ccaglOQVNnqsjnui/view?usp=sharing) | No | 0.585 | 0.654 | 1.342 |
|
| 14 |
+
| [Wav2Vec2.0 (XLSR)](https://drive.google.com/file/d/1TgKro9pp197TCgIF__IlE_rMVQOk50Eb/view?usp=sharing) | No | 0.564 | 0.605 | 1.23 |
|
| 15 |
+
| [UniSpeech-SAT large](https://drive.google.com/file/d/10o6NHZsPXJn2k8n57e8Z_FkKh3V4TC3g/view?usp=sharing) | No | 0.564 | 0.561 | 1.23 |
|
| 16 |
+
| [**WavLM large**](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view?usp=sharing) | No | **0.431** | **0.538** | **1.154** |
|
| 17 |
+
|
| 18 |
+
### How to use?
|
| 19 |
+
|
| 20 |
+
#### Environment Setup
|
| 21 |
+
|
| 22 |
+
1. `pip install --require-hashes -r requirements.txt`
|
| 23 |
+
2. Install fairseq code
|
| 24 |
+
- For HuBERT_Large and Wav2Vec2.0 (XLSR), we should install the official [fairseq](https://github.com/pytorch/fairseq).
|
| 25 |
+
- For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code.
|
| 26 |
+
- For WavLM, we should install the latest s3prl: `pip install s3prl@git+https://github.com/s3prl/s3prl.git@7ab62aaf2606d83da6c71ee74e7d16e0979edbc3#egg=s3prl`
|
| 27 |
+
|
| 28 |
+
#### Example
|
| 29 |
+
|
| 30 |
+
Take `unispeech_sat ` and `ecapa_tdnn` for example:
|
| 31 |
+
|
| 32 |
+
1. First, you should download the pre-trained model in the above table to `checkpoint_path`.
|
| 33 |
+
2. Then, run the following codes:
|
| 34 |
+
- The wav files are sampled from [voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html).
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
python verification.py --model_name unispeech_sat --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav --checkpoint $checkpoint_path
|
| 38 |
+
# output: The similarity score between two audios is 0.0317 (-1.0, 1.0).
|
| 39 |
+
|
| 40 |
+
python verification.py --model_name unispeech_sat --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav --checkpoint --checkpoint $checkpoint_path
|
| 41 |
+
# output: The similarity score between two audios is 0.5389 (-1.0, 1.0).
|
| 42 |
+
|
| 43 |
+
python verification.py --model_name ecapa_tdnn --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav --checkpoint $checkpoint_path
|
| 44 |
+
# output: The similarity score between two audios is 0.2053 (-1.0, 1.0).
|
| 45 |
+
|
| 46 |
+
python verification.py --model_name ecapa_tdnn --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav --checkpoint --checkpoint $checkpoint_path
|
| 47 |
+
# output: he similarity score between two audios is 0.5302 (-1.0, 1.0).
|
| 48 |
+
```
|
| 49 |
+
|
UniSpeech/downstreams/speaker_verification/config/unispeech_sat.th
ADDED
|
Binary file (20.7 kB). View file
|
|
|
UniSpeech/downstreams/speaker_verification/models/__init__.py
ADDED
|
File without changes
|
UniSpeech/downstreams/speaker_verification/models/ecapa_tdnn.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torchaudio.transforms as trans
|
| 7 |
+
from .utils import UpstreamExpert
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
''' Res2Conv1d + BatchNorm1d + ReLU
|
| 11 |
+
'''
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Res2Conv1dReluBn(nn.Module):
|
| 15 |
+
'''
|
| 16 |
+
in_channels == out_channels == channels
|
| 17 |
+
'''
|
| 18 |
+
|
| 19 |
+
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
|
| 20 |
+
super().__init__()
|
| 21 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
| 22 |
+
self.scale = scale
|
| 23 |
+
self.width = channels // scale
|
| 24 |
+
self.nums = scale if scale == 1 else scale - 1
|
| 25 |
+
|
| 26 |
+
self.convs = []
|
| 27 |
+
self.bns = []
|
| 28 |
+
for i in range(self.nums):
|
| 29 |
+
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
|
| 30 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
| 31 |
+
self.convs = nn.ModuleList(self.convs)
|
| 32 |
+
self.bns = nn.ModuleList(self.bns)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
out = []
|
| 36 |
+
spx = torch.split(x, self.width, 1)
|
| 37 |
+
for i in range(self.nums):
|
| 38 |
+
if i == 0:
|
| 39 |
+
sp = spx[i]
|
| 40 |
+
else:
|
| 41 |
+
sp = sp + spx[i]
|
| 42 |
+
# Order: conv -> relu -> bn
|
| 43 |
+
sp = self.convs[i](sp)
|
| 44 |
+
sp = self.bns[i](F.relu(sp))
|
| 45 |
+
out.append(sp)
|
| 46 |
+
if self.scale != 1:
|
| 47 |
+
out.append(spx[self.nums])
|
| 48 |
+
out = torch.cat(out, dim=1)
|
| 49 |
+
|
| 50 |
+
return out
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
''' Conv1d + BatchNorm1d + ReLU
|
| 54 |
+
'''
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Conv1dReluBn(nn.Module):
|
| 58 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
| 61 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
return self.bn(F.relu(self.conv(x)))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
''' The SE connection of 1D case.
|
| 68 |
+
'''
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SE_Connect(nn.Module):
|
| 72 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
| 75 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
out = x.mean(dim=2)
|
| 79 |
+
out = F.relu(self.linear1(out))
|
| 80 |
+
out = torch.sigmoid(self.linear2(out))
|
| 81 |
+
out = x * out.unsqueeze(2)
|
| 82 |
+
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
''' SE-Res2Block of the ECAPA-TDNN architecture.
|
| 87 |
+
'''
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
| 91 |
+
# return nn.Sequential(
|
| 92 |
+
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
| 93 |
+
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
| 94 |
+
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
| 95 |
+
# SE_Connect(channels)
|
| 96 |
+
# )
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class SE_Res2Block(nn.Module):
|
| 100 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 103 |
+
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
|
| 104 |
+
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 105 |
+
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
| 106 |
+
|
| 107 |
+
self.shortcut = None
|
| 108 |
+
if in_channels != out_channels:
|
| 109 |
+
self.shortcut = nn.Conv1d(
|
| 110 |
+
in_channels=in_channels,
|
| 111 |
+
out_channels=out_channels,
|
| 112 |
+
kernel_size=1,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
residual = x
|
| 117 |
+
if self.shortcut:
|
| 118 |
+
residual = self.shortcut(x)
|
| 119 |
+
|
| 120 |
+
x = self.Conv1dReluBn1(x)
|
| 121 |
+
x = self.Res2Conv1dReluBn(x)
|
| 122 |
+
x = self.Conv1dReluBn2(x)
|
| 123 |
+
x = self.SE_Connect(x)
|
| 124 |
+
|
| 125 |
+
return x + residual
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
''' Attentive weighted mean and standard deviation pooling.
|
| 129 |
+
'''
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class AttentiveStatsPool(nn.Module):
|
| 133 |
+
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.global_context_att = global_context_att
|
| 136 |
+
|
| 137 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
| 138 |
+
if global_context_att:
|
| 139 |
+
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
|
| 140 |
+
else:
|
| 141 |
+
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
|
| 142 |
+
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
|
| 146 |
+
if self.global_context_att:
|
| 147 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
| 148 |
+
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
| 149 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
| 150 |
+
else:
|
| 151 |
+
x_in = x
|
| 152 |
+
|
| 153 |
+
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
| 154 |
+
alpha = torch.tanh(self.linear1(x_in))
|
| 155 |
+
# alpha = F.relu(self.linear1(x_in))
|
| 156 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 157 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 158 |
+
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
|
| 159 |
+
std = torch.sqrt(residuals.clamp(min=1e-9))
|
| 160 |
+
return torch.cat([mean, std], dim=1)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ECAPA_TDNN(nn.Module):
|
| 164 |
+
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
|
| 165 |
+
feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
|
| 166 |
+
super().__init__()
|
| 167 |
+
|
| 168 |
+
self.feat_type = feat_type
|
| 169 |
+
self.feature_selection = feature_selection
|
| 170 |
+
self.update_extract = update_extract
|
| 171 |
+
self.sr = sr
|
| 172 |
+
|
| 173 |
+
if feat_type == "fbank" or feat_type == "mfcc":
|
| 174 |
+
self.update_extract = False
|
| 175 |
+
|
| 176 |
+
win_len = int(sr * 0.025)
|
| 177 |
+
hop_len = int(sr * 0.01)
|
| 178 |
+
|
| 179 |
+
if feat_type == 'fbank':
|
| 180 |
+
self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len,
|
| 181 |
+
hop_length=hop_len, f_min=0.0, f_max=sr // 2,
|
| 182 |
+
pad=0, n_mels=feat_dim)
|
| 183 |
+
elif feat_type == 'mfcc':
|
| 184 |
+
melkwargs = {
|
| 185 |
+
'n_fft': 512,
|
| 186 |
+
'win_length': win_len,
|
| 187 |
+
'hop_length': hop_len,
|
| 188 |
+
'f_min': 0.0,
|
| 189 |
+
'f_max': sr // 2,
|
| 190 |
+
'pad': 0
|
| 191 |
+
}
|
| 192 |
+
self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
|
| 193 |
+
melkwargs=melkwargs)
|
| 194 |
+
else:
|
| 195 |
+
if config_path is None:
|
| 196 |
+
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
|
| 197 |
+
else:
|
| 198 |
+
self.feature_extract = UpstreamExpert(config_path)
|
| 199 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
|
| 200 |
+
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
| 201 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
|
| 202 |
+
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
| 203 |
+
|
| 204 |
+
self.feat_num = self.get_feat_num()
|
| 205 |
+
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
| 206 |
+
|
| 207 |
+
if feat_type != 'fbank' and feat_type != 'mfcc':
|
| 208 |
+
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
|
| 209 |
+
for name, param in self.feature_extract.named_parameters():
|
| 210 |
+
for freeze_val in freeze_list:
|
| 211 |
+
if freeze_val in name:
|
| 212 |
+
param.requires_grad = False
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
if not self.update_extract:
|
| 216 |
+
for param in self.feature_extract.parameters():
|
| 217 |
+
param.requires_grad = False
|
| 218 |
+
|
| 219 |
+
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
| 220 |
+
# self.channels = [channels] * 4 + [channels * 3]
|
| 221 |
+
self.channels = [channels] * 4 + [1536]
|
| 222 |
+
|
| 223 |
+
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
| 224 |
+
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
|
| 225 |
+
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
|
| 226 |
+
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
|
| 227 |
+
|
| 228 |
+
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
| 229 |
+
cat_channels = channels * 3
|
| 230 |
+
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
| 231 |
+
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
|
| 232 |
+
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
| 233 |
+
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def get_feat_num(self):
|
| 237 |
+
self.feature_extract.eval()
|
| 238 |
+
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
features = self.feature_extract(wav)
|
| 241 |
+
select_feature = features[self.feature_selection]
|
| 242 |
+
if isinstance(select_feature, (list, tuple)):
|
| 243 |
+
return len(select_feature)
|
| 244 |
+
else:
|
| 245 |
+
return 1
|
| 246 |
+
|
| 247 |
+
def get_feat(self, x):
|
| 248 |
+
if self.update_extract:
|
| 249 |
+
x = self.feature_extract([sample for sample in x])
|
| 250 |
+
else:
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
|
| 253 |
+
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
|
| 254 |
+
else:
|
| 255 |
+
x = self.feature_extract([sample for sample in x])
|
| 256 |
+
|
| 257 |
+
if self.feat_type == 'fbank':
|
| 258 |
+
x = x.log()
|
| 259 |
+
|
| 260 |
+
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
| 261 |
+
x = x[self.feature_selection]
|
| 262 |
+
if isinstance(x, (list, tuple)):
|
| 263 |
+
x = torch.stack(x, dim=0)
|
| 264 |
+
else:
|
| 265 |
+
x = x.unsqueeze(0)
|
| 266 |
+
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 267 |
+
x = (norm_weights * x).sum(dim=0)
|
| 268 |
+
x = torch.transpose(x, 1, 2) + 1e-6
|
| 269 |
+
|
| 270 |
+
x = self.instance_norm(x)
|
| 271 |
+
return x
|
| 272 |
+
|
| 273 |
+
def forward(self, x):
|
| 274 |
+
x = self.get_feat(x)
|
| 275 |
+
|
| 276 |
+
out1 = self.layer1(x)
|
| 277 |
+
out2 = self.layer2(out1)
|
| 278 |
+
out3 = self.layer3(out2)
|
| 279 |
+
out4 = self.layer4(out3)
|
| 280 |
+
|
| 281 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
| 282 |
+
out = F.relu(self.conv(out))
|
| 283 |
+
out = self.bn(self.pooling(out))
|
| 284 |
+
out = self.linear(out)
|
| 285 |
+
|
| 286 |
+
return out
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
|
| 290 |
+
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
|
| 291 |
+
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
|
| 292 |
+
|
| 293 |
+
if __name__ == '__main__':
|
| 294 |
+
x = torch.zeros(2, 32000)
|
| 295 |
+
model = ECAPA_TDNN_SMALL(feat_dim=768, emb_dim=256, feat_type='hubert_base', feature_selection="hidden_states",
|
| 296 |
+
update_extract=False)
|
| 297 |
+
|
| 298 |
+
out = model(x)
|
| 299 |
+
# print(model)
|
| 300 |
+
print(out.shape)
|
| 301 |
+
|
UniSpeech/downstreams/speaker_verification/models/utils.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import fairseq
|
| 3 |
+
from packaging import version
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from fairseq import tasks
|
| 6 |
+
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
| 7 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
from s3prl.upstream.interfaces import UpstreamBase
|
| 10 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 11 |
+
|
| 12 |
+
def load_model(filepath):
|
| 13 |
+
state = torch.load(filepath, map_location=lambda storage, loc: storage)
|
| 14 |
+
# state = load_checkpoint_to_cpu(filepath)
|
| 15 |
+
state["cfg"] = OmegaConf.create(state["cfg"])
|
| 16 |
+
|
| 17 |
+
if "args" in state and state["args"] is not None:
|
| 18 |
+
cfg = convert_namespace_to_omegaconf(state["args"])
|
| 19 |
+
elif "cfg" in state and state["cfg"] is not None:
|
| 20 |
+
cfg = state["cfg"]
|
| 21 |
+
else:
|
| 22 |
+
raise RuntimeError(
|
| 23 |
+
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
task = tasks.setup_task(cfg.task)
|
| 27 |
+
if "task_state" in state:
|
| 28 |
+
task.load_state_dict(state["task_state"])
|
| 29 |
+
|
| 30 |
+
model = task.build_model(cfg.model)
|
| 31 |
+
|
| 32 |
+
return model, cfg, task
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
###################
|
| 36 |
+
# UPSTREAM EXPERT #
|
| 37 |
+
###################
|
| 38 |
+
class UpstreamExpert(UpstreamBase):
|
| 39 |
+
def __init__(self, ckpt, **kwargs):
|
| 40 |
+
super().__init__(**kwargs)
|
| 41 |
+
assert version.parse(fairseq.__version__) > version.parse(
|
| 42 |
+
"0.10.2"
|
| 43 |
+
), "Please install the fairseq master branch."
|
| 44 |
+
|
| 45 |
+
model, cfg, task = load_model(ckpt)
|
| 46 |
+
self.model = model
|
| 47 |
+
self.task = task
|
| 48 |
+
|
| 49 |
+
if len(self.hooks) == 0:
|
| 50 |
+
module_name = "self.model.encoder.layers"
|
| 51 |
+
for module_id in range(len(eval(module_name))):
|
| 52 |
+
self.add_hook(
|
| 53 |
+
f"{module_name}[{module_id}]",
|
| 54 |
+
lambda input, output: input[0].transpose(0, 1),
|
| 55 |
+
)
|
| 56 |
+
self.add_hook("self.model.encoder", lambda input, output: output[0])
|
| 57 |
+
|
| 58 |
+
def forward(self, wavs):
|
| 59 |
+
if self.task.cfg.normalize:
|
| 60 |
+
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
|
| 61 |
+
|
| 62 |
+
device = wavs[0].device
|
| 63 |
+
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
|
| 64 |
+
wav_padding_mask = ~torch.lt(
|
| 65 |
+
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
|
| 66 |
+
wav_lengths.unsqueeze(1),
|
| 67 |
+
)
|
| 68 |
+
padded_wav = pad_sequence(wavs, batch_first=True)
|
| 69 |
+
|
| 70 |
+
features, feat_padding_mask = self.model.extract_features(
|
| 71 |
+
padded_wav,
|
| 72 |
+
padding_mask=wav_padding_mask,
|
| 73 |
+
mask=None,
|
| 74 |
+
)
|
| 75 |
+
return {
|
| 76 |
+
"default": features,
|
| 77 |
+
}
|
| 78 |
+
|
UniSpeech/downstreams/speaker_verification/requirements.txt
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scipy==1.7.1 \
|
| 2 |
+
--hash=sha256:2a0eeaab01258e0870c4022a6cd329aef3b7c6c2b606bd7cf7bb2ba9820ae561 \
|
| 3 |
+
--hash=sha256:3304bd5bc32e00954ac4b3f4cc382ca8824719bf348aacbec6347337d6b125fe \
|
| 4 |
+
--hash=sha256:3f52470e0548cdb74fb8ddf06773ffdcca7c97550f903b1c51312ec19243a7f7 \
|
| 5 |
+
--hash=sha256:4729b41a4cdaf4cd011aeac816b532f990bdf97710cef59149d3e293115cf467 \
|
| 6 |
+
--hash=sha256:4ee952f39a4a4c7ba775a32b664b1f4b74818548b65f765987adc14bb78f5802 \
|
| 7 |
+
--hash=sha256:611f9cb459d0707dd8e4de0c96f86e93f61aac7475fcb225e9ec71fecdc5cebf \
|
| 8 |
+
--hash=sha256:6b47d5fa7ea651054362561a28b1ccc8da9368a39514c1bbf6c0977a1c376764 \
|
| 9 |
+
--hash=sha256:71cfc96297617eab911e22216e8a8597703202e95636d9406df9af5c2ac99a2b \
|
| 10 |
+
--hash=sha256:787749110a23502031fb1643c55a2236c99c6b989cca703ea2114d65e21728ef \
|
| 11 |
+
--hash=sha256:90c07ba5f34f33299a428b0d4fa24c30d2ceba44d63f8385b2b05be460819fcb \
|
| 12 |
+
--hash=sha256:a496b42dbcd04ea9924f5e92be63af3d8e0f43a274b769bfaca0a297327d54ee \
|
| 13 |
+
--hash=sha256:bc61e3e5ff92d2f32bb263621d54a9cff5e3f7c420af3d1fa122ce2529de2bd9 \
|
| 14 |
+
--hash=sha256:c9951e3746b68974125e5e3445008a4163dd6d20ae0bbdae22b38cb8951dc11b \
|
| 15 |
+
--hash=sha256:d1388fbac9dd591ea630da75c455f4cc637a7ca5ecb31a6b6cef430914749cde \
|
| 16 |
+
--hash=sha256:d13f31457f2216e5705304d9f28e2826edf75487410a57aa99263fa4ffd792c2 \
|
| 17 |
+
--hash=sha256:d648aa85dd5074b1ed83008ae987c3fbb53d68af619fce1dee231f4d8bd40e2f \
|
| 18 |
+
--hash=sha256:da9c6b336e540def0b7fd65603da8abeb306c5fc9a5f4238665cbbb5ff95cf58 \
|
| 19 |
+
--hash=sha256:e101bceeb9e65a90dadbc5ca31283403a2d4667b9c178db29109750568e8d112 \
|
| 20 |
+
--hash=sha256:efdd3825d54c58df2cc394366ca4b9166cf940a0ebddeb87b6c10053deb625ea
|
| 21 |
+
fire==0.4.0 \
|
| 22 |
+
--hash=sha256:c5e2b8763699d1142393a46d0e3e790c5eb2f0706082df8f647878842c216a62
|
| 23 |
+
sklearn==0.0 \
|
| 24 |
+
--hash=sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31
|
| 25 |
+
s3prl==0.3.1 \
|
| 26 |
+
--hash=sha256:e497989b10d4e058b619cf3e7a547820fceb3fe18c14c566427eb7b8c770d62e
|
| 27 |
+
torchaudio==0.9.0 \
|
| 28 |
+
--hash=sha256:0a387e78eeaf6e0abd36df70e9d8a15d242b49c2507dbd9522568f5d4af5fb96 \
|
| 29 |
+
--hash=sha256:18763c05cb7d85a08b8ea960e40f6984e9513b02e76f4526d920493c701b0671 \
|
| 30 |
+
--hash=sha256:48e33bb96b7ff2dc10a778a695429dbd6dfc8c8baa0d7c9b63569cb002bb87cd \
|
| 31 |
+
--hash=sha256:62fd9393ddbe40aadaabef7595f5bff0057e39f7e519195a010731542815f5a4 \
|
| 32 |
+
--hash=sha256:76a5b8ea0e4ddafd5b8f24abdf1a6f7afe847d892570da13cf0fc9bceeac437f \
|
| 33 |
+
--hash=sha256:87520525da10b5f00d3e5e1180db6ee37b1fa305edb2260c7335e0859dbe634e \
|
| 34 |
+
--hash=sha256:9d3f5d6df7d91676e67a38a448253b74d77da723f8e24bd833ff7ed0f82fa4ef \
|
| 35 |
+
--hash=sha256:acf0d736a5c1ea6b94adf08b0a31670009b6e78dfe50a1b0bdabf2b0f7895dc0 \
|
| 36 |
+
--hash=sha256:ad221258fc5d1d446f2c1ce9a1bb54cc05ca2b208491d4eaa5af443f1c0f16a2 \
|
| 37 |
+
--hash=sha256:ba52ae64611773bec7fc664c29f9ea3e02c9e5c817693726b978ed1bdedd07f2 \
|
| 38 |
+
--hash=sha256:c6126556d529df73b676e023063388d551be3c0cb2d42a4ff5c4cfd44ef3e012 \
|
| 39 |
+
--hash=sha256:ef5f0b22646a94f95869001b40ab940468b1ae399d0ffd3bc73d5c43342a013a \
|
| 40 |
+
--hash=sha256:ef8dc4ab1ec807382a713e71e8493d1985930537c933273e3c0739f02183cedc \
|
| 41 |
+
--hash=sha256:efb16c593b2a5ada07b180580c7612617e84f4714ce86928ad54baefe71ef29d
|
| 42 |
+
sentencepiece==0.1.96 \
|
| 43 |
+
--hash=sha256:1dac8c2ad02b5ebc1179c0a14cbc7d7c6f4fd73d4dd51820626402d0aefc974e \
|
| 44 |
+
--hash=sha256:26d20d713b3ba1b7a19205336afb1e93a4327c372b2f795e907b8dc2315ac92e \
|
| 45 |
+
--hash=sha256:335bf84d72112cc91f3c3b691d61802fc963503b7772fd8280d20368048b8f3e \
|
| 46 |
+
--hash=sha256:36e9ff61e7b67c5b7ee96733613622620b4802fc8cf188a4dbc1f355b03dde02 \
|
| 47 |
+
--hash=sha256:384148cead5cdab34a4d74fe1fb6a5a8abaafed25eaa4a7698b49dd9482e4c4e \
|
| 48 |
+
--hash=sha256:3c703e68ea192e45b65c5d5836f6980849d828a18da4189899d7150fad82dc9e \
|
| 49 |
+
--hash=sha256:3e61e0757e49c306fff78ea75d6b75773418fe22214b4a460959203be934e834 \
|
| 50 |
+
--hash=sha256:466e381f0a812da8fda97a9707498cef3210ea8385a3421bcbadcb5384063969 \
|
| 51 |
+
--hash=sha256:48c6d13b3bfff08060c138248e85df60f6fad11135ad7a8fc2ef6005aacca839 \
|
| 52 |
+
--hash=sha256:4997c7ccf2ae462320250314aa5709a88d8a09fa271d073458a07bebf33f8e7c \
|
| 53 |
+
--hash=sha256:5388882bb24d083f6cc8cffc5c435f3694a7772b018e06ea6fd84d1044009efb \
|
| 54 |
+
--hash=sha256:5513298d62fe63dd0862d08a6eb52a9aa3537006f597f2386184e3f95bb88889 \
|
| 55 |
+
--hash=sha256:78e18d9106c36dcca929e18fd2c412378deac661d47fa3ee25defc55eef8a215 \
|
| 56 |
+
--hash=sha256:8179785883b556cd517416cdbda6244745414b00ec83132cfe1d26000971f3ae \
|
| 57 |
+
--hash=sha256:81bb77ba3651114943b2f8f77829cf764137dff06e38f4bf7fa43efea12c7f84 \
|
| 58 |
+
--hash=sha256:89c038da7f827a6e2ca4c73aeb4e4b25b99d981ce47dd61b04d446c8200cba1e \
|
| 59 |
+
--hash=sha256:940a6999c7d3f55e9d7b194fd5e1f41a7dbed26d3519fb95333216292a39599e \
|
| 60 |
+
--hash=sha256:99ea2d9db19e63a2d17d5dc64f9ace83fb9308a735be05a1aaf98eb4b496fba7 \
|
| 61 |
+
--hash=sha256:9bdf097d5bd1d8ce42dfee51f6ff05f5578b96e48c6f6006aa4eff69edfa3639 \
|
| 62 |
+
--hash=sha256:a336575463d75d3aac1f7e32470b8998643ccd9a73786bd726f6b0470520b6b4 \
|
| 63 |
+
--hash=sha256:a697257a2cd7581732d7741a8d32a06927f0311c3d277dbc47fa1043350c9d17 \
|
| 64 |
+
--hash=sha256:a92e1932ee8fd500680ccbe1bf53eb33228f4c9d6524ed6f300bcc80ac359f27 \
|
| 65 |
+
--hash=sha256:aeb090ad462833df03af1debce4ae607a2766ef861f992003ad0c56d074ab805 \
|
| 66 |
+
--hash=sha256:b1c24c1d9405b2148184ff27c062493d5e3be5c144575f95b5a0d7c660a515af \
|
| 67 |
+
--hash=sha256:b77d27f59d515c43b61745b8173fbe7c7b3014b14b3702a75bf1793471e7def6 \
|
| 68 |
+
--hash=sha256:b8b1dd2712f8a7de5b4c8ec912e6c041d25750bf03e1ce325cdba43bae0944ae \
|
| 69 |
+
--hash=sha256:bedf0355117fb4e9b1fc9fc92b4d5ee743a7d468be9f6196e3b94447710ea589 \
|
| 70 |
+
--hash=sha256:cc969e6694fb27fba7cee2953f350804faf03913f25ae1ee713a7b8a1bc08018 \
|
| 71 |
+
--hash=sha256:d45e3f78e746aa161bc9f5a31c6a2839c512101113a4065f4d2e7a3ab8198d8c \
|
| 72 |
+
--hash=sha256:d501713a8396193883aa526f48dc609f5f031a5df1afbafa561cf9ab492ffc76 \
|
| 73 |
+
--hash=sha256:d954d25a8705f972e8bfc1dea5464d7e697dd6f4ade092f1a487387e6d6c829a \
|
| 74 |
+
--hash=sha256:dadccb2e49244b6e64b4527d13ec14d5e094a90b41cf9b963e457e64182f1941 \
|
| 75 |
+
--hash=sha256:e811984b0908c14c56de7d8226fdd494d87a7ccb75af8ac3a07423037aaafc35 \
|
| 76 |
+
--hash=sha256:e88354b61f59dfdeb41023f7be8ae31dc627c2dc2dacbc2de8b2d82a0997135c \
|
| 77 |
+
--hash=sha256:e8ec5bb6777e2060e1499750c50e1b69dca5a0f80f90f2c66656c5f3e5244593 \
|
| 78 |
+
--hash=sha256:e9e9fe8094ca57549d801e9a2017ac5c24108bbf485ea4f8994a72e8e96ee135 \
|
| 79 |
+
--hash=sha256:eba0471ab0bb2e07ed06d91ecf5185d402c83d194155a41d8e2aa547d187712e \
|
| 80 |
+
--hash=sha256:ef59ba19340dc1d002ce5713b911c0ef23c577b08f8ed57998ee3c8e62c5bf6e \
|
| 81 |
+
--hash=sha256:f8c90df663cd9759b2cf8dd29998b63140ac39e51ada2e739dc13bdac0b4f001 \
|
| 82 |
+
--hash=sha256:f8cb24d8d0b2f8b7463815a59183eb81ec1d7a06e3217bed456063f3303eddfb \
|
| 83 |
+
--hash=sha256:fd907a8f744e5337de7fc532dd800c4416b571ea47f8c3c66be10cd1bc67c925 \
|
| 84 |
+
--hash=sha256:ff7d752a7f82d87711ec1a95c2262cb74f98be5b457f0300d81a1aefe5be2a95
|
| 85 |
+
|
UniSpeech/downstreams/speaker_verification/verification.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import soundfile as sf
|
| 2 |
+
import torch
|
| 3 |
+
import fire
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchaudio.transforms import Resample
|
| 6 |
+
from models.ecapa_tdnn import ECAPA_TDNN_SMALL
|
| 7 |
+
|
| 8 |
+
MODEL_LIST = ['ecapa_tdnn', 'hubert_large', 'wav2vec2_xlsr', 'unispeech_sat', "wavlm_base_plus", "wavlm_large"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def init_model(model_name, checkpoint=None):
|
| 12 |
+
if model_name == 'unispeech_sat':
|
| 13 |
+
config_path = 'config/unispeech_sat.th'
|
| 14 |
+
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path)
|
| 15 |
+
elif model_name == 'wavlm_base_plus':
|
| 16 |
+
config_path = None
|
| 17 |
+
model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path)
|
| 18 |
+
elif model_name == 'wavlm_large':
|
| 19 |
+
config_path = None
|
| 20 |
+
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path)
|
| 21 |
+
elif model_name == 'hubert_large':
|
| 22 |
+
config_path = None
|
| 23 |
+
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path)
|
| 24 |
+
elif model_name == 'wav2vec2_xlsr':
|
| 25 |
+
config_path = None
|
| 26 |
+
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path)
|
| 27 |
+
else:
|
| 28 |
+
model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')
|
| 29 |
+
|
| 30 |
+
if checkpoint is not None:
|
| 31 |
+
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
|
| 32 |
+
model.load_state_dict(state_dict['model'], strict=False)
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def verification(model_name, wav1, wav2, use_gpu=True, checkpoint=None):
|
| 37 |
+
|
| 38 |
+
assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST)
|
| 39 |
+
model = init_model(model_name, checkpoint)
|
| 40 |
+
|
| 41 |
+
wav1, sr1 = sf.read(wav1)
|
| 42 |
+
wav2, sr2 = sf.read(wav2)
|
| 43 |
+
|
| 44 |
+
wav1 = torch.from_numpy(wav1).unsqueeze(0).float()
|
| 45 |
+
wav2 = torch.from_numpy(wav2).unsqueeze(0).float()
|
| 46 |
+
resample1 = Resample(orig_freq=sr1, new_freq=16000)
|
| 47 |
+
resample2 = Resample(orig_freq=sr2, new_freq=16000)
|
| 48 |
+
wav1 = resample1(wav1)
|
| 49 |
+
wav2 = resample2(wav2)
|
| 50 |
+
|
| 51 |
+
if use_gpu:
|
| 52 |
+
model = model.cuda()
|
| 53 |
+
wav1 = wav1.cuda()
|
| 54 |
+
wav2 = wav2.cuda()
|
| 55 |
+
|
| 56 |
+
model.eval()
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
emb1 = model(wav1)
|
| 59 |
+
emb2 = model(wav2)
|
| 60 |
+
|
| 61 |
+
sim = F.cosine_similarity(emb1, emb2)
|
| 62 |
+
print("The similarity score between two audios is {:.4f} (-1.0, 1.0).".format(sim[0].item()))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
fire.Fire(verification)
|
| 67 |
+
|
UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c76f757c218238a1bb30fa3fb03ccfe3cdb2376e411b984536238296d30b0c0
|
| 3 |
+
size 157486
|
UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc70db51fba0095b2820a1afce3781b2327fa3e74677deeaa300eac9c740679c
|
| 3 |
+
size 285486
|
UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dca90fb57b11034d3adf378a17786336497eee71176fc4966ed4fdcd5c63c142
|
| 3 |
+
size 312366
|
UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/RFyw7V3SOnQ_0000001.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:981c26c1adc3a833872d44702ff8d0ee2da114096aa7ba11cb2ec683c3b0b95d
|
| 3 |
+
size 454446
|
UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/HladKGyKTLM_0000006.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7263ce6b8c9b7e4bc137270798c94491d5f962854ef9187f84477e2118cee7eb
|
| 3 |
+
size 541486
|
UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/mHTAr5dlAgc_0000004.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a09cc35b58825fac276e20a2cc71146b9c0143d5f64127235f56274a02300ea3
|
| 3 |
+
size 140846
|
UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/WbB8m9-wlIQ_0000001.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08fbfaa3d34cbf1abfc56721ece687d306adc7b2bf2d725ca8b5f0dd10d5d84f
|
| 3 |
+
size 160046
|
UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/gFfcgOVmiO0_0000002.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b47b677fba76b3125537c7562b0be223f3ae7cf881bdac6d7b1c59311c02c7d
|
| 3 |
+
size 253486
|
UniSpeech/src/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
## Enforcement
|
| 56 |
+
|
| 57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 58 |
+
reported by contacting the project team at <conduct@pytorch.org>. All
|
| 59 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 62 |
+
Further details of specific enforcement policies may be posted separately.
|
| 63 |
+
|
| 64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 65 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 66 |
+
members of the project's leadership.
|
| 67 |
+
|
| 68 |
+
## Attribution
|
| 69 |
+
|
| 70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 72 |
+
|
| 73 |
+
[homepage]: https://www.contributor-covenant.org
|
| 74 |
+
|
| 75 |
+
For answers to common questions about this code of conduct, see
|
| 76 |
+
https://www.contributor-covenant.org/faq
|
| 77 |
+
|
UniSpeech/src/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `master`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Facebook's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
## License
|
| 26 |
+
By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
|
| 27 |
+
you agree that your contributions will be licensed under the LICENSE file in
|
| 28 |
+
the root directory of this source tree.
|
UniSpeech/src/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
UniSpeech/src/config/config.yaml
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
common:
|
| 3 |
+
no_progress_bar: false
|
| 4 |
+
log_interval: 100
|
| 5 |
+
log_format: null
|
| 6 |
+
tensorboard_logdir: null
|
| 7 |
+
seed: 1
|
| 8 |
+
cpu: false
|
| 9 |
+
tpu: false
|
| 10 |
+
bf16: false
|
| 11 |
+
fp16: false
|
| 12 |
+
memory_efficient_fp16: false
|
| 13 |
+
memory_efficient_bf16: false
|
| 14 |
+
fp16_no_flatten_grads: false
|
| 15 |
+
fp16_init_scale: 128
|
| 16 |
+
fp16_scale_window: null
|
| 17 |
+
fp16_scale_tolerance: 0.0
|
| 18 |
+
min_loss_scale: 1.0e-4
|
| 19 |
+
threshold_loss_scale: null
|
| 20 |
+
user_dir: null
|
| 21 |
+
empty_cache_freq: 0
|
| 22 |
+
all_gather_list_size: 16384
|
| 23 |
+
model_parallel_size: 1
|
| 24 |
+
quantization_config_path: null
|
| 25 |
+
profile: false
|
| 26 |
+
distributed_training:
|
| 27 |
+
distributed_rank: 0
|
| 28 |
+
distributed_backend: "nccl"
|
| 29 |
+
distributed_init_method: null
|
| 30 |
+
distributed_port: -1
|
| 31 |
+
device_id: 0
|
| 32 |
+
local_rank: 0
|
| 33 |
+
distributed_no_spawn: false
|
| 34 |
+
ddp_backend: "c10d"
|
| 35 |
+
bucket_cap_mb: 25
|
| 36 |
+
fix_batches_to_gpus: false
|
| 37 |
+
find_unused_parameters: false
|
| 38 |
+
fast_stat_sync: false
|
| 39 |
+
broadcast_buffers: false
|
| 40 |
+
distributed_wrapper: "DDP"
|
| 41 |
+
slowmo_momentum: null
|
| 42 |
+
slowmo_algorithm: "LocalSGD"
|
| 43 |
+
localsgd_frequency: 3
|
| 44 |
+
dataset:
|
| 45 |
+
num_workers: 1
|
| 46 |
+
skip_invalid_size_inputs_valid_test: false
|
| 47 |
+
max_tokens: null
|
| 48 |
+
batch_size: null
|
| 49 |
+
required_batch_size_multiple: 8
|
| 50 |
+
dataset_impl: null
|
| 51 |
+
data_buffer_size: 10
|
| 52 |
+
train_subset: "train"
|
| 53 |
+
valid_subset: "valid"
|
| 54 |
+
validate_interval: 1
|
| 55 |
+
fixed_validation_seed: null
|
| 56 |
+
disable_validation: false
|
| 57 |
+
curriculum: 0
|
| 58 |
+
gen_subset: "test"
|
| 59 |
+
num_shards: 1
|
| 60 |
+
shard_id: 0
|
| 61 |
+
max_tokens_valid: ${dataset.max_tokens}
|
| 62 |
+
batch_size_valid: ${dataset.batch_size}
|
| 63 |
+
optimization:
|
| 64 |
+
max_epoch: 0
|
| 65 |
+
max_update: 0
|
| 66 |
+
clip_norm: 25.0
|
| 67 |
+
sentence_avg: false
|
| 68 |
+
update_freq: [ 1 ]
|
| 69 |
+
lr: [ 0.25 ]
|
| 70 |
+
min_lr: -1.0
|
| 71 |
+
use_bmuf: false
|
| 72 |
+
checkpoint:
|
| 73 |
+
save_dir: "checkpoints"
|
| 74 |
+
restore_file: "checkpoint_last.pt"
|
| 75 |
+
reset_dataloader: false
|
| 76 |
+
reset_lr_scheduler: false
|
| 77 |
+
reset_meters: false
|
| 78 |
+
reset_optimizer: false
|
| 79 |
+
optimizer_overrides: "{}"
|
| 80 |
+
save_interval: 1
|
| 81 |
+
save_interval_updates: 0
|
| 82 |
+
keep_interval_updates: -1
|
| 83 |
+
keep_last_epochs: -1
|
| 84 |
+
keep_best_checkpoints: -1
|
| 85 |
+
no_save: false
|
| 86 |
+
no_epoch_checkpoints: false
|
| 87 |
+
no_last_checkpoints: false
|
| 88 |
+
no_save_optimizer_state: false
|
| 89 |
+
best_checkpoint_metric: "loss"
|
| 90 |
+
maximize_best_checkpoint_metric: false
|
| 91 |
+
patience: -1
|
| 92 |
+
checkpoint_suffix: ""
|
| 93 |
+
bmuf:
|
| 94 |
+
block_lr: 1
|
| 95 |
+
block_momentum: 0.875
|
| 96 |
+
global_sync_iter: 50
|
| 97 |
+
warmup_iterations: 500
|
| 98 |
+
use_nbm: false
|
| 99 |
+
average_sync: false
|
| 100 |
+
defaults:
|
| 101 |
+
- task: language_modeling
|
| 102 |
+
- model: null
|
| 103 |
+
- criterion: null
|
| 104 |
+
- optimizer: null
|
| 105 |
+
- lr_scheduler: null
|
| 106 |
+
- bpe: null
|
| 107 |
+
- tokenizer: null
|
| 108 |
+
- scoring: null
|
| 109 |
+
- generation: null
|
| 110 |
+
- common_eval: null
|
| 111 |
+
- eval_lm: null
|